Skip to content

Commit 959e0f1

Browse files
committed
[FIRRTL] FlattenMemories: code cleanup, nfci
1 parent b4da671 commit 959e0f1

File tree

1 file changed

+50
-38
lines changed

1 file changed

+50
-38
lines changed

lib/Dialect/FIRRTL/Transforms/FlattenMemory.cpp

Lines changed: 50 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -34,93 +34,107 @@ using namespace firrtl;
3434
namespace {
3535
struct FlattenMemoryPass
3636
: public circt::firrtl::impl::FlattenMemoryBase<FlattenMemoryPass> {
37+
38+
/// Returns true if the the memory has annotations on a subfield of any of the
39+
/// ports.
40+
static bool hasSubAnno(MemOp op) {
41+
for (size_t portIdx = 0, e = op.getNumResults(); portIdx < e; ++portIdx)
42+
for (auto attr : op.getPortAnnotation(portIdx))
43+
if (cast<DictionaryAttr>(attr).get("circt.fieldID"))
44+
return true;
45+
return false;
46+
};
47+
3748
/// This pass flattens the aggregate data of memory into a UInt, and inserts
3849
/// appropriate bitcasts to access the data.
3950
void runOnOperation() override {
4051
LLVM_DEBUG(llvm::dbgs() << "\n Running lower memory on module:"
4152
<< getOperation().getName());
4253
SmallVector<Operation *> opsToErase;
43-
auto hasSubAnno = [&](MemOp op) -> bool {
44-
for (size_t portIdx = 0, e = op.getNumResults(); portIdx < e; ++portIdx)
45-
for (auto attr : op.getPortAnnotation(portIdx))
46-
if (cast<DictionaryAttr>(attr).get("circt.fieldID"))
47-
return true;
48-
49-
return false;
50-
};
5154
getOperation().getBodyBlock()->walk([&](MemOp memOp) {
5255
LLVM_DEBUG(llvm::dbgs() << "\n Memory:" << memOp);
53-
// The vector of leaf elements type after flattening the data.
54-
SmallVector<IntType> flatMemType;
55-
// MaskGranularity : how many bits each mask bit controls.
56-
size_t maskGran = 1;
57-
// Total mask bitwidth after flattening.
58-
uint32_t totalmaskWidths = 0;
59-
// How many mask bits each field type requires.
60-
SmallVector<unsigned> maskWidths;
6156

6257
// Cannot flatten a memory if it has debug ports, because debug port
6358
// implies a memtap and we cannot transform the datatype for a memory that
6459
// is tapped.
6560
for (auto res : memOp.getResults())
6661
if (isa<RefType>(res.getType()))
6762
return;
63+
6864
// If subannotations present on aggregate fields, we cannot flatten the
6965
// memory. It must be split into one memory per aggregate field.
7066
// Do not overwrite the pass flag!
71-
if (hasSubAnno(memOp) || !flattenType(memOp.getDataType(), flatMemType))
67+
if (hasSubAnno(memOp))
7268
return;
7369

74-
SmallVector<Operation *, 8> flatData;
75-
SmallVector<int32_t> memWidths;
70+
// The vector of leaf elements type after flattening the data. If any of
71+
// the datatypes cannot be flattened, then we cannot flatten the memory.
72+
SmallVector<FIRRTLBaseType> flatMemType;
73+
if (!flattenType(memOp.getDataType(), flatMemType))
74+
return;
75+
76+
// Calculate the width of the memory data type, and the width of
77+
// each individual aggregate leaf elements.
7678
size_t memFlatWidth = 0;
77-
// Get the width of individual aggregate leaf elements.
79+
SmallVector<int32_t> memWidths;
7880
for (auto f : flatMemType) {
7981
LLVM_DEBUG(llvm::dbgs() << "\n field type:" << f);
80-
auto w = *f.getWidth();
82+
auto w = f.getBitWidthOrSentinel();
8183
memWidths.push_back(w);
8284
memFlatWidth += w;
8385
}
8486
// If all the widths are zero, ignore the memory.
8587
if (!memFlatWidth)
8688
return;
87-
maskGran = memWidths[0];
88-
// Compute the GCD of all data bitwidths.
89-
for (auto w : memWidths) {
89+
90+
// Calculate the mask granularity of this memory, which is how many bits
91+
// of the data each mask bit controls. This is the greatest common
92+
// denominator of the widths of the flattened data types.
93+
auto maskGran = memWidths.front();
94+
for (auto w : ArrayRef(memWidths).drop_front())
9095
maskGran = std::gcd(maskGran, w);
91-
}
96+
97+
// Total mask bitwidth after flattening.
98+
uint32_t totalmaskWidths = 0;
99+
// How many mask bits each field type requires.
100+
SmallVector<unsigned> maskWidths;
92101
for (auto w : memWidths) {
93102
// How many mask bits required for each flattened field.
94103
auto mWidth = w / maskGran;
95104
maskWidths.push_back(mWidth);
96105
totalmaskWidths += mWidth;
97106
}
107+
98108
// Now create a new memory of type flattened data.
99109
// ----------------------------------------------
100110
SmallVector<Type, 8> ports;
101111
SmallVector<Attribute, 8> portNames;
102112

103113
auto *context = memOp.getContext();
104114
ImplicitLocOpBuilder builder(memOp.getLoc(), memOp);
105-
// Create a new memoty data type of unsigned and computed width.
115+
// Create a new memory data type of unsigned and computed width.
106116
auto flatType = UIntType::get(context, memFlatWidth);
107-
auto opPorts = memOp.getPorts();
108-
for (size_t portIdx = 0, e = opPorts.size(); portIdx < e; ++portIdx) {
109-
auto port = opPorts[portIdx];
117+
for (auto port : memOp.getPorts()) {
110118
ports.push_back(MemOp::getTypeForPort(memOp.getDepth(), flatType,
111119
port.second, totalmaskWidths));
112120
portNames.push_back(port.first);
113121
}
114122

123+
// Create the new flattened memory.
115124
auto flatMem = builder.create<MemOp>(
116125
ports, memOp.getReadLatency(), memOp.getWriteLatency(),
117126
memOp.getDepth(), memOp.getRuw(), builder.getArrayAttr(portNames),
118127
memOp.getNameAttr(), memOp.getNameKind(), memOp.getAnnotations(),
119128
memOp.getPortAnnotations(), memOp.getInnerSymAttr(),
120129
memOp.getInitAttr(), memOp.getPrefixAttr());
130+
121131
// Hook up the new memory to the wires the old memory was replaced with.
122132
for (size_t index = 0, rend = memOp.getNumResults(); index < rend;
123133
++index) {
134+
135+
// Create a wire with the original type, and replace all uses of the old
136+
// memory with the wire. We will be reconstructing the original type
137+
// in the wire from the bitvector of the flattened memory.
124138
auto result = memOp.getResult(index);
125139
auto wire = builder
126140
.create<WireOp>(result.getType(),
@@ -134,7 +148,7 @@ struct FlattenMemoryPass
134148
auto rType = type_cast<BundleType>(result.getType());
135149
for (size_t fieldIndex = 0, fend = rType.getNumElements();
136150
fieldIndex != fend; ++fieldIndex) {
137-
auto name = rType.getElement(fieldIndex).name.getValue();
151+
auto name = rType.getElement(fieldIndex).name;
138152
auto oldField = builder.create<SubfieldOp>(result, fieldIndex);
139153
FIRRTLBaseValue newField =
140154
builder.create<SubfieldOp>(newResult, fieldIndex);
@@ -153,7 +167,6 @@ struct FlattenMemoryPass
153167
// Write the aggregate read data.
154168
emitConnect(builder, realOldField, castField);
155169
} else {
156-
// Cast the input aggregate write data to flat type.
157170
// Cast the input aggregate write data to flat type.
158171
auto newFieldType = newField.getType();
159172
auto oldFieldBitWidth = getBitWidth(oldField.getType());
@@ -197,10 +210,11 @@ struct FlattenMemoryPass
197210
}
198211

199212
private:
200-
// Convert an aggregate type into a flat list of fields.
201-
// This is used to flatten the aggregate memory datatype.
202-
// Recursively populate the results with each ground type field.
203-
static bool flattenType(FIRRTLType type, SmallVectorImpl<IntType> &results) {
213+
// Convert an aggregate type into a flat list of fields. This is used to
214+
// flatten the aggregate memory datatype. Recursively populate the results
215+
// with each ground type field.
216+
static bool flattenType(FIRRTLType type,
217+
SmallVectorImpl<FIRRTLBaseType> &results) {
204218
std::function<bool(FIRRTLType)> flatten = [&](FIRRTLType type) -> bool {
205219
return FIRRTLTypeSwitch<FIRRTLType, bool>(type)
206220
.Case<BundleType>([&](auto bundle) {
@@ -226,9 +240,7 @@ struct FlattenMemoryPass
226240
.Default([&](auto) { return false; });
227241
};
228242
// Return true only if this is an aggregate with more than one element.
229-
if (flatten(type) && results.size() > 1)
230-
return true;
231-
return false;
243+
return flatten(type) && results.size() > 1;
232244
}
233245

234246
Value getSubWhatever(ImplicitLocOpBuilder *builder, Value val, size_t index) {

0 commit comments

Comments
 (0)