Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions integration_test/circt-synth/datapath-lowering-lec.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,31 @@ hw.module @partial_product_4(in %a : i4, in %b : i4, out sum : i4) {
hw.output %1 : i4
}

// RUN: circt-lec %t.mlir %s -c1=partial_product_zext -c2=partial_product_zext --shared-libs=%libz3 | FileCheck %s --check-prefix=AND3_ZEXT
// AND3_ZEXT: c1 == c2
hw.module @partial_product_zext(in %a : i3, in %b : i3, out sum : i6) {
%c0_i3 = hw.constant 0 : i3
%0 = comb.concat %c0_i3, %a : i3, i3
%1 = comb.concat %c0_i3, %b : i3, i3
%2:3 = datapath.partial_product %0, %1 : (i6, i6) -> (i6, i6, i6)
%3 = comb.add %2#0, %2#1, %2#2 : i6
hw.output %3 : i6
}

// RUN: circt-lec %t.mlir %s -c1=partial_product_sext -c2=partial_product_sext --shared-libs=%libz3 | FileCheck %s --check-prefix=AND3_SEXT
// AND3_SEXT: c1 == c2
hw.module @partial_product_sext(in %a : i3, in %b : i3, out sum : i6) {
%0 = comb.extract %a from 2 : (i3) -> i1
%1 = comb.extract %b from 2 : (i3) -> i1
%2 = comb.replicate %0 : (i1) -> i3
%3 = comb.replicate %1 : (i1) -> i3
%4 = comb.concat %2, %a : i3, i3
%5 = comb.concat %3, %b : i3, i3
%6:6 = datapath.partial_product %4, %5 : (i6, i6) -> (i6, i6, i6, i6, i6, i6)
%7 = comb.add %6#0, %6#1, %6#2, %6#3, %6#4, %6#5 : i6
hw.output %7 : i6
}

