Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion include/circt/Dialect/Synth/Transforms/SynthPasses.td
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,19 @@ def GenericLutMapper : CutRewriterPassBase<"synth-generic-lut-mapper",
}

def LowerVariadic : Pass<"synth-lower-variadic", "hw::HWModuleOp"> {
let summary = "Lower variadic AndInverter operations to binary AndInverter";
let summary = "Lower variadic operations to binary operations";
let description = [{
This pass lowers variadic operations to binary operations using a
delay-aware algorithm. For commutative operations, it builds a balanced
tree by combining values with the earliest arrival times first to minimize
the critical path.
}];
let options = [
ListOption<"opNames", "op-names", "std::string",
"Specify operation names to lower (empty means all)">,
Option<"timingAware", "timing-aware", "bool", "true",
"Lower operators with timing information">
];
}

def LowerWordToBits : Pass<"synth-lower-word-to-bits", "hw::HWModuleOp"> {
Expand Down
5 changes: 5 additions & 0 deletions include/circt/Dialect/Synth/Transforms/SynthesisPipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ struct SynthOptimizationPipelineOptions
PassOptions::Option<bool> disableWordToBits{
*this, "disable-word-to-bits",
llvm::cl::desc("Disable LowerWordToBits pass"), llvm::cl::init(false)};

PassOptions::Option<bool> timingAware{
*this, "timing-aware",
llvm::cl::desc("Lower operators in a timing-aware fashion"),
llvm::cl::init(false)};
};

//===----------------------------------------------------------------------===//
Expand Down
23 changes: 23 additions & 0 deletions integration_test/circt-synth/lower-variadic.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// REQUIRES: libz3
// REQUIRES: circt-lec-jit

// RUN: circt-opt %s -convert-synth-to-comb -o %t.before.mlir
// RUN: circt-opt %s -synth-lower-variadic -convert-synth-to-comb -o %t.after.mlir
// RUN: circt-lec %t.before.mlir %t.after.mlir -c1=AndInverter -c2=AndInverter --shared-libs=%libz3 | FileCheck %s --check-prefix=AND_INVERTER_LEC
// AND_INVERTER_LEC: c1 == c2
hw.module @AndInverter(in %a: i2, in %b: i2, in %c: i2, in %d: i2, in %e: i2, in %f: i2, in %g: i2, out o1: i2) {
%0 = synth.aig.and_inv %d, not %e : i2
%1 = synth.aig.and_inv not %c, not %0, %f : i2
%2 = synth.aig.and_inv %a, not %b, not %1, %g : i2
hw.output %2 : i2
}

// RUN: circt-lec %t.before.mlir %t.after.mlir -c1=VariadicCombOps -c2=VariadicCombOps --shared-libs=%libz3 | FileCheck %s --check-prefix=VARIADIC_COMB_OPS_LEC
// VARIADIC_COMB_OPS_LEC: c1 == c2
hw.module @VariadicCombOps(in %a: i2, in %b: i2, in %c: i2, in %d: i2, in %e: i2, in %f: i2,
out out_and: i2, out out_or: i2, out out_xor: i2) {
%0 = comb.and %a, %b, %c, %d, %e, %f : i2
%1 = comb.or %a, %b, %c, %d, %e, %f : i2
%2 = comb.xor %a, %b, %c, %d, %e, %f : i2
hw.output %0, %1, %2 : i2, i2, i2
}
236 changes: 176 additions & 60 deletions lib/Dialect/Synth/Transforms/LowerVariadic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,20 @@
//
//===----------------------------------------------------------------------===//
//
// This pass lowers variadic AndInverter operations to binary AndInverter
// operations.
// This pass lowers variadic operations to binary operations using a
// delay-aware algorithm for commutative operations.
//
//===----------------------------------------------------------------------===//

#include "circt/Dialect/Comb/CombDialect.h"
#include "circt/Dialect/Comb/CombOps.h"
#include "circt/Dialect/HW/HWOps.h"
#include "circt/Dialect/Synth/Analysis/LongestPathAnalysis.h"
#include "circt/Dialect/Synth/SynthOps.h"
#include "circt/Dialect/Synth/Transforms/SynthPasses.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/IR/OpDefinition.h"
#include "llvm/ADT/PointerIntPair.h"
#include "llvm/ADT/PriorityQueue.h"

#define DEBUG_TYPE "synth-lower-variadic"

