Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 61 additions & 60 deletions include/circt/Dialect/Synth/Transforms/CutRewriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/raw_ostream.h"
#include <memory>
#include <optional>
Expand Down Expand Up @@ -58,6 +59,41 @@ class CutRewriter;
struct CutRewritePattern;
struct CutRewriterOptions;

/// Represents a cut that has been successfully matched to a rewriting pattern.
///
/// This class encapsulates the result of matching a cut against a rewriting
/// pattern during optimization. It stores the matched pattern, the
/// cut that was matched, and timing information needed for optimization.
class MatchedPattern {
private:
const CutRewritePattern *pattern = nullptr; ///< The matched library pattern
SmallVector<DelayType, 1>
arrivalTimes; ///< Arrival times of outputs from this pattern

public:
/// Default constructor creates an invalid matched pattern.
MatchedPattern() = default;

/// Constructor for a valid matched pattern.
MatchedPattern(const CutRewritePattern *pattern,
SmallVector<DelayType, 1> arrivalTimes)
: pattern(pattern), arrivalTimes(std::move(arrivalTimes)) {}

/// Get the arrival time of signals through this pattern.
DelayType getArrivalTime(unsigned outputIndex) const;
ArrayRef<DelayType> getArrivalTimes() const;
DelayType getWorstOutputArrivalTime() const;

/// Get the library pattern that was matched.
const CutRewritePattern *getPattern() const;

/// Get the area cost of using this pattern.
double getArea() const;

/// Get the delay between specific input and output pins.
DelayType getDelay(unsigned inputIndex, unsigned outputIndex) const;
};

