Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/circt/Dialect/Synth/Synth.td
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ def Synth_Dialect : Dialect {
including meta operations for synthesis decisions, logic representations
like AIG and MIG, and synthesis pipeline infrastructure.
}];

let hasConstantMaterializer = 1;
}

include "circt/Dialect/Synth/SynthOps.td"
Expand Down
12 changes: 12 additions & 0 deletions include/circt/Dialect/Synth/SynthOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,24 @@ def MajorityInverterOp : SynthOp<"mig.maj_inv",
DenseBoolArrayAttr:$inverted);
let results = (outs AnyType:$result);
let hasVerifier = true;
let hasCanonicalizeMethod = true;
let hasFolder = true;

let assemblyFormat = [{
custom<VariadicInvertibleOperands>($inputs, type($result), $inverted,
attr-dict)
}];
let cppNamespace = "::circt::synth::mig";
let extraClassDeclaration = [{
// Evaluate the operation with the given input values.
APInt evaluate(ArrayRef<APInt> inputs);

// Check if the input is inverted.
bool isInverted(size_t idx) {
return getInverted()[idx];
}
}];

}

def AndInverterOp : SynthOp<"aig.and_inv", [SameOperandsAndResultType, Pure]> {
Expand Down
11 changes: 11 additions & 0 deletions lib/Dialect/Synth/SynthDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//

#include "circt/Dialect/Synth/SynthDialect.h"
#include "circt/Dialect/HW/HWOps.h"
#include "circt/Dialect/Synth/SynthOps.h"

using namespace circt;
Expand All @@ -19,4 +20,14 @@ void SynthDialect::initialize() {
>();
}

Operation *SynthDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
// Integer constants.
if (auto intType = dyn_cast<IntegerType>(type))
if (auto attrValue = dyn_cast<IntegerAttr>(value))
return hw::ConstantOp::create(builder, loc, type, attrValue);
return nullptr;
}

#include "circt/Dialect/Synth/SynthDialect.cpp.inc"
119 changes: 119 additions & 0 deletions lib/Dialect/Synth/SynthOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@
#include "circt/Dialect/HW/HWOps.h"
#include "circt/Support/CustomDirectiveImpl.h"
#include "circt/Support/Naming.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/APInt.h"
#include "llvm/Support/Casting.h"

using namespace mlir;
using namespace circt;
Expand All @@ -27,6 +32,120 @@ LogicalResult MajorityInverterOp::verify() {
return success();
}

llvm::APInt MajorityInverterOp::evaluate(ArrayRef<APInt> inputs) {
assert(inputs.size() == getNumOperands() &&
"Number of inputs must match number of operands");

if (inputs.size() == 3) {
auto a = (isInverted(0) ? ~inputs[0] : inputs[0]);
auto b = (isInverted(1) ? ~inputs[1] : inputs[1]);
auto c = (isInverted(2) ? ~inputs[2] : inputs[2]);
return (a & b) | (a & c) | (b & c);
}

// General case for odd number of inputs != 3
auto width = inputs[0].getBitWidth();
APInt result(width, 0);

for (size_t bit = 0; bit < width; ++bit) {
size_t count = 0;
for (size_t i = 0; i < inputs.size(); ++i) {
// Count the number of 1s, considering inversion.
if (isInverted(i) ^ inputs[i][bit])
count++;
}

if (count > inputs.size() / 2)
result.setBit(bit);
}

return result;
}

OpFoldResult MajorityInverterOp::fold(FoldAdaptor adaptor) {
// TODO: Implement maj(x, 1, 1) = 1, maj(x, 0, 0) = 0

SmallVector<APInt, 3> inputValues;
for (auto input : adaptor.getInputs()) {
auto attr = llvm::dyn_cast_or_null<IntegerAttr>(input);
if (!attr)
return {};
inputValues.push_back(attr.getValue());
}

auto result = evaluate(inputValues);
return IntegerAttr::get(getType(), result);
}

