Skip to content

Commit d087c2f

Browse files
[mlir][Transforms] Make lookup without type converter unambiguous
1 parent 16a0892 commit d087c2f

File tree

5 files changed

+220
-69
lines changed

5 files changed

+220
-69
lines changed

mlir/docs/DialectConversion.md

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -202,17 +202,29 @@ struct MyConversionPattern : public ConversionPattern {
202202
203203
#### Type Safety
204204
205-
The types of the remapped operands provided to a conversion pattern must be of a
206-
type expected by the pattern. The expected types of a pattern are determined by
207-
a provided [TypeConverter](#type-converter). If no type converter is provided,
208-
the types of the remapped operands are expected to match the types of the
209-
original operands. If a type converter is provided, the types of the remapped
210-
operands are expected to be legal as determined by the converter. If the
211-
remapped operand types are not of an expected type, and a materialization to the
212-
expected type could not be performed, the pattern fails application before the
213-
`matchAndRewrite` hook is invoked. This ensures that patterns do not have to
214-
explicitly ensure type safety, or sanitize the types of the incoming remapped
215-
operands. More information on type conversion is detailed in the
205+
The types of the remapped operands provided to a conversion pattern (through
206+
the adaptor or `ArrayRef` of operands) depend on type conversio rules.
207+
208+
If the pattern was initialized with a [type converter](#type-converter), the
209+
conversion driver passes values whose types match the legalized types of the
210+
operands of the matched operation as per the type converter. To that end, the
211+
conversion driver may insert target materializations to convert the most
212+
recently mapped values to the expected legalized types. The driver tries to
213+
reuse existing materializations on a best-effort basis, but this is not
214+
guaranteed by the infrastructure. If the operand types of the matched op could
215+
not be legalized, the pattern fails to apply before the `matchAndRewrite` hook
216+
is invoked.
217+
218+
If the pattern was initialized without a type converter, the conversion driver
219+
passes the most recently mapped values to the pattern, excluding any
220+
materializations. Materializations are intentionally excluded because their
221+
presence may depend on other patterns. If a value of the same type as an
222+
operand is desired, users can directly take the respective operand from the
223+
matched operation.
224+
225+
The above rules ensure that patterns do not have to explicitly ensure type
226+
safety, or sanitize the types of the incoming remapped operands. More
227+
information on type conversion is detailed in the
216228
[dedicated section](#type-conversion) below.
217229
218230
## Type Conversion

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 129 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -131,16 +131,25 @@ struct ConversionValueMapping {
131131
/// recently mapped values.
132132
/// - If there is no mapping for the given values at all, return the given
133133
/// value.
134-
ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}) const;
134+
///
135+
/// If `skipMaterializations` is true, materializations are not considered.
136+
ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {},
137+
bool skipMaterializations = false) const;
138+
139+
/// Lookup a value from the mapping. (Just once, not following the chain of
140+
/// potential mappings.) Look for actual replacements first, then for
141+
/// materializations. The materializations lookup can be skipped.
142+
ValueVector lookupSingleStep(const ValueVector &from,
143+
bool skipMaterializations = false) const;
135144

136145
template <typename T>
137146
struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {};
138147

139148
/// Map a value vector to the one provided.
140149
template <typename OldVal, typename NewVal>
141150
std::enable_if_t<IsValueVector<OldVal>::value && IsValueVector<NewVal>::value>
142-
map(OldVal &&oldVal, NewVal &&newVal) {
143-
LLVM_DEBUG({
151+
map(OldVal &&oldVal, NewVal &&newVal, bool isOnlyTypeConversion = false) {
152+
auto checkCircularMapping = [&](auto &mapping) {
144153
ValueVector next(newVal);
145154
while (true) {
146155
assert(next != oldVal && "inserting cyclic mapping");
@@ -149,45 +158,117 @@ struct ConversionValueMapping {
149158
break;
150159
next = it->second;
151160
}
152-
});
161+
};
162+
(void)checkCircularMapping;
163+
153164
mappedTo.insert_range(newVal);
154165

155-
mapping[std::forward<OldVal>(oldVal)] = std::forward<NewVal>(newVal);
166+
if (isOnlyTypeConversion) {
167+
// This is a materialization.
168+
LLVM_DEBUG({ checkCircularMapping(materializations); });
169+
materializations[std::forward<OldVal>(oldVal)] =
170+
std::forward<NewVal>(newVal);
171+
} else {
172+
// This is a regular value replacement.
173+
LLVM_DEBUG({ checkCircularMapping(mapping); });
174+
mapping[std::forward<OldVal>(oldVal)] = std::forward<NewVal>(newVal);
175+
}
156176
}
157177

158178
/// Map a value vector or single value to the one provided.
159179
template <typename OldVal, typename NewVal>
160180
std::enable_if_t<!IsValueVector<OldVal>::value ||
161181
!IsValueVector<NewVal>::value>
162-
map(OldVal &&oldVal, NewVal &&newVal) {
182+
map(OldVal &&oldVal, NewVal &&newVal, bool isOnlyTypeConversion = false) {
163183
if constexpr (IsValueVector<OldVal>{}) {
164-
map(std::forward<OldVal>(oldVal), ValueVector{newVal});
184+
map(std::forward<OldVal>(oldVal), ValueVector{newVal},
185+
isOnlyTypeConversion);
165186
} else if constexpr (IsValueVector<NewVal>{}) {
166-
map(ValueVector{oldVal}, std::forward<NewVal>(newVal));
187+
map(ValueVector{oldVal}, std::forward<NewVal>(newVal),
188+
isOnlyTypeConversion);
167189
} else {
168-
map(ValueVector{oldVal}, ValueVector{newVal});
190+
map(ValueVector{oldVal}, ValueVector{newVal}, isOnlyTypeConversion);
169191
}
170192
}
171193

172-
void map(Value oldVal, SmallVector<Value> &&newVal) {
194+
void map(Value oldVal, SmallVector<Value> &&newVal,
195+
bool isOnlyTypeConversion = false) {
173196
map(ValueVector{oldVal}, ValueVector(std::move(newVal)));
174197
}
175198

176199
/// Drop the last mapping for the given values.
177-
void erase(const ValueVector &value) { mapping.erase(value); }
200+
void erase(const ValueVector &value) {
201+
mapping.erase(value);
202+
materializations.erase(value);
203+
}
178204

179205
private:
180-
/// Current value mappings.
206+
/// Mapping of actual replacements.
181207
DenseMap<ValueVector, ValueVector, ValueVectorMapInfo> mapping;
182208

209+
/// Mapping of materializations that are created only to resolve type
210+
/// mismatches.
211+
DenseMap<ValueVector, ValueVector, ValueVectorMapInfo> materializations;
212+
183213
/// All SSA values that are mapped to. May contain false positives.
184214
DenseSet<Value> mappedTo;
185215
};
186216
} // namespace
187217

188218
ValueVector
189-
ConversionValueMapping::lookupOrDefault(Value from,
190-
TypeRange desiredTypes) const {
219+
ConversionValueMapping::lookupSingleStep(const ValueVector &from,
220+
bool skipMaterializations) const {
221+
// Continue the lookup on each value separately. (Each value could have been
222+
// mapped to one or multiple other values.)
223+
ValueVector next;
224+
for (Value v : from) {
225+
// First check regular value replacements.
226+
auto it = mapping.find({v});
227+
if (it != mapping.end()) {
228+
llvm::append_range(next, it->second);
229+
continue;
230+
}
231+
if (skipMaterializations) {
232+
next.push_back(v);
233+
continue;
234+
}
235+
// Then check materializations.
236+
it = materializations.find({v});
237+
if (it != materializations.end()) {
238+
llvm::append_range(next, it->second);
239+
continue;
240+
}
241+
next.push_back(v);
242+
}
243+
244+
if (next != from)
245+
return next;
246+
247+
// Otherwise: Check if there is a mapping for the entire vector. Such
248+
// mappings are materializations. (N:M mapping are not supported for value
249+
// replacements.)
250+
//
251+
// Note: From a correctness point of view, materializations do not have to
252+
// be stored (and looked up) in the mapping. But for performance reasons,
253+
// we choose to reuse existing IR (when possible) instead of creating it
254+
// multiple times.
255+
//
256+
// First check regular value replacements.
257+
auto it = mapping.find(from);
258+
if (it != mapping.end())
259+
return it->second;
260+
if (skipMaterializations)
261+
return {};
262+
// Then check materializations.
263+
it = materializations.find(from);
264+
if (it != materializations.end())
265+
return it->second;
266+
return {};
267+
}
268+
269+
ValueVector
270+
ConversionValueMapping::lookupOrDefault(Value from, TypeRange desiredTypes,
271+
bool skipMaterializations) const {
191272
// Try to find the deepest values that have the desired types. If there is no
192273
// such mapping, simply return the deepest values.
193274
ValueVector desiredValue;
@@ -197,36 +278,13 @@ ConversionValueMapping::lookupOrDefault(Value from,
197278
if (TypeRange(ValueRange(current)) == desiredTypes)
198279
desiredValue = current;
199280

200-
// If possible, Replace each value with (one or multiple) mapped values.
201-
ValueVector next;
202-
for (Value v : current) {
203-
auto it = mapping.find({v});
204-
if (it != mapping.end()) {
205-
llvm::append_range(next, it->second);
206-
} else {
207-
next.push_back(v);
208-
}
209-
}
210-
if (next != current) {
211-
// If at least one value was replaced, continue the lookup from there.
212-
current = std::move(next);
213-
continue;
214-
}
215-
216-
// Otherwise: Check if there is a mapping for the entire vector. Such
217-
// mappings are materializations. (N:M mapping are not supported for value
218-
// replacements.)
219-
//
220-
// Note: From a correctness point of view, materializations do not have to
221-
// be stored (and looked up) in the mapping. But for performance reasons,
222-
// we choose to reuse existing IR (when possible) instead of creating it
223-
// multiple times.
224-
auto it = mapping.find(current);
225-
if (it == mapping.end()) {
281+
ValueVector next = lookupSingleStep(current, skipMaterializations);
282+
if (next.empty()) {
226283
// No mapping found: The lookup stops here.
227284
break;
228285
}
229-
current = it->second;
286+
287+
current = std::move(next);
230288
} while (true);
231289

232290
// If the desired values were found use them, otherwise default to the leaf
@@ -930,7 +988,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
930988
/// recently mapped values.
931989
/// - If there is no mapping for the given values at all, return the given
932990
/// value.
933-
ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}) const;
991+
///
992+
/// If `skipMaterializations` is true, materializations are not considered.
993+
ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {},
994+
bool skipMaterializations = false) const;
934995

935996
/// Lookup the given value within the map, or return an empty vector if the
936997
/// value is not mapped. If it is mapped, this follows the same behavior
@@ -993,11 +1054,18 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
9931054
/// If `valuesToMap` is set to a non-null Value, then that value is mapped to
9941055
/// the results of the unresolved materialization in the conversion value
9951056
/// mapping.
1057+
///
1058+
/// If `isOnlyTypeConversion` is "true", the materialization is created to
1059+
/// resolve a type mismatch, and not a regular value replacement issued by
1060+
/// the user. (Replacement values that are created "out of thin air" are
1061+
/// treated appear like unresolved materializations, but are not just type
1062+
/// conversions.)
9961063
ValueRange buildUnresolvedMaterialization(
9971064
MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
9981065
ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
9991066
Type originalType, const TypeConverter *converter,
1000-
UnrealizedConversionCastOp *castOp = nullptr);
1067+
UnrealizedConversionCastOp *castOp = nullptr,
1068+
bool isOnlyTypeConversion = true);
10011069

10021070
/// Find a replacement value for the given SSA value in the conversion value
10031071
/// mapping. The replacement value must have the same type as the given SSA
@@ -1264,10 +1332,9 @@ void ConversionPatternRewriterImpl::applyRewrites() {
12641332
// State Management
12651333
//===----------------------------------------------------------------------===//
12661334

1267-
ValueVector
1268-
ConversionPatternRewriterImpl::lookupOrDefault(Value from,
1269-
TypeRange desiredTypes) const {
1270-
return mapping.lookupOrDefault(from, desiredTypes);
1335+
ValueVector ConversionPatternRewriterImpl::lookupOrDefault(
1336+
Value from, TypeRange desiredTypes, bool skipMaterializations) const {
1337+
return mapping.lookupOrDefault(from, desiredTypes, skipMaterializations);
12711338
}
12721339

12731340
ValueVector
@@ -1324,10 +1391,13 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
13241391
Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
13251392

13261393
if (!currentTypeConverter) {
1327-
// The current pattern does not have a type converter. I.e., it does not
1328-
// distinguish between legal and illegal types. For each operand, simply
1329-
// pass through the most recently mapped values.
1330-
remapped.push_back(lookupOrDefault(operand));
1394+
// The current pattern does not have a type converter. Pass the most
1395+
// recently mapped values, excluding materializations. Materializations
1396+
// are intentionally excluded because their presence may depend on other
1397+
// patterns. Including materializations would make the lookup fragile
1398+
// and unpredictable.
1399+
remapped.push_back(lookupOrDefault(operand, /*desiredTypes=*/{},
1400+
/*skipMaterializations=*/true));
13311401
continue;
13321402
}
13331403

@@ -1356,7 +1426,8 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
13561426
}
13571427

13581428
// Create a materialization for the most recently mapped values.
1359-
repl = lookupOrDefault(operand);
1429+
repl = lookupOrDefault(operand, /*desiredTypes=*/{},
1430+
/*skipMaterializations=*/true);
13601431
ValueRange castValues = buildUnresolvedMaterialization(
13611432
MaterializationKind::Target, computeInsertPoint(repl), operandLoc,
13621433
/*valuesToMap=*/repl, /*inputs=*/repl, /*outputTypes=*/legalTypes,
@@ -1482,7 +1553,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
14821553
OpBuilder::InsertPoint(newBlock, newBlock->begin()),
14831554
origArg.getLoc(),
14841555
/*valuesToMap=*/{}, /*inputs=*/ValueRange(),
1485-
/*outputTypes=*/origArgType, /*originalType=*/Type(), converter)
1556+
/*outputTypes=*/origArgType, /*originalType=*/Type(), converter,
1557+
/*castOp=*/nullptr, /*isOnlyTypeConversion=*/false)
14861558
.front();
14871559
replaceUsesOfBlockArgument(origArg, mat, converter);
14881560
continue;
@@ -1523,7 +1595,7 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
15231595
MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
15241596
ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
15251597
Type originalType, const TypeConverter *converter,
1526-
UnrealizedConversionCastOp *castOp) {
1598+
UnrealizedConversionCastOp *castOp, bool isOnlyTypeConversion) {
15271599
assert((!originalType || kind == MaterializationKind::Target) &&
15281600
"original type is valid only for target materializations");
15291601
assert(TypeRange(inputs) != outputTypes &&
@@ -1536,7 +1608,7 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
15361608
auto convertOp =
15371609
UnrealizedConversionCastOp::create(builder, loc, outputTypes, inputs);
15381610
if (!valuesToMap.empty())
1539-
mapping.map(valuesToMap, convertOp.getResults());
1611+
mapping.map(valuesToMap, convertOp.getResults(), isOnlyTypeConversion);
15401612
if (castOp)
15411613
*castOp = convertOp;
15421614
unresolvedMaterializations[convertOp] =
@@ -1650,7 +1722,8 @@ void ConversionPatternRewriterImpl::replaceOp(
16501722
MaterializationKind::Source, computeInsertPoint(result),
16511723
result.getLoc(), /*valuesToMap=*/{result}, /*inputs=*/ValueRange(),
16521724
/*outputTypes=*/result.getType(), /*originalType=*/Type(),
1653-
currentTypeConverter);
1725+
currentTypeConverter, /*castOp=*/nullptr,
1726+
/*isOnlyTypeConversion=*/false);
16541727
continue;
16551728
}
16561729

mlir/test/Transforms/test-legalizer.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,3 +415,20 @@ func.func @test_multiple_1_to_n_replacement() {
415415
%0 = "test.multiple_1_to_n_replacement"() : () -> (f16)
416416
"test.invalid"(%0) : (f16) -> ()
417417
}
418+
419+
// -----
420+
421+
// CHECK-LABEL: func @test_lookup_without_converter
422+
// CHECK: %[[producer:.*]] = "test.valid_producer"() : () -> i16
423+
// CHECK: %[[cast:.*]] = "test.cast"(%[[producer]]) : (i16) -> f64
424+
// CHECK: "test.valid_consumer"(%[[cast]]) : (f64) -> ()
425+
// CHECK: "test.valid_consumer"(%[[producer]]) : (i16) -> ()
426+
func.func @test_lookup_without_converter() {
427+
%0 = "test.replace_with_valid_producer"() {type = i16} : () -> (i64)
428+
"test.replace_with_valid_consumer"(%0) {with_converter} : (i64) -> ()
429+
// Make sure that the second "replace_with_valid_consumer" lowering does not
430+
// lookup the materialization that was created for the above op.
431+
"test.replace_with_valid_consumer"(%0) : (i64) -> ()
432+
// expected-remark@+1 {{op 'func.return' is not legalizable}}
433+
return
434+
}

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2104,6 +2104,10 @@ def TestInvalidOp : TEST_Op<"invalid", [Terminator]>,
21042104
Arguments<(ins Variadic<AnyType>)>;
21052105
def TestTypeProducerOp : TEST_Op<"type_producer">,
21062106
Results<(outs AnyType)>;
2107+
def TestValidProducerOp : TEST_Op<"valid_producer">,
2108+
Results<(outs AnyType)>;
2109+
def TestValidConsumerOp : TEST_Op<"valid_consumer">,
2110+
Arguments<(ins AnyType)>;
21072111
def TestAnotherTypeProducerOp : TEST_Op<"another_type_producer">,
21082112
Results<(outs AnyType)>;
21092113
def TestTypeConsumerOp : TEST_Op<"type_consumer">,

0 commit comments

Comments
 (0)