Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions integration_test/circt-synth/datapath-lowering-lec.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,24 @@ hw.module @partial_product_zext(in %a : i3, in %b : i3, out sum : i6) {
hw.output %3 : i6
}

// RUN: circt-lec %t.mlir %s -c1=partial_product_square -c2=partial_product_square --shared-libs=%libz3 | FileCheck %s --check-prefix=SQR4
// SQR4: c1 == c2
hw.module @partial_product_square(in %a : i4, out sum : i4) {
%0:4 = datapath.partial_product %a, %a : (i4, i4) -> (i4, i4, i4, i4)
%1 = comb.add bin %0#0, %0#1, %0#2, %0#3 : i4
hw.output %1 : i4
}

// RUN: circt-lec %t.mlir %s -c1=partial_product_square_zext -c2=partial_product_square_zext --shared-libs=%libz3 | FileCheck %s --check-prefix=SQR3_ZEXT
// SQR3_ZEXT: c1 == c2
hw.module @partial_product_square_zext(in %a : i3, out sum : i6) {
%c0_i3 = hw.constant 0 : i3
%0 = comb.concat %c0_i3, %a : i3, i3
%1:3 = datapath.partial_product %0, %0 : (i6, i6) -> (i6, i6, i6)
%2 = comb.add %1#0, %1#1, %1#2 : i6
hw.output %2 : i6
}