LogicalResult MajorityInverterOp::canonicalize(MajorityInverterOp op,
PatternRewriter &rewriter) {
if (op.getNumOperands() == 1) {
if (op.getInverted()[0])
return failure();
rewriter.replaceOp(op, op.getOperand(0));
return success();
}

// For now, only support 3 operands.
if (op.getNumOperands() != 3)
return failure();

// Return if the idx-th operand is a constant (inverted if necessary),
// otherwise return std::nullopt.
auto getConstant = [&](unsigned index) -> std::optional<llvm::APInt> {
APInt value;
if (mlir::matchPattern(op.getInputs()[index], mlir::m_ConstantInt(&value)))
return op.isInverted(index) ? ~value : value;
return std::nullopt;
};

// Replace the op with the idx-th operand (inverted if necessary).
auto replaceWithIndex = [&](int index) {
bool inverted = op.isInverted(index);
if (inverted)
rewriter.replaceOpWithNewOp<MajorityInverterOp>(
op, op.getType(), op.getOperand(index), true);
else
rewriter.replaceOp(op, op.getOperand(index));
return success();
};

// Pattern match following cases:
// maj_inv(x, x, y) -> x
// maj_inv(x, y, not y) -> x
for (int i = 0; i < 2; ++i) {
for (int j = i + 1; j < 3; ++j) {
int k = 3 - (i + j);
assert(k >= 0 && k < 3);
// If we have two identical operands, we can fold.
if (op.getOperand(i) == op.getOperand(j)) {
// If they are inverted differently, we can fold to the third.
if (op.isInverted(i) != op.isInverted(j)) {
return replaceWithIndex(k);
}
rewriter.replaceOp(op, op.getOperand(i));
return success();
}

// If i and j are constant.
if (auto c1 = getConstant(i)) {
if (auto c2 = getConstant(j)) {
// If both constants are equal, we can fold.
if (*c1 == *c2) {
rewriter.replaceOpWithNewOp<hw::ConstantOp>(
op, op.getType(), mlir::IntegerAttr::get(op.getType(), *c1));
return success();
}
// If constants are complementary, we can fold.
if (*c1 == ~*c2)
return replaceWithIndex(k);
}
}
}
}
return failure();
}

//===----------------------------------------------------------------------===//
// AIG Operations
//===----------------------------------------------------------------------===//
Expand Down
56 changes: 56 additions & 0 deletions test/Dialect/Synth/canonicalizer.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,59 @@ hw.module @DoubleInversion(in %a: i1, in %b: i1, out o1: i1, out o2: i1, out o3:
%5 = synth.aig.and_inv not %4, not %a : i1
hw.output %1, %3, %5 : i1, i1, i1
}

// CHECK-LABEL: hw.module @maj_inv_basic
hw.module @maj_inv_basic(in %a : i1, in %b : i1, out o1 : i1, out o2 : i1, out o3 : i1, out o4 : i1) {
// Single operand, not inverted -> replace with operand
%0 = synth.mig.maj_inv %a : i1

// Three operands, two same -> replace with the operand
%1 = synth.mig.maj_inv %a, %a, %b : i1

// Three operands, two complementary -> replace with the third
%2 = synth.mig.maj_inv %a, %b, not %b : i1

// Two operands same but one inverted -> replace with the third
%3 = synth.mig.maj_inv %a, not %a, %b : i1

// CHECK: hw.output %a, %a, %a, %b : i1, i1, i1, i1
hw.output %0, %1, %2, %3 : i1, i1, i1, i1
}

// CHECK-LABEL: hw.module @maj_inv_constants_canonicalization
hw.module @maj_inv_constants_canonicalization(in %a : i1, out o1 : i1, out o2 : i1, out o3 : i1) {
// Two constants equal -> replace with constant
%c = hw.constant 1 : i1
%0 = synth.mig.maj_inv %c, %c, %a : i1

// Two constants complementary -> replace with the third
%c0 = hw.constant 0 : i1
%c1 = hw.constant 1 : i1
%1 = synth.mig.maj_inv %c0, %c1, %a : i1

// Two constants equal with one inverted -> replace with constant
%2 = synth.mig.maj_inv not %c0, %c1, %a : i1

// CHECK: hw.output %true, %a, %true : i1, i1, i1
hw.output %0, %1, %2 : i1, i1, i1
}

// CHECK-LABEL: hw.module @maj_inv_constants_fold
hw.module @maj_inv_constants_fold(out o1 : i2, out o2 : i2) {
%c1 = hw.constant 1 : i2
%c2 = hw.constant 2 : i2
%c3 = hw.constant 3 : i2

// not(%c1) = 10
// %c2 = 10
// not(%c3) = 00
// --------------------
// maj(10, 10, 00) = 10
// Fold to 2
%0 = synth.mig.maj_inv not %c1, %c2, not %c3 : i2

// Fold to 3
%1 = synth.mig.maj_inv not %c1, not %c2, %c3, %c1, %c2 : i2
// CHECK: hw.output %c-2_i2, %c-1_i2 : i2, i2
hw.output %0, %1 : i2, i2
}