1313#include " circt/Dialect/HW/HWOps.h"
1414#include " mlir/Analysis/TopologicalSortUtils.h"
1515#include " mlir/Dialect/Func/IR/FuncOps.h"
16+ #include " mlir/IR/PatternMatch.h"
1617#include " mlir/Pass/Pass.h"
17- #include " mlir/Transforms/DialectConversion .h"
18+ #include " mlir/Transforms/GreedyPatternRewriteDriver .h"
1819#include " llvm/Support/Debug.h"
1920
2021#define DEBUG_TYPE " datapath-to-comb"
@@ -42,11 +43,11 @@ namespace {
4243// Replace compressor by an adder of the inputs and zero for the other results:
4344// compress(a,b,c,d) -> {a+b+c+d, 0}
4445// Facilitates use of downstream compression algorithms e.g. Yosys
45- struct DatapathCompressOpAddConversion : OpConversionPattern <CompressOp> {
46- using OpConversionPattern <CompressOp>::OpConversionPattern ;
46+ struct DatapathCompressOpAddConversion : mlir::OpRewritePattern <CompressOp> {
47+ using mlir::OpRewritePattern <CompressOp>::OpRewritePattern ;
4748 LogicalResult
48- matchAndRewrite (CompressOp op, OpAdaptor adaptor,
49- ConversionPatternRewriter &rewriter) const override {
49+ matchAndRewrite (CompressOp op,
50+ mlir::PatternRewriter &rewriter) const override {
5051 Location loc = op.getLoc ();
5152 auto inputs = op.getOperands ();
5253 unsigned width = inputs[0 ].getType ().getIntOrFloatBitWidth ();
@@ -62,15 +63,14 @@ struct DatapathCompressOpAddConversion : OpConversionPattern<CompressOp> {
6263};
6364
6465// Replace compressor by a wallace tree of full-adders
65- struct DatapathCompressOpConversion : OpConversionPattern<CompressOp> {
66- using OpConversionPattern<CompressOp>::OpConversionPattern;
66+ struct DatapathCompressOpConversion : mlir::OpRewritePattern<CompressOp> {
6767 DatapathCompressOpConversion (MLIRContext *context,
6868 aig::IncrementalLongestPathAnalysis *analysis)
69- : OpConversionPattern <CompressOp>(context), analysis(analysis) {}
69+ : mlir::OpRewritePattern <CompressOp>(context), analysis(analysis) {}
7070
7171 LogicalResult
72- matchAndRewrite (CompressOp op, OpAdaptor adaptor,
73- ConversionPatternRewriter &rewriter) const override {
72+ matchAndRewrite (CompressOp op,
73+ mlir::PatternRewriter &rewriter) const override {
7474 Location loc = op.getLoc ();
7575 auto inputs = op.getOperands ();
7676 unsigned width = inputs[0 ].getType ().getIntOrFloatBitWidth ();
@@ -79,30 +79,6 @@ struct DatapathCompressOpConversion : OpConversionPattern<CompressOp> {
7979 for (auto input : inputs) {
8080 addends.push_back (
8181 extractBits (rewriter, input)); // Extract bits from each input
82-
83- // NOTE: Following change will be splitted into a separate PR.
84- if (analysis) {
85- auto delay = analysis->getOrComputePaths (input, 0 );
86- if (failed (delay))
87- return op.emitError (" Failed to get delay for input" );
88- // TODO: Use the delay information to sort the inputs.
89- }
90-
91- LLVM_DEBUG ({
92- llvm::dbgs () << " Input: " << input << " delay: " ;
93- assert (analysis && " Expected analysis to be set" );
94- if (analysis) {
95- auto delay = analysis->getOrComputeDelay (
96- input, 0 ); // Query delay for each input
97- if (llvm::succeeded (delay))
98- llvm::dbgs () << *delay;
99- else
100- llvm::dbgs () << " N/A" ;
101- } else {
102- llvm::dbgs () << " N/A(analysis not set)" ;
103- }
104- llvm::dbgs () << " \n " ;
105- });
10682 }
10783
10884 // Wallace tree reduction
@@ -111,6 +87,26 @@ struct DatapathCompressOpConversion : OpConversionPattern<CompressOp> {
11187 // sort the inputs according to arrival time.
11288 // TODO: Use the listener to get arrival time information.
11389 auto targetAddends = op.getNumResults ();
90+ if (analysis) {
91+ // Sort the addends row based on the delay of the input.
92+ for (size_t j = 0 ; j < addends[0 ].size (); ++j) {
93+ SmallVector<std::pair<int64_t , Value>> delays;
94+ for (size_t i = 0 ; i < addends.size (); ++i) {
95+ auto delay = analysis->getOrComputeMaxDelay (addends[i][j], 0 );
96+ if (failed (delay))
97+ return rewriter.notifyMatchFailure (op,
98+ " Failed to get delay for input" );
99+ delays.push_back (std::make_pair (*delay, addends[i][j]));
100+ }
101+ std::stable_sort (delays.begin (), delays.end (),
102+ [](const std::pair<int64_t , Value> &a,
103+ const std::pair<int64_t , Value> &b) {
104+ return a.first < b.first ;
105+ });
106+ for (size_t i = 0 ; i < addends.size (); ++i)
107+ addends[i][j] = delays[i].second ;
108+ }
109+ }
114110 rewriter.replaceOp (op, comb::wallaceReduction (rewriter, loc, width,
115111 targetAddends, addends));
116112 return success ();
@@ -120,19 +116,16 @@ struct DatapathCompressOpConversion : OpConversionPattern<CompressOp> {
120116 aig::IncrementalLongestPathAnalysis *analysis = nullptr ;
121117};
122118
123- struct DatapathPartialProductOpConversion
124- : OpConversionPattern<PartialProductOp> {
125- using OpConversionPattern<PartialProductOp>::OpConversionPattern;
119+ struct DatapathPartialProductOpConversion : OpRewritePattern<PartialProductOp> {
120+ using OpRewritePattern<PartialProductOp>::OpRewritePattern;
126121
127122 DatapathPartialProductOpConversion (MLIRContext *context, bool forceBooth)
128- : OpConversionPattern<PartialProductOp>(context),
129- forceBooth (forceBooth){};
123+ : OpRewritePattern<PartialProductOp>(context), forceBooth(forceBooth){};
130124
131125 const bool forceBooth;
132126
133- LogicalResult
134- matchAndRewrite (PartialProductOp op, OpAdaptor adaptor,
135- ConversionPatternRewriter &rewriter) const override {
127+ LogicalResult matchAndRewrite (PartialProductOp op,
128+ PatternRewriter &rewriter) const override {
136129
137130 Value a = op.getLhs ();
138131 Value b = op.getRhs ();
@@ -152,8 +145,8 @@ struct DatapathPartialProductOpConversion
152145 }
153146
154147private:
155- static LogicalResult lowerAndArray (ConversionPatternRewriter &rewriter,
156- Value a, Value b, PartialProductOp op,
148+ static LogicalResult lowerAndArray (PatternRewriter &rewriter, Value a ,
149+ Value b, PartialProductOp op,
157150 unsigned width) {
158151
159152 Location loc = op.getLoc ();
@@ -179,8 +172,8 @@ struct DatapathPartialProductOpConversion
179172 return success ();
180173 }
181174
182- static LogicalResult lowerBoothArray (ConversionPatternRewriter &rewriter,
183- Value a, Value b, PartialProductOp op,
175+ static LogicalResult lowerBoothArray (PatternRewriter &rewriter, Value a ,
176+ Value b, PartialProductOp op,
184177 unsigned width) {
185178 Location loc = op.getLoc ();
186179 auto zeroFalse = hw::ConstantOp::create (rewriter, loc, APInt (1 , 0 ));
@@ -288,42 +281,40 @@ struct ConvertDatapathToCombPass
288281};
289282} // namespace
290283
291- static LogicalResult
292- applyConversionWithTimingInfo (Operation *op, const ConversionTarget &target,
293- RewritePatternSet &&patterns,
294- aig::IncrementalLongestPathAnalysis *analysis) {
284+ static LogicalResult applyPatternsGreedilyWithTimingInfo (
285+ Operation *op, RewritePatternSet &&patterns,
286+ aig::IncrementalLongestPathAnalysis *analysis) {
295287 // TODO: Topologically sort the operations in the module to ensure that all
296288 // dependencies are processed before their users.
297- mlir::ConversionConfig config;
298- config.listener = analysis;
299-
300- // Apply the conversion patterns
301- if (failed (mlir::applyPartialConversion (op, target, std::move (patterns))))
289+ mlir::GreedyRewriteConfig config;
290+ // Set the listener to update timing information
291+ // HACK: Setting max iterations to 2 to ensure that the patterns are one-shot,
292+ // making sure target operations are datapath operations are replaced.
293+ config.setMaxIterations (2 ).setListener (analysis).setUseTopDownTraversal (true );
294+
295+ // Apply the patterns greedily
296+ if (failed (mlir::applyPatternsGreedily (op, std::move (patterns), config)))
302297 return failure ();
303298
304299 return success ();
305300}
306301
307302void ConvertDatapathToCombPass::runOnOperation () {
308- ConversionTarget target (getContext ());
309-
310- target.addLegalDialect <comb::CombDialect, hw::HWDialect>();
311- target.addIllegalDialect <DatapathDialect>();
312-
313303 RewritePatternSet patterns (&getContext ());
314304
315305 patterns.add <DatapathPartialProductOpConversion>(patterns.getContext (),
316306 forceBooth);
317- auto &analysis = getAnalysis<aig::IncrementalLongestPathAnalysis>();
307+ aig::IncrementalLongestPathAnalysis *analysis = nullptr ;
308+ if (timingAware)
309+ analysis = &getAnalysis<aig::IncrementalLongestPathAnalysis>();
318310 if (lowerCompressToAdd)
319- // Lower compressors to simple add operations for downstream optimisations
311+ // Lower compressors to simple add operations for downstream optimizations
320312 patterns.add <DatapathCompressOpAddConversion>(patterns.getContext ());
321313 else
322314 // Lower compressors to a complete gate-level implementation
323- patterns.add <DatapathCompressOpConversion>(patterns.getContext (),
324- &analysis);
315+ patterns.add <DatapathCompressOpConversion>(patterns.getContext (), analysis);
325316
326- if (failed (applyConversionWithTimingInfo ( getOperation (), target,
327- std::move (patterns), & analysis)))
317+ if (failed (applyPatternsGreedilyWithTimingInfo (
318+ getOperation (), std::move (patterns), analysis)))
328319 return signalPassFailure ();
329320}
0 commit comments