Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions include/circt/Dialect/Datapath/DatapathOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -102,4 +102,38 @@ def PartialProductOp : DatapathOp<"partial_product",
];
}

def PosPartialProductOp : DatapathOp<"pos_partial_product",
[Pure, SameTypeOperands, SameOperandsAndResultType]> {
let summary = "Generate partial products for (a+b)*c";
let description = [{
A partial product array which when summed produces the result (a+b)*c.
An efficient circuit can be constructed that encodes a carry-save input
without performing a full carry-propagation. The number of results
corresponds to the rows of a partial product array, which by default is
equal to the width of the inputs.

Example using `datapath` dialect:
```mlir
%0:3 = datapath.pos_partial_product %a, %b, %c : (i3, i3, i3)
-> (i3, i3, i3)
```
}];
let arguments = (ins HWIntegerType:$addend0, HWIntegerType:$addend1,
HWIntegerType:$multiplicand);
let results = (outs Variadic<HWIntegerType>:$results);

let assemblyFormat = [{
$addend0 `,` $addend1 `,` $multiplicand attr-dict `:` functional-type(operands, results)
}];
let hasCanonicalizer = true;

let builders = [
OpBuilder<(ins "ValueRange":$lhs, "int32_t":$targetRows), [{
auto inputType = lhs.front().getType();
SmallVector<Type> resultTypes(targetRows, inputType);
return build($_builder, $_state, resultTypes, lhs);
}]>
];
}

#endif // CIRCT_DIALECT_DATAPATH_OPS_TD
21 changes: 21 additions & 0 deletions integration_test/circt-synth/datapath-lowering-lec.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,27 @@ hw.module @partial_product_sext(in %a : i3, in %b : i3, out sum : i6) {
hw.output %7 : i6
}

// RUN: circt-lec %t.mlir %s -c1=pos_partial_product_4 -c2=pos_partial_product_4 --shared-libs=%libz3 | FileCheck %s --check-prefix=POS4
// POS4: c1 == c2
hw.module @pos_partial_product_4(in %a : i4, in %b : i4, in %c : i4, out sum : i4) {
%0:4 = datapath.pos_partial_product %a, %b, %c : (i4, i4, i4) -> (i4, i4, i4, i4)
%1 = comb.add bin %0#0, %0#1, %0#2, %0#3 : i4
hw.output %1 : i4
}

// RUN: circt-lec %t.mlir %s -c1=pos_partial_product_zext -c2=pos_partial_product_zext --shared-libs=%libz3 | FileCheck %s --check-prefix=POS_ZEXT
// POS_ZEXT: c1 == c2
hw.module @pos_partial_product_zext(in %a : i4, in %b : i3, in %c : i4, out sum : i8) {
%c0_i4 = hw.constant 0 : i4
%c0_i5 = hw.constant 0 : i5
%0 = comb.concat %c0_i4, %a : i4, i4
%1 = comb.concat %c0_i5, %b : i5, i3
%2 = comb.concat %c0_i4, %c : i4, i4
%3:8 = datapath.pos_partial_product %0, %1, %2 : (i8, i8, i8) -> (i8, i8, i8, i8, i8, i8, i8, i8)
%4 = comb.add %3#0, %3#1, %3#2, %3#3, %3#4, %3#5, %3#6, %3#7 : i8
hw.output %4 : i8
}

// RUN: circt-lec %t.mlir %s -c1=compress_3 -c2=compress_3 --shared-libs=%libz3 | FileCheck %s --check-prefix=COMP3
// COMP3: c1 == c2
hw.module @compress_3(in %a : i4, in %b : i4, in %c : i4, out sum : i4) {
Expand Down
118 changes: 116 additions & 2 deletions lib/Conversion/DatapathToComb/DatapathToComb.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,8 @@ struct DatapathPartialProductOpConversion : OpRewritePattern<PartialProductOp> {
a = rewriter.createOrFold<comb::ExtractOp>(loc, a, 0, rowWidth);
}
}