Expand All @@ -29,79 +34,190 @@ using namespace circt;
using namespace synth;

//===----------------------------------------------------------------------===//
// Rewrite patterns
// Lower Variadic pass
//===----------------------------------------------------------------------===//

namespace {
static Value lowerVariadicAndInverterOp(aig::AndInverterOp op,
OperandRange operands,
ArrayRef<bool> inverts,
PatternRewriter &rewriter) {
switch (operands.size()) {
case 0:
assert(0 && "cannot be called with empty operand range");
break;
case 1:
if (inverts[0])
return aig::AndInverterOp::create(rewriter, op.getLoc(), operands[0],
true);
else
return operands[0];
case 2:
return aig::AndInverterOp::create(rewriter, op.getLoc(), operands[0],
operands[1], inverts[0], inverts[1]);
default:
auto firstHalf = operands.size() / 2;
auto lhs =
lowerVariadicAndInverterOp(op, operands.take_front(firstHalf),
inverts.take_front(firstHalf), rewriter);
auto rhs =
lowerVariadicAndInverterOp(op, operands.drop_front(firstHalf),
inverts.drop_front(firstHalf), rewriter);
return aig::AndInverterOp::create(rewriter, op.getLoc(), lhs, rhs);
}

return Value();
}
/// Helper class for delay-aware variadic operation lowering.
/// Stores a value along with its arrival time for priority queue ordering.
class ValueWithArrivalTime {
/// The value and an optional inversion flag packed together.
/// The inversion flag is used for AndInverterOp lowering.
llvm::PointerIntPair<Value, 1, bool> value;

struct VariadicOpConversion : OpRewritePattern<aig::AndInverterOp> {
using OpRewritePattern<aig::AndInverterOp>::OpRewritePattern;
LogicalResult matchAndRewrite(aig::AndInverterOp op,
PatternRewriter &rewriter) const override {
if (op.getInputs().size() <= 2)
return failure();
/// The arrival time (delay) of this value in the circuit.
int64_t arrivalTime;

// TODO: This is a naive implementation that creates a balanced binary tree.
// We can improve by analyzing the dataflow and creating a tree that
// improves the critical path or area.
rewriter.replaceOp(op,
lowerVariadicAndInverterOp(op, op.getOperands(),
op.getInverted(), rewriter));
return success();
}
};
/// Value numbering for deterministic ordering when arrival times are equal.
/// This ensures consistent results across runs when multiple values have
/// the same delay.
size_t valueNumbering = 0;

} // namespace
public:
ValueWithArrivalTime(Value value, int64_t arrivalTime, bool invert,
size_t valueNumbering)
: value(value, invert), arrivalTime(arrivalTime),
valueNumbering(valueNumbering) {}

static void populateLowerVariadicPatterns(RewritePatternSet &patterns) {
patterns.add<VariadicOpConversion>(patterns.getContext());
}
Value getValue() const { return value.getPointer(); }
bool isInverted() const { return value.getInt(); }

//===----------------------------------------------------------------------===//
// Lower Variadic pass
//===----------------------------------------------------------------------===//
/// Comparison operator for priority queue. Values with earlier arrival times
/// have higher priority. When arrival times are equal, use value numbering
/// for determinism.
bool operator>(const ValueWithArrivalTime &other) const {
return arrivalTime > other.arrivalTime ||
(arrivalTime == other.arrivalTime &&
valueNumbering > other.valueNumbering);
}
};

namespace {
struct LowerVariadicPass : public impl::LowerVariadicBase<LowerVariadicPass> {
using LowerVariadicBase::LowerVariadicBase;
void runOnOperation() override;
};

} // namespace