/// Represents a cut in the combinational logic network.
///
/// A cut is a subset of nodes in the combinational logic that forms a complete
Expand All @@ -80,7 +116,7 @@ class Cut {
/// Computed lazily from the truth table when first accessed.
mutable std::optional<NPNClass> npnClass;

unsigned depth = 0; ///< Depth of this cut in the logic network
std::optional<MatchedPattern> matchedPattern;

public:
/// External inputs to this cut (cut boundary).
Expand All @@ -95,9 +131,6 @@ class Cut {
/// A trivial cut has no internal operations and exactly one input.
bool isTrivialCut() const;

/// Get the root operation of this cut.
unsigned getDepth() const;

/// Get the root operation of this cut.
/// The root operation produces the output of the cut.
mlir::Operation *getRoot() const;
Expand Down Expand Up @@ -129,44 +162,16 @@ class Cut {
/// Get the permutated inputs for this cut based on the given pattern NPN.
void getPermutatedInputs(const NPNClass &patternNPN,
SmallVectorImpl<Value> &permutedInputs) const;
};

/// Represents a cut that has been successfully matched to a rewriting pattern.
///
/// This class encapsulates the result of matching a cut against a rewriting
/// pattern during optimization. It stores the matched pattern, the
/// cut that was matched, and timing information needed for optimization.
class MatchedPattern {
private:
const CutRewritePattern *pattern = nullptr; ///< The matched library pattern
Cut *cut = nullptr; ///< The cut that was matched
SmallVector<DelayType, 2>
arrivalTimes; ///< Arrival times of outputs from this pattern

public:
/// Default constructor creates an invalid matched pattern.
MatchedPattern() = default;

/// Constructor for a valid matched pattern.
MatchedPattern(const CutRewritePattern *pattern, Cut *cut,
SmallVector<DelayType, 2> arrivalTimes)
: pattern(pattern), cut(cut), arrivalTimes(std::move(arrivalTimes)) {}

/// Get the arrival time of signals through this pattern.
DelayType getArrivalTime(unsigned outputIndex) const;
ArrayRef<DelayType> getArrivalTimes() const;

/// Get the library pattern that was matched.
const CutRewritePattern *getPattern() const;

/// Get the cut that was matched to the pattern.
Cut *getCut() const;

/// Get the area cost of using this pattern.
double getArea() const;
/// Matched pattern for this cut.
void setMatchedPattern(MatchedPattern pattern) {
matchedPattern = std::move(pattern);
}

/// Get the delay between specific input and output pins.
DelayType getDelay(unsigned inputIndex, unsigned outputIndex) const;
/// Get the matched pattern for this cut.
const std::optional<MatchedPattern> &getMatchedPattern() const {
return matchedPattern;
}
};

/// Manages a collection of cuts for a single logic node using priority cuts
Expand All @@ -183,19 +188,15 @@ class MatchedPattern {
class CutSet {
private:
llvm::SmallVector<Cut, 4> cuts; ///< Collection of cuts for this node
std::optional<MatchedPattern> matchedPattern; ///< Best matched pattern found
Cut *bestCut = nullptr;
bool isFrozen = false; ///< Whether cut set is finalized

public:
/// Get the best matched pattern for this cut set.
std::optional<MatchedPattern> getBestMatchedPattern() const;

/// Check if this cut set has a valid matched pattern.
bool isMatched() const;
bool isMatched() const { return bestCut; }

/// Get the cut associated with the best matched pattern.
/// NOTE: isMatched() must be true
Cut *getMatchedCut();
Cut *getBestMatchedCut() const;

/// Finalize the cut set by removing duplicates and selecting the best
/// pattern.
Expand All @@ -205,9 +206,9 @@ class CutSet {
/// 2. Limits the number of cuts to prevent exponential growth
/// 3. Matches each cut against available patterns
/// 4. Selects the best pattern based on the optimization strategy
void
finalize(const CutRewriterOptions &options,
llvm::function_ref<std::optional<MatchedPattern>(Cut &)> matchCut);
void finalize(
const CutRewriterOptions &options,
llvm::function_ref<std::optional<MatchedPattern>(const Cut &)> matchCut);

/// Get the number of cuts in this set.
unsigned size() const;
Expand Down Expand Up @@ -243,6 +244,9 @@ struct CutRewriterOptions {

/// Put arrival times to rewritten operations.
bool attachDebugTiming = false;

/// Run priority cuts enumeration and dump the cut sets.
bool testPriorityCuts = false;
};

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -272,10 +276,8 @@ class CutEnumerator {
/// for combinational logic operations.
LogicalResult enumerateCuts(
Operation *topOp,
llvm::function_ref<std::optional<MatchedPattern>(Cut &)> matchCut =
[](Cut &) {
return std::nullopt; // Default no-op matcher
});
llvm::function_ref<std::optional<MatchedPattern>(const Cut &)> matchCut =
[](const Cut &) { return std::nullopt; });

/// Create a new cut set for a value.
/// The value must not already have a cut set.
Expand All @@ -293,6 +295,8 @@ class CutEnumerator {
/// Clear all cut sets and reset the enumerator.
void clear();

void dump() const;

private:
/// Visit a single operation and generate cuts for it.
LogicalResult visit(Operation *op);
Expand All @@ -310,7 +314,7 @@ class CutEnumerator {

/// Function to match cuts against available patterns.
/// Set during enumeration and used when finalizing cut sets.
llvm::function_ref<std::optional<MatchedPattern>(Cut &)> matchCut;
llvm::function_ref<std::optional<MatchedPattern>(const Cut &)> matchCut;
};

/// Base class for cut rewriting patterns used in combinational logic
Expand Down Expand Up @@ -360,17 +364,14 @@ struct CutRewritePattern {
Cut &cut) const = 0;

/// Get the area cost of this pattern.
virtual double getArea(const Cut &cut) const = 0;
virtual double getArea() const = 0;

/// Get the delay between specific input and output.
/// NOTE: The input index is already permuted according to the pattern's
/// input permutation, so it's not necessary to account for it here.
virtual DelayType getDelay(unsigned inputIndex,
unsigned outputIndex) const = 0;

/// Get the number of inputs this pattern expects.
virtual unsigned getNumInputs() const = 0;

/// Get the number of outputs this pattern produces.
virtual unsigned getNumOutputs() const = 0;

Expand Down Expand Up @@ -478,7 +479,7 @@ class CutRewriter {
getMatchingPatternsFromTruthTable(const Cut &cut) const;

/// Match a cut against available patterns and compute arrival time.
std::optional<MatchedPattern> patternMatchCut(Cut &cut);
std::optional<MatchedPattern> patternMatchCut(const Cut &cut);

/// Perform the actual circuit rewriting using selected patterns.
LogicalResult runBottomUpRewrite(Operation *topOp);
Expand Down
Loading