// RUN: circt-lec %t.mlir %s -c1=partial_product_sext -c2=partial_product_sext --shared-libs=%libz3 | FileCheck %s --check-prefix=AND3_SEXT
// AND3_SEXT: c1 == c2
hw.module @partial_product_sext(in %a : i3, in %b : i3, out sum : i6) {
Expand Down
84 changes: 82 additions & 2 deletions lib/Conversion/DatapathToComb/DatapathToComb.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/KnownBits.h"
#include <algorithm>

#define DEBUG_TYPE "datapath-to-comb"

Expand Down Expand Up @@ -122,6 +123,21 @@ struct DatapathPartialProductOpConversion : OpRewritePattern<PartialProductOp> {
return success();
}

// Square partial product array can be reduced to upper triangular array.
// For example: AND array for a 4-bit squarer:
// 0 0 0 a0a3 a0a2 a0a1 a0a0
// 0 0 a1a3 a1a2 a1a1 a1a0 0
// 0 a2a3 a2a2 a2a1 a2a0 0 0
// a3a3 a3a2 a3a1 a3a0 0 0 0
//
// Can be reduced to:
// 0 0 a0a3 a0a2 a0a1 0 a0
// 0 a1a3 a1a2 0 a1 0 0
// a2a3 0 a2 0 0 0 0
// a3 0 0 0 0 0 0
if (a == b)
return lowerSqrAndArray(rewriter, a, op, width);

// Use result rows as a heuristic to guide partial product
// implementation
if (op.getNumResults() > 16 || forceBooth)
Expand Down Expand Up @@ -166,6 +182,70 @@ struct DatapathPartialProductOpConversion : OpRewritePattern<PartialProductOp> {
return success();
}

static LogicalResult lowerSqrAndArray(PatternRewriter &rewriter, Value a,
PartialProductOp op, unsigned width) {

Location loc = op.getLoc();
SmallVector<Value> aBits = extractBits(rewriter, a);

SmallVector<Value> partialProducts;
partialProducts.reserve(width);
// AND Array Construction - reducing to upper triangle:
// partialProducts[i] = ({a[i],..., a[i]} & a) << i
// optimised to: {a[i] & a[n-1], ..., a[i] & a[i+1], 0, a[i], 0, ..., 0}
assert(op.getNumResults() <= width &&
"Cannot return more results than the operator width");
auto zeroFalse = hw::ConstantOp::create(rewriter, loc, APInt(1, 0));
for (unsigned i = 0; i < op.getNumResults(); ++i) {
SmallVector<Value> row;
row.reserve(width);

if (2 * i >= width) {
// Pad the remaining rows with zeros
auto zeroWidth = hw::ConstantOp::create(rewriter, loc, APInt(width, 0));
partialProducts.push_back(zeroWidth);
continue;
}

if (i > 0) {
auto shiftBy = hw::ConstantOp::create(rewriter, loc, APInt(2 * i, 0));
row.push_back(shiftBy);
}
row.push_back(aBits[i]);

// Track width of constructed row
unsigned rowWidth = 2 * i + 1;
if (rowWidth < width) {
row.push_back(zeroFalse);
++rowWidth;
}

for (unsigned j = i + 1; j < width; ++j) {
// Stop when we reach the required width
if (rowWidth == width)
break;

// Otherwise pad with zeros or partial product bits
++rowWidth;
// Number of results indicates number of non-zero bits in input
if (j >= op.getNumResults()) {
row.push_back(zeroFalse);
continue;
}

auto ppBit =
rewriter.createOrFold<comb::AndOp>(loc, aBits[i], aBits[j]);
row.push_back(ppBit);
}
std::reverse(row.begin(), row.end());
auto ppRow = comb::ConcatOp::create(rewriter, loc, row);
partialProducts.push_back(ppRow);
}

rewriter.replaceOp(op, partialProducts);
return success();
}

static LogicalResult lowerBoothArray(PatternRewriter &rewriter, Value a,
Value b, PartialProductOp op,
unsigned width) {
Expand Down Expand Up @@ -370,8 +450,8 @@ struct DatapathPosPartialProductOpConversion
unsigned width) {

Location loc = op.getLoc();
// Encode (a+b) by implementing a half-adder - then note the following fact
// carry[i] & save[i] == false
// Encode (a+b) by implementing a half-adder - then note the following
// fact carry[i] & save[i] == false
auto carry = rewriter.createOrFold<comb::AndOp>(loc, a, b);
auto save = rewriter.createOrFold<comb::XorOp>(loc, a, b);

Expand Down
38 changes: 38 additions & 0 deletions test/Conversion/DatapathToComb/datapath-to-comb.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,44 @@ hw.module @partial_product(in %a : i3, in %b : i3, out pp0 : i3, out pp1 : i3, o
hw.output %0#0, %0#1, %0#2 : i3, i3, i3
}

// CHECK-LABEL: @partial_product_square
hw.module @partial_product_square(in %a : i3, out pp0 : i3, out pp1 : i3, out pp2 : i3) {
// CHECK-NEXT: %c0_i3 = hw.constant 0 : i3
// CHECK-NEXT: %c0_i2 = hw.constant 0 : i2
// CHECK-NEXT: %false = hw.constant false
// CHECK-NEXT: %[[A0:.+]] = comb.extract %a from 0 : (i3) -> i1
// CHECK-NEXT: %[[A1:.+]] = comb.extract %a from 1 : (i3) -> i1
// CHECK-NEXT: %[[A01:.+]] = comb.and %[[A0]], %[[A1]] : i1
// CHECK-NEXT: %[[PP0:.+]] = comb.concat %[[A01]], %false, %[[A0]] : i1, i1, i1
// CHECK-NEXT: %[[PP1:.+]] = comb.concat %[[A1]], %c0_i2 : i1, i2
// CHECK-NEXT: hw.output %[[PP0]], %[[PP1]], %c0_i3 : i3, i3, i3
%0:3 = datapath.partial_product %a, %a : (i3, i3) -> (i3, i3, i3)
hw.output %0#0, %0#1, %0#2 : i3, i3, i3
}

// CHECK-LABEL: @partial_product_square_zext
hw.module @partial_product_square_zext(in %a : i3, out pp0 : i6, out pp1 : i6, out pp2 : i6) {
// CHECK-NEXT: %c0_i4 = hw.constant 0 : i4
// CHECK-NEXT: %c0_i2 = hw.constant 0 : i2
// CHECK-NEXT: %false = hw.constant false
// CHECK-NEXT: %c0_i3 = hw.constant 0 : i3
// CHECK-NEXT: %[[AEXT:.+]] = comb.concat %c0_i3, %a : i3, i3
// CHECK-NEXT: %[[A0:.+]] = comb.extract %[[AEXT]] from 0 : (i6) -> i1
// CHECK-NEXT: %[[A1:.+]] = comb.extract %[[AEXT]] from 1 : (i6) -> i1
// CHECK-NEXT: %[[A2:.+]] = comb.extract %[[AEXT]] from 2 : (i6) -> i1
// CHECK-NEXT: %[[A01:.+]] = comb.and %[[A0]], %[[A1]] : i1
// CHECK-NEXT: %[[A02:.+]] = comb.and %[[A0]], %[[A2]] : i1
// CHECK-NEXT: %[[PP0:.+]] = comb.concat %false, %false, %[[A02]], %[[A01]], %false, %[[A0]] : i1, i1, i1, i1, i1, i1
// CHECK-NEXT: %[[A12:.+]] = comb.and %[[A1]], %[[A2]] : i1
// CHECK-NEXT: %[[PP1:.+]] = comb.concat %false, %[[A12]], %false, %[[A1]], %c0_i2 : i1, i1, i1, i1, i2
// CHECK-NEXT: %[[PP2:.+]] = comb.concat %false, %[[A2]], %c0_i4 : i1, i1, i4
// CHECK-NEXT: hw.output %[[PP0]], %[[PP1]], %[[PP2]] : i6, i6, i6
%c0_i3 = hw.constant 0 : i3
%0 = comb.concat %c0_i3, %a : i3, i3
%1:3 = datapath.partial_product %0, %0 : (i6, i6) -> (i6, i6, i6)
hw.output %1#0, %1#1, %1#2 : i6, i6, i6
}

// CHECK-LABEL: @partial_product_booth
// FORCE-BOOTH-LABEL: @partial_product_booth
// Constants
Expand Down