/// Construct a balanced binary tree from a variadic operation using a
/// delay-aware algorithm. This function builds the tree by repeatedly combining
/// the two values with the earliest arrival times, which minimizes the critical
/// path delay.
static LogicalResult replaceWithBalancedTree(
Comment on lines +83 to +87
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non-blocking - could this also be used in CombToSynth as we perform variadic reduction there too?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good points, yes. I think it makes sense to eventually move the utility to Synth header and use it from CombToSynth. Though it might be slightly weird because CombToSynth doesn't have timing info at this point. I'll find a way to unify them.

IncrementalLongestPathAnalysis *analysis, mlir::IRRewriter &rewriter,
Operation *op, llvm::function_ref<bool(OpOperand &)> isInverted,
llvm::function_ref<Value(ValueWithArrivalTime, ValueWithArrivalTime)>
createBinaryOp) {
// Min-heap priority queue ordered by arrival time.
// Values with earlier arrival times are processed first.
llvm::PriorityQueue<ValueWithArrivalTime, std::vector<ValueWithArrivalTime>,
std::greater<ValueWithArrivalTime>>
queue;

// Counter for deterministic ordering when arrival times are equal.
size_t valueNumber = 0;

auto push = [&](Value value, bool invert) {
int64_t delay = 0;
// If analysis is available, use it to compute the delay.
// If not available, use zero delay and `valueNumber` will be used instead.
if (analysis) {
auto result = analysis->getMaxDelay(value);
if (failed(result))
return failure();
delay = *result;
}
ValueWithArrivalTime entry(value, delay, invert, valueNumber++);
queue.push(entry);
return success();
};

// Enqueue all operands with their arrival times and inversion flags.
for (size_t i = 0, e = op->getNumOperands(); i < e; ++i)
if (failed(push(op->getOperand(i), isInverted(op->getOpOperand(i)))))
return failure();

// Build balanced tree by repeatedly combining the two earliest values.
// This greedy approach minimizes the maximum depth of late-arriving signals.
while (queue.size() >= 2) {
auto lhs = queue.top();
queue.pop();
auto rhs = queue.top();
queue.pop();
// Create and enqueue the combined value.
if (failed(push(createBinaryOp(lhs, rhs), /*inverted=*/false)))
return failure();
}

// Get the final result and replace the original operation.
auto result = queue.top().getValue();
rewriter.replaceOp(op, result);
return success();
}

void LowerVariadicPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
populateLowerVariadicPatterns(patterns);
mlir::FrozenRewritePatternSet frozen(std::move(patterns));
// Topologically sort operations in graph regions to ensure operands are
// defined before uses.
if (failed(synth::topologicallySortGraphRegionBlocks(
getOperation(), [](Value, Operation *op) -> bool {
return !isa_and_nonnull<comb::CombDialect, synth::SynthDialect>(
op->getDialect());
})))
return signalPassFailure();

// Get longest path analysis if timing-aware lowering is enabled.
synth::IncrementalLongestPathAnalysis *analysis = nullptr;
if (timingAware.getValue())
analysis = &getAnalysis<synth::IncrementalLongestPathAnalysis>();

auto moduleOp = getOperation();

// Build set of operation names to lower if specified.
SmallVector<OperationName> names;
for (const auto &name : opNames)
names.push_back(OperationName(name, &getContext()));

// Return true if the operation should be lowered.
auto shouldLower = [&](Operation *op) {
// If no names specified, lower all variadic ops.
if (names.empty())
return true;
return llvm::find(names, op->getName()) != names.end();
};

mlir::IRRewriter rewriter(&getContext());
rewriter.setListener(analysis);

auto result = moduleOp->walk([&](Operation *op) {
// Skip operations that don't need lowering or are already binary.
if (!shouldLower(op) || op->getNumOperands() <= 2)
return WalkResult::advance();

rewriter.setInsertionPoint(op);

// Handle AndInverterOp specially to preserve inversion flags.
if (auto andInverterOp = dyn_cast<aig::AndInverterOp>(op)) {
auto result = replaceWithBalancedTree(
analysis, rewriter, op,
// Check if each operand is inverted.
[&](OpOperand &operand) {
return andInverterOp.isInverted(operand.getOperandNumber());
},
// Create binary AndInverterOp with inversion flags.
[&](ValueWithArrivalTime lhs, ValueWithArrivalTime rhs) {
return rewriter.create<aig::AndInverterOp>(
op->getLoc(), lhs.getValue(), rhs.getValue(), lhs.isInverted(),
rhs.isInverted());
});
return result.succeeded() ? WalkResult::advance()
: WalkResult::interrupt();
}

// Handle commutative operations (and, or, xor, mul, add, etc.) using
// delay-aware lowering to minimize critical path.
if (isa_and_nonnull<comb::CombDialect>(op->getDialect()) &&
op->hasTrait<OpTrait::IsCommutative>()) {
auto result = replaceWithBalancedTree(
analysis, rewriter, op,
// No inversion flags for standard commutative operations.
[](OpOperand &) { return false; },
// Create binary operation with the same operation type.
[&](ValueWithArrivalTime lhs, ValueWithArrivalTime rhs) {
OperationState state(op->getLoc(), op->getName());
state.addOperands(ValueRange{lhs.getValue(), rhs.getValue()});
state.addTypes(op->getResult(0).getType());
auto *newOp = Operation::create(state);
rewriter.insert(newOp);
return newOp->getResult(0);
});
return result.succeeded() ? WalkResult::advance()
: WalkResult::interrupt();
}

return WalkResult::advance();
});

