@@ -34,93 +34,107 @@ using namespace firrtl;
3434namespace {
3535struct FlattenMemoryPass
3636 : public circt::firrtl::impl::FlattenMemoryBase<FlattenMemoryPass> {
37+
38+ // / Returns true if the the memory has annotations on a subfield of any of the
39+ // / ports.
40+ static bool hasSubAnno (MemOp op) {
41+ for (size_t portIdx = 0 , e = op.getNumResults (); portIdx < e; ++portIdx)
42+ for (auto attr : op.getPortAnnotation (portIdx))
43+ if (cast<DictionaryAttr>(attr).get (" circt.fieldID" ))
44+ return true ;
45+ return false ;
46+ };
47+
3748 // / This pass flattens the aggregate data of memory into a UInt, and inserts
3849 // / appropriate bitcasts to access the data.
3950 void runOnOperation () override {
4051 LLVM_DEBUG (llvm::dbgs () << " \n Running lower memory on module:"
4152 << getOperation ().getName ());
4253 SmallVector<Operation *> opsToErase;
43- auto hasSubAnno = [&](MemOp op) -> bool {
44- for (size_t portIdx = 0 , e = op.getNumResults (); portIdx < e; ++portIdx)
45- for (auto attr : op.getPortAnnotation (portIdx))
46- if (cast<DictionaryAttr>(attr).get (" circt.fieldID" ))
47- return true ;
48-
49- return false ;
50- };
5154 getOperation ().getBodyBlock ()->walk ([&](MemOp memOp) {
5255 LLVM_DEBUG (llvm::dbgs () << " \n Memory:" << memOp);
53- // The vector of leaf elements type after flattening the data.
54- SmallVector<IntType> flatMemType;
55- // MaskGranularity : how many bits each mask bit controls.
56- size_t maskGran = 1 ;
57- // Total mask bitwidth after flattening.
58- uint32_t totalmaskWidths = 0 ;
59- // How many mask bits each field type requires.
60- SmallVector<unsigned > maskWidths;
6156
6257 // Cannot flatten a memory if it has debug ports, because debug port
6358 // implies a memtap and we cannot transform the datatype for a memory that
6459 // is tapped.
6560 for (auto res : memOp.getResults ())
6661 if (isa<RefType>(res.getType ()))
6762 return ;
63+
6864 // If subannotations present on aggregate fields, we cannot flatten the
6965 // memory. It must be split into one memory per aggregate field.
7066 // Do not overwrite the pass flag!
71- if (hasSubAnno (memOp) || ! flattenType (memOp. getDataType (), flatMemType) )
67+ if (hasSubAnno (memOp))
7268 return ;
7369
74- SmallVector<Operation *, 8 > flatData;
75- SmallVector<int32_t > memWidths;
70+ // The vector of leaf elements type after flattening the data. If any of
71+ // the datatypes cannot be flattened, then we cannot flatten the memory.
72+ SmallVector<FIRRTLBaseType> flatMemType;
73+ if (!flattenType (memOp.getDataType (), flatMemType))
74+ return ;
75+
76+ // Calculate the width of the memory data type, and the width of
77+ // each individual aggregate leaf elements.
7678 size_t memFlatWidth = 0 ;
77- // Get the width of individual aggregate leaf elements.
79+ SmallVector< int32_t > memWidths;
7880 for (auto f : flatMemType) {
7981 LLVM_DEBUG (llvm::dbgs () << " \n field type:" << f);
80- auto w = *f. getWidth ();
82+ auto w = f. getBitWidthOrSentinel ();
8183 memWidths.push_back (w);
8284 memFlatWidth += w;
8385 }
8486 // If all the widths are zero, ignore the memory.
8587 if (!memFlatWidth)
8688 return ;
87- maskGran = memWidths[0 ];
88- // Compute the GCD of all data bitwidths.
89- for (auto w : memWidths) {
89+
90+ // Calculate the mask granularity of this memory, which is how many bits
91+ // of the data each mask bit controls. This is the greatest common
92+ // denominator of the widths of the flattened data types.
93+ auto maskGran = memWidths.front ();
94+ for (auto w : ArrayRef (memWidths).drop_front ())
9095 maskGran = std::gcd (maskGran, w);
91- }
96+
97+ // Total mask bitwidth after flattening.
98+ uint32_t totalmaskWidths = 0 ;
99+ // How many mask bits each field type requires.
100+ SmallVector<unsigned > maskWidths;
92101 for (auto w : memWidths) {
93102 // How many mask bits required for each flattened field.
94103 auto mWidth = w / maskGran;
95104 maskWidths.push_back (mWidth );
96105 totalmaskWidths += mWidth ;
97106 }
107+
98108 // Now create a new memory of type flattened data.
99109 // ----------------------------------------------
100110 SmallVector<Type, 8 > ports;
101111 SmallVector<Attribute, 8 > portNames;
102112
103113 auto *context = memOp.getContext ();
104114 ImplicitLocOpBuilder builder (memOp.getLoc (), memOp);
105- // Create a new memoty data type of unsigned and computed width.
115+ // Create a new memory data type of unsigned and computed width.
106116 auto flatType = UIntType::get (context, memFlatWidth);
107- auto opPorts = memOp.getPorts ();
108- for (size_t portIdx = 0 , e = opPorts.size (); portIdx < e; ++portIdx) {
109- auto port = opPorts[portIdx];
117+ for (auto port : memOp.getPorts ()) {
110118 ports.push_back (MemOp::getTypeForPort (memOp.getDepth (), flatType,
111119 port.second , totalmaskWidths));
112120 portNames.push_back (port.first );
113121 }
114122
123+ // Create the new flattened memory.
115124 auto flatMem = builder.create <MemOp>(
116125 ports, memOp.getReadLatency (), memOp.getWriteLatency (),
117126 memOp.getDepth (), memOp.getRuw (), builder.getArrayAttr (portNames),
118127 memOp.getNameAttr (), memOp.getNameKind (), memOp.getAnnotations (),
119128 memOp.getPortAnnotations (), memOp.getInnerSymAttr (),
120129 memOp.getInitAttr (), memOp.getPrefixAttr ());
130+
121131 // Hook up the new memory to the wires the old memory was replaced with.
122132 for (size_t index = 0 , rend = memOp.getNumResults (); index < rend;
123133 ++index) {
134+
135+ // Create a wire with the original type, and replace all uses of the old
136+ // memory with the wire. We will be reconstructing the original type
137+ // in the wire from the bitvector of the flattened memory.
124138 auto result = memOp.getResult (index);
125139 auto wire = builder
126140 .create <WireOp>(result.getType (),
@@ -134,7 +148,7 @@ struct FlattenMemoryPass
134148 auto rType = type_cast<BundleType>(result.getType ());
135149 for (size_t fieldIndex = 0 , fend = rType.getNumElements ();
136150 fieldIndex != fend; ++fieldIndex) {
137- auto name = rType.getElement (fieldIndex).name . getValue () ;
151+ auto name = rType.getElement (fieldIndex).name ;
138152 auto oldField = builder.create <SubfieldOp>(result, fieldIndex);
139153 FIRRTLBaseValue newField =
140154 builder.create <SubfieldOp>(newResult, fieldIndex);
@@ -153,7 +167,6 @@ struct FlattenMemoryPass
153167 // Write the aggregate read data.
154168 emitConnect (builder, realOldField, castField);
155169 } else {
156- // Cast the input aggregate write data to flat type.
157170 // Cast the input aggregate write data to flat type.
158171 auto newFieldType = newField.getType ();
159172 auto oldFieldBitWidth = getBitWidth (oldField.getType ());
@@ -197,10 +210,11 @@ struct FlattenMemoryPass
197210 }
198211
199212private:
200- // Convert an aggregate type into a flat list of fields.
201- // This is used to flatten the aggregate memory datatype.
202- // Recursively populate the results with each ground type field.
203- static bool flattenType (FIRRTLType type, SmallVectorImpl<IntType> &results) {
213+ // Convert an aggregate type into a flat list of fields. This is used to
214+ // flatten the aggregate memory datatype. Recursively populate the results
215+ // with each ground type field.
216+ static bool flattenType (FIRRTLType type,
217+ SmallVectorImpl<FIRRTLBaseType> &results) {
204218 std::function<bool (FIRRTLType)> flatten = [&](FIRRTLType type) -> bool {
205219 return FIRRTLTypeSwitch<FIRRTLType, bool >(type)
206220 .Case <BundleType>([&](auto bundle) {
@@ -226,9 +240,7 @@ struct FlattenMemoryPass
226240 .Default ([&](auto ) { return false ; });
227241 };
228242 // Return true only if this is an aggregate with more than one element.
229- if (flatten (type) && results.size () > 1 )
230- return true ;
231- return false ;
243+ return flatten (type) && results.size () > 1 ;
232244 }
233245
234246 Value getSubWhatever (ImplicitLocOpBuilder *builder, Value val, size_t index) {
0 commit comments