// RUN: circt-lec %t.mlir %s -c1=compress_3 -c2=compress_3 --shared-libs=%libz3 | FileCheck %s --check-prefix=COMP3
// COMP3: c1 == c2
hw.module @compress_3(in %a : i4, in %b : i4, in %c : i4, out sum : i4) {
Expand All @@ -41,6 +66,12 @@ hw.module @compress_6(in %a : i4, in %b : i4, in %c : i4, in %d : i4, in %e : i4
// RUN: circt-lec %t.mlir %s -c1=partial_product_4 -c2=partial_product_4 --shared-libs=%libz3 | FileCheck %s --check-prefix=BOOTH4
// BOOTH4: c1 == c2

// RUN: circt-lec %t.mlir %s -c1=partial_product_zext -c2=partial_product_zext --shared-libs=%libz3 | FileCheck %s --check-prefix=BOOTH3_ZEXT
// BOOTH3_ZEXT: c1 == c2

// RUN: circt-lec %t.mlir %s -c1=partial_product_sext -c2=partial_product_sext --shared-libs=%libz3 | FileCheck %s --check-prefix=BOOTH3_SEXT
// BOOTH3_SEXT: c1 == c2

// RUN: circt-lec %t.mlir %s -c1=compress_3 -c2=compress_3 --shared-libs=%libz3 | FileCheck %s --check-prefix=COMPADD3
// COMPADD3: c1 == c2

Expand Down
128 changes: 102 additions & 26 deletions lib/Conversion/DatapathToComb/DatapathToComb.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/KnownBits.h"

#define DEBUG_TYPE "datapath-to-comb"

Expand Down Expand Up @@ -134,8 +135,9 @@ struct DatapathPartialProductOpConversion
"Cannot return more results than the operator width");

for (unsigned i = 0; i < op.getNumResults(); ++i) {
auto repl = comb::ReplicateOp::create(rewriter, loc, bBits[i], width);
auto ppRow = comb::AndOp::create(rewriter, loc, repl, a);
auto repl =
rewriter.createOrFold<comb::ReplicateOp>(loc, bBits[i], width);
auto ppRow = rewriter.createOrFold<comb::AndOp>(loc, repl, a);
auto shiftBy = hw::ConstantOp::create(rewriter, loc, APInt(width, i));
auto ppAlign = comb::ShlOp::create(rewriter, loc, ppRow, shiftBy);
partialProducts.push_back(ppAlign);
Expand All @@ -151,11 +153,37 @@ struct DatapathPartialProductOpConversion
Location loc = op.getLoc();
auto zeroFalse = hw::ConstantOp::create(rewriter, loc, APInt(1, 0));
auto zeroWidth = hw::ConstantOp::create(rewriter, loc, APInt(width, 0));
auto oneWidth = hw::ConstantOp::create(rewriter, loc, APInt(width, 1));
Value twoA = comb::ShlOp::create(rewriter, loc, a, oneWidth);

// Detect leading zeros in multiplicand due to zero-extension
// and truncate to reduce partial product bits
// {'0, a} * {'0, b}
auto rowWidth = width;
auto knownBitsA = comb::computeKnownBits(a);
if (!knownBitsA.Zero.isZero()) {
if (knownBitsA.Zero.countLeadingOnes() > 1) {
// Retain one leading zero to represent 2*{1'b0, a} = {a, 1'b0}
// {'0, a} -> {1'b0, a}
rowWidth -= knownBitsA.Zero.countLeadingOnes() - 1;
a = rewriter.createOrFold<comb::ExtractOp>(loc, a, 0, rowWidth);
}
}
auto oneRowWidth =
hw::ConstantOp::create(rewriter, loc, APInt(rowWidth, 1));
// Booth encoding will select each row from {-2a, -1a, 0, 1a, 2a}
Value twoA = rewriter.createOrFold<comb::ShlOp>(loc, a, oneRowWidth);

// Encode based on the bits of b
// TODO: sort a and b based on non-zero bits to encode the smaller input
SmallVector<Value> bBits = extractBits(rewriter, b);

// Identify zero bits of b to reduce height of partial product array
auto knownBitsB = comb::computeKnownBits(b);
if (!knownBitsB.Zero.isZero()) {
for (unsigned i = 0; i < width; ++i)
if (knownBitsB.Zero[i])
bBits[i] = zeroFalse;
}

SmallVector<Value> partialProducts;
partialProducts.reserve(width);

Expand All @@ -176,33 +204,80 @@ struct DatapathPartialProductOpConversion
// Is the encoding zero or negative (an approximation)
Value encNeg = bip1;
// Is the encoding one = b[i] xor b[i-1]
Value encOne = comb::XorOp::create(rewriter, loc, bi, bim1, true);
Value encOne = rewriter.createOrFold<comb::XorOp>(loc, bi, bim1, true);
// Is the encoding two = (bip1 & ~bi & ~bim1) | (~bip1 & bi & bim1)
Value constOne = hw::ConstantOp::create(rewriter, loc, APInt(1, 1));
Value biInv = comb::XorOp::create(rewriter, loc, bi, constOne, true);
Value bip1Inv = comb::XorOp::create(rewriter, loc, bip1, constOne, true);
Value bim1Inv = comb::XorOp::create(rewriter, loc, bim1, constOne, true);

Value andLeft = comb::AndOp::create(rewriter, loc,
ValueRange{bip1Inv, bi, bim1}, true);
Value andRight = comb::AndOp::create(
rewriter, loc, ValueRange{bip1, biInv, bim1Inv}, true);
Value encTwo = comb::OrOp::create(rewriter, loc, andLeft, andRight, true);
Value biInv = rewriter.createOrFold<comb::XorOp>(loc, bi, constOne, true);
Value bip1Inv =
rewriter.createOrFold<comb::XorOp>(loc, bip1, constOne, true);
Value bim1Inv =
rewriter.createOrFold<comb::XorOp>(loc, bim1, constOne, true);

Value andLeft = rewriter.createOrFold<comb::AndOp>(
loc, ValueRange{bip1Inv, bi, bim1}, true);
Value andRight = rewriter.createOrFold<comb::AndOp>(
loc, ValueRange{bip1, biInv, bim1Inv}, true);
Value encTwo =
rewriter.createOrFold<comb::OrOp>(loc, andLeft, andRight, true);

Value encNegRepl =
comb::ReplicateOp::create(rewriter, loc, encNeg, width);
rewriter.createOrFold<comb::ReplicateOp>(loc, encNeg, rowWidth);
Value encOneRepl =
comb::ReplicateOp::create(rewriter, loc, encOne, width);
rewriter.createOrFold<comb::ReplicateOp>(loc, encOne, rowWidth);
Value encTwoRepl =
comb::ReplicateOp::create(rewriter, loc, encTwo, width);
rewriter.createOrFold<comb::ReplicateOp>(loc, encTwo, rowWidth);

// Select between 2*a or 1*a or 0*a
Value selTwoA = comb::AndOp::create(rewriter, loc, encTwoRepl, twoA);
Value selOneA = comb::AndOp::create(rewriter, loc, encOneRepl, a);
Value magA = comb::OrOp::create(rewriter, loc, selTwoA, selOneA, true);
Value selTwoA = rewriter.createOrFold<comb::AndOp>(loc, encTwoRepl, twoA);
Value selOneA = rewriter.createOrFold<comb::AndOp>(loc, encOneRepl, a);
Value magA =
rewriter.createOrFold<comb::OrOp>(loc, selTwoA, selOneA, true);

// Conditionally invert the row
Value ppRow = comb::XorOp::create(rewriter, loc, magA, encNegRepl, true);
Value ppRow =
rewriter.createOrFold<comb::XorOp>(loc, magA, encNegRepl, true);

// Sign-extension Optimisation:
// Section 7.2.2 of "Application Specific Arithmetic" by Dinechin & Kumm
// Handle sign-extension and padding to full width
// s = encNeg (sign-bit)
// {s, s, s, s, s, pp} = {1, 1, 1, 1, 1, pp}
// + {0, 0, 0, 0,!s, '0}
// Applying this to every row we create an upper-triangle of 1s that can
// be optimised away since they will not affect the final sum.
// {!s3, 0,!s2, 0,!s1, 0}
// { 1, 1, 1, 1, 1, p1}
// { 1, 1, 1, p2 }
// { 1, p3 }
if (rowWidth < width) {
auto padding = width - rowWidth;
auto encNegInv = bip1Inv;

// Sign-extension trick not worth it for padding < 3
if (padding < 3) {
Value encNegPad =
rewriter.createOrFold<comb::ReplicateOp>(loc, encNeg, padding);
ppRow = rewriter.createOrFold<comb::ConcatOp>(
loc, ValueRange{encNegPad, ppRow}); // Pad to full width
} else if (i == 0) {
// First row = {!encNeg, encNeg, encNeg, ppRow}
ppRow = rewriter.createOrFold<comb::ConcatOp>(
loc, ValueRange{encNegInv, encNeg, encNeg, ppRow});
} else {
// Remaining rows = {1, !encNeg, ppRow}
ppRow = rewriter.createOrFold<comb::ConcatOp>(
loc, ValueRange{constOne, encNegInv, ppRow});
}

// Zero pad to full width
auto rowWidth = ppRow.getType().getIntOrFloatBitWidth();
if (rowWidth < width) {
auto zeroPad =
hw::ConstantOp::create(rewriter, loc, APInt(width - rowWidth, 0));
ppRow = rewriter.createOrFold<comb::ConcatOp>(
loc, ValueRange{zeroPad, ppRow});
}
}

// No sign-correction in the first row
if (i == 0) {
Expand All @@ -214,13 +289,14 @@ struct DatapathPartialProductOpConversion
// Insert a sign-correction from the previous row
assert(i >= 2 && "Expected i to be at least 2 for sign correction");
// {ppRow, 0, encNegPrev} << 2*(i-1)
Value withSignCorrection = comb::ConcatOp::create(
rewriter, loc, ValueRange{ppRow, zeroFalse, encNegPrev});
Value ppAlignPre =
comb::ExtractOp::create(rewriter, loc, withSignCorrection, 0, width);
Value withSignCorrection = rewriter.createOrFold<comb::ConcatOp>(
loc, ValueRange{ppRow, zeroFalse, encNegPrev});
Value ppAlignPre = rewriter.createOrFold<comb::ExtractOp>(
loc, withSignCorrection, 0, width);
Value shiftBy =
hw::ConstantOp::create(rewriter, loc, APInt(width, i - 2));
Value ppAlign = comb::ShlOp::create(rewriter, loc, ppAlignPre, shiftBy);
Value ppAlign =
rewriter.createOrFold<comb::ShlOp>(loc, ppAlignPre, shiftBy);
partialProducts.push_back(ppAlign);
encNegPrev = encNeg;

Expand Down
38 changes: 38 additions & 0 deletions test/Conversion/DatapathToComb/datapath-to-comb.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,44 @@ hw.module @partial_product_booth(in %a : i3, in %b : i3, out pp0 : i3, out pp1 :
hw.output %0#0, %0#1, %0#2 : i3, i3, i3
}

// CHECK-LABEL: @partial_product_booth_zext
// FORCE-BOOTH-LABEL: @partial_product_booth_zext
hw.module @partial_product_booth_zext(in %a : i3, in %b : i3, out pp0 : i6, out pp1 : i6, out pp2 : i6) {
// FORCE-BOOTH-NEXT: %true = hw.constant true
// FORCE-BOOTH-NEXT: %c0_i6 = hw.constant 0 : i6
// FORCE-BOOTH-NEXT: %false = hw.constant false
// FORCE-BOOTH-NEXT: %[[A:.*]] = comb.concat %false, %a : i1, i3
// FORCE-BOOTH-NEXT: %[[TWOA:.*]] = comb.concat %a, %false : i3, i1
// FORCE-BOOTH-NEXT: %[[B0:.*]] = comb.extract %b from 0 : (i3) -> i1
// FORCE-BOOTH-NEXT: %[[B1:.*]] = comb.extract %b from 1 : (i3) -> i1
// FORCE-BOOTH-NEXT: %[[B2:.*]] = comb.extract %b from 2 : (i3) -> i1
// FORCE-BOOTH-NEXT: %[[NB0:.*]] = comb.xor bin %[[B0]], %true : i1
// FORCE-BOOTH-NEXT: %[[B1NB0:.*]] = comb.and %[[B1]], %[[NB0]] : i1
// FORCE-BOOTH-NEXT: %[[RB1:.*]] = comb.replicate %[[B1]] : (i1) -> i4
// FORCE-BOOTH-NEXT: %[[RB0:.*]] = comb.replicate %[[B0]] : (i1) -> i4
// FORCE-BOOTH-NEXT: %[[RB1NB0:.*]] = comb.replicate %[[B1NB0]] : (i1) -> i4
// FORCE-BOOTH-NEXT: %[[ROW02A:.*]] = comb.and %[[RB1NB0]], %[[TWOA]] : i4
// FORCE-BOOTH-NEXT: %[[ROW0A:.*]] = comb.and %[[RB0]], %[[A]] : i4
// FORCE-BOOTH-NEXT: %[[ROW0:.*]] = comb.or bin %[[ROW02A]], %[[ROW0A]] : i4
// FORCE-BOOTH-NEXT: %[[NROW0:.*]] = comb.xor bin %[[ROW0]], %[[RB1]] : i4
// FORCE-BOOTH-NEXT: %[[SEXTB1:.*]] = comb.replicate %[[B1]] : (i1) -> i2
// FORCE-BOOTH-NEXT: %[[PP0:.*]] = comb.concat %[[SEXTB1]], %[[NROW0]] : i2, i4
// FORCE-BOOTH-NEXT: %[[B2XORB1:.*]] = comb.xor bin %[[B2]], %[[B1]] : i1
// FORCE-BOOTH-NEXT: %[[B2B1:.*]] = comb.and %[[B2]], %[[B1]] : i1
// FORCE-BOOTH-NEXT: %[[RB2XORB1:.*]] = comb.replicate %[[B2XORB1]] : (i1) -> i4
// FORCE-BOOTH-NEXT: %[[RB2B1:.*]] = comb.replicate %[[B2B1]] : (i1) -> i4
// FORCE-BOOTH-NEXT: %[[ROW12A:.*]] = comb.and %[[RB2B1]], %[[TWOA]] : i4
// FORCE-BOOTH-NEXT: %[[ROW1A:.*]] = comb.and %[[RB2XORB1]], %[[A]] : i4
// FORCE-BOOTH-NEXT: %[[ROW1:.*]] = comb.or bin %[[ROW12A]], %[[ROW1A]] : i4
// FORCE-BOOTH-NEXT: %[[PP1:.*]] = comb.concat %[[ROW1]], %false, %[[B1]] : i4, i1, i1
// FORCE-BOOTH-NEXT: hw.output %[[PP0]], %[[PP1]], %c0_i6 : i6, i6, i6
%c0_i3 = hw.constant 0 : i3
%0 = comb.concat %c0_i3, %a : i3, i3
%1 = comb.concat %c0_i3, %b : i3, i3
%2:3 = datapath.partial_product %0, %1 : (i6, i6) -> (i6, i6, i6)
hw.output %2#0, %2#1, %2#2 : i6, i6, i6
}

// CHECK-LABEL: @partial_product_24
hw.module @partial_product_24(in %a : i24, in %b : i24, out sum : i24) {
%0:24 = datapath.partial_product %a, %b : (i24, i24) -> (i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24)
Expand Down