// TODO - replace with a concatenation to aid longest path analysis
auto oneRowWidth =
hw::ConstantOp::create(rewriter, loc, APInt(rowWidth, 1));
// Booth encoding will select each row from {-2a, -1a, 0, 1a, 2a}
Expand Down Expand Up @@ -333,6 +335,117 @@ struct DatapathPartialProductOpConversion : OpRewritePattern<PartialProductOp> {
return success();
}
};

struct DatapathPosPartialProductOpConversion
: OpRewritePattern<PosPartialProductOp> {
using OpRewritePattern<PosPartialProductOp>::OpRewritePattern;

DatapathPosPartialProductOpConversion(MLIRContext *context, bool forceBooth)
: OpRewritePattern<PosPartialProductOp>(context),
forceBooth(forceBooth){};

const bool forceBooth;

LogicalResult matchAndRewrite(PosPartialProductOp op,
PatternRewriter &rewriter) const override {

Value a = op.getAddend0();
Value b = op.getAddend1();
Value c = op.getMultiplicand();
unsigned width = a.getType().getIntOrFloatBitWidth();

// Skip a zero width value.
if (width == 0) {
rewriter.replaceOpWithNewOp<hw::ConstantOp>(op, op.getType(0), 0);
return success();
}

// TODO: Implement Booth lowering
return lowerAndArray(rewriter, a, b, c, op, width);
}

private:
static LogicalResult lowerAndArray(PatternRewriter &rewriter, Value a,
Value b, Value c, PosPartialProductOp op,
unsigned width) {

Location loc = op.getLoc();
// Encode (a+b) by implementing a half-adder - then note the following fact
// carry[i] & save[i] == false
auto carry = rewriter.createOrFold<comb::AndOp>(loc, a, b);
auto save = rewriter.createOrFold<comb::XorOp>(loc, a, b);

SmallVector<Value> carryBits = extractBits(rewriter, carry);
SmallVector<Value> saveBits = extractBits(rewriter, save);

// Reduce c width based on leading zeros
auto rowWidth = width;
auto knownBitsC = comb::computeKnownBits(c);
if (!knownBitsC.Zero.isZero()) {
if (knownBitsC.Zero.countLeadingOnes() > 1) {
// Retain one leading zero to represent 2*{1'b0, c} = {c, 1'b0}
// {'0, c} -> {1'b0, c}
rowWidth -= knownBitsC.Zero.countLeadingOnes() - 1;
c = rewriter.createOrFold<comb::ExtractOp>(loc, c, 0, rowWidth);
}
}

// Compute 2*c for use in array construction
Value zero = hw::ConstantOp::create(rewriter, loc, APInt(1, 0));
Value twoCWider = rewriter.create<comb::ConcatOp>(loc, ValueRange{c, zero});
Value twoC = rewriter.create<comb::ExtractOp>(loc, twoCWider, 0, rowWidth);

// AND Array Construction:
// pp[i] = ( (carry[i] * (c<<1)) | (save[i] * c) ) << i
SmallVector<Value> partialProducts;
partialProducts.reserve(width);

assert(op.getNumResults() <= width &&
"Cannot return more results than the operator width");

for (unsigned i = 0; i < op.getNumResults(); ++i) {
auto replSave =
rewriter.createOrFold<comb::ReplicateOp>(loc, saveBits[i], rowWidth);
auto replCarry =
rewriter.createOrFold<comb::ReplicateOp>(loc, carryBits[i], rowWidth);

auto ppRowSave = rewriter.createOrFold<comb::AndOp>(loc, replSave, c);
auto ppRowCarry =
rewriter.createOrFold<comb::AndOp>(loc, replCarry, twoC);
auto ppRow =
rewriter.createOrFold<comb::OrOp>(loc, ppRowSave, ppRowCarry);
auto ppAlign = ppRow;
if (i > 0) {
auto shiftBy = hw::ConstantOp::create(rewriter, loc, APInt(i, 0));
ppAlign =
comb::ConcatOp::create(rewriter, loc, ValueRange{ppRow, shiftBy});
}

// May need to truncate shifted value
if (rowWidth + i > width) {
auto ppAlignTrunc =
rewriter.createOrFold<comb::ExtractOp>(loc, ppAlign, 0, width);
partialProducts.push_back(ppAlignTrunc);
continue;
}
// May need to zero pad to approriate width
if (rowWidth + i < width) {
auto padding = width - rowWidth - i;
Value zeroPad =
hw::ConstantOp::create(rewriter, loc, APInt(padding, 0));
partialProducts.push_back(rewriter.createOrFold<comb::ConcatOp>(
loc, ValueRange{zeroPad, ppAlign})); // Pad to full width
continue;
}

partialProducts.push_back(ppAlign);
}

rewriter.replaceOp(op, partialProducts);
return success();
}
};

} // namespace

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -370,8 +483,9 @@ static LogicalResult applyPatternsGreedilyWithTimingInfo(
void ConvertDatapathToCombPass::runOnOperation() {
RewritePatternSet patterns(&getContext());

patterns.add<DatapathPartialProductOpConversion>(patterns.getContext(),
forceBooth);
patterns.add<DatapathPartialProductOpConversion,
DatapathPosPartialProductOpConversion>(patterns.getContext(),
forceBooth);
synth::IncrementalLongestPathAnalysis *analysis = nullptr;
if (timingAware)
analysis = &getAnalysis<synth::IncrementalLongestPathAnalysis>();
Expand Down
56 changes: 54 additions & 2 deletions lib/Conversion/DatapathToSMT/DatapathToSMT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,58 @@ struct PartialProductOpConversion : OpConversionPattern<PartialProductOp> {
return success();
}
};

// Lower to an SMT assertion that summing the results is equivalent to the
// product of the sum of the pos_partial_product inputs
// c:<N> = pos_partial_product(a, b, c) ->
// assert(c#0 + ... + c#<N-1> == (a + b) * c)
struct PosPartialProductOpConversion
: OpConversionPattern<PosPartialProductOp> {
using OpConversionPattern<PosPartialProductOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(PosPartialProductOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

ValueRange operands = adaptor.getOperands();
ValueRange results = op.getResults();

// (a+b)
auto addResult =
smt::BVAddOp::create(rewriter, op.getLoc(), operands[0], operands[1]);
// (a+b)*c
auto mulResult =
smt::BVMulOp::create(rewriter, op.getLoc(), addResult, operands[2]);

// Create free variables
SmallVector<Value, 2> newResults;
newResults.reserve(results.size());
for (Value result : results) {
auto declareFunOp = smt::DeclareFunOp::create(
rewriter, op.getLoc(), typeConverter->convertType(result.getType()));
newResults.push_back(declareFunOp.getResult());
}

// Sum the free variables
Value resultRunner = newResults.front();
for (auto freeVar : llvm::drop_begin(newResults, 1))
resultRunner =
smt::BVAddOp::create(rewriter, op.getLoc(), resultRunner, freeVar);

// Assert product of operands == sum results (free variables)
auto premise =
smt::EqOp::create(rewriter, op.getLoc(), mulResult, resultRunner);
// Encode via an assertion (could be relaxed to an assumption).
smt::AssertOp::create(rewriter, op.getLoc(), premise);

if (newResults.size() != results.size())
return rewriter.notifyMatchFailure(op, "expected same number of results");

rewriter.replaceOp(op, newResults);
return success();
}
};

} // namespace

//===----------------------------------------------------------------------===//
Expand All @@ -138,8 +190,8 @@ struct ConvertDatapathToSMTPass

void circt::populateDatapathToSMTConversionPatterns(
TypeConverter &converter, RewritePatternSet &patterns) {
patterns.add<CompressOpConversion, PartialProductOpConversion>(
converter, patterns.getContext());
patterns.add<CompressOpConversion, PartialProductOpConversion,
PosPartialProductOpConversion>(converter, patterns.getContext());
}

void ConvertDatapathToSMTPass::runOnOperation() {
Expand Down
Loading