if (failed(mlir::applyPatternsGreedily(getOperation(), frozen)))
if (result.wasInterrupted())
return signalPassFailure();
}
33 changes: 26 additions & 7 deletions lib/Dialect/Synth/Transforms/SynthesisPipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,22 @@ using namespace circt::synth;

/// Helper function to populate additional legal ops for partial legalization.
template <typename... AllowedOpTy>
static void partiallyLegalizeCombToSynth(SmallVectorImpl<std::string> &ops) {
static void addOpName(SmallVectorImpl<std::string> &ops) {
(ops.push_back(AllowedOpTy::getOperationName().str()), ...);
}

template <typename... OpToLowerTy>
static std::unique_ptr<Pass> createLowerVariadicPass(bool timingAware) {
LowerVariadicOptions options;
addOpName<OpToLowerTy...>(options.opNames);
options.timingAware = timingAware;
return createLowerVariadic(options);
}
void circt::synth::buildCombLoweringPipeline(
OpPassManager &pm, const CombLoweringPipelineOptions &options) {
{
if (!options.disableDatapath) {
// Lower variadic Mul into a binary op to enable datapath lowering.
pm.addPass(createLowerVariadicPass<comb::MulOp>(options.timingAware));
pm.addPass(createConvertCombToDatapath());
pm.addPass(createSimpleCanonicalizerPass());
if (options.synthesisStrategy == OptimizationStrategyTiming)
Expand All @@ -55,10 +63,9 @@ void circt::synth::buildCombLoweringPipeline(
}
// Partially legalize Comb, then run CSE and canonicalization.
circt::ConvertCombToSynthOptions convOptions;
partiallyLegalizeCombToSynth<comb::AndOp, comb::OrOp, comb::XorOp,
comb::MuxOp, comb::ICmpOp, hw::ArrayGetOp,
hw::ArraySliceOp, hw::ArrayCreateOp,
hw::ArrayConcatOp, hw::AggregateConstantOp>(
addOpName<comb::AndOp, comb::OrOp, comb::XorOp, comb::MuxOp, comb::ICmpOp,
hw::ArrayGetOp, hw::ArraySliceOp, hw::ArrayCreateOp,
hw::ArrayConcatOp, hw::AggregateConstantOp>(
convOptions.additionalLegalOps);
pm.addPass(circt::createConvertCombToSynth(convOptions));
}
Expand All @@ -69,6 +76,18 @@ void circt::synth::buildCombLoweringPipeline(
comb::BalanceMuxOptions balanceOptions{OptimizationStrategyTiming ? 16 : 64};
pm.addPass(comb::createBalanceMux(balanceOptions));

// Lower variadic ops before running full lowering to target IR.
if (options.targetIR.getValue() == TargetIR::AIG) {
// For AIG, lower variadic XoR since AIG cannot keep variadic
// representation.
pm.addPass(createLowerVariadicPass<comb::XorOp>(options.timingAware));
} else if (options.targetIR.getValue() == TargetIR::MIG) {
// For MIG, lower variadic And, Or, and Xor since MIG cannot keep variadic
// representation.
pm.addPass(createLowerVariadicPass<comb::AndOp, comb::OrOp, comb::XorOp>(
options.timingAware));
}

pm.addPass(circt::hw::createHWAggregateToComb());
circt::ConvertCombToSynthOptions convOptions;
convOptions.targetIR = options.targetIR.getValue() == TargetIR::AIG
Expand All @@ -83,7 +102,7 @@ void circt::synth::buildCombLoweringPipeline(
void circt::synth::buildSynthOptimizationPipeline(
OpPassManager &pm, const SynthOptimizationPipelineOptions &options) {

pm.addPass(synth::createLowerVariadic());
pm.addPass(createLowerVariadicPass(options.timingAware));

// LowerWordToBits may not be scalable for large designs so conditionally
// disable it. It's also worth considering keeping word-level representation
Expand Down
Loading