Skip to content

Commit 702a6b7

Browse files
[compiler] Improve consistency of how 'target' attribute on grouping ops is handled
- Target attribute is now optional on inline_group and inline_closed_group ops. - Add printer/parser tests - Use the attribute interface when declaring the attribute in the ODS. GitOrigin-RevId: 824b46c7379a3e8a4e9a8d33e65ec4146ac0abb7
1 parent 9cf058d commit 702a6b7

File tree

5 files changed

+72
-44
lines changed

5 files changed

+72
-44
lines changed

mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/IR/PlanOps.td

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -105,15 +105,15 @@ def Plan_InlineGroupOp : Plan_GroupOpBase<"inline_group", [
105105
more specialized operation(s) and/or have the body outlined to a
106106
function-like operation such that the reuslts are replaced with a
107107
call-like operation.
108-
109108
}];
110109

111-
let arguments = (ins CompilerBackendAttrInterface:$target);
110+
let arguments = (ins OptionalAttr<CompilerBackendAttrInterface>:$target);
112111

113112
let results = (outs Variadic<AnyType>:$results);
114113

115114
let assemblyFormat = [{
116-
`target` `(` $target `)` attr-dict-with-keyword (`->` type($results)^)? $body
115+
( `target` `(` $target^ `)` )?
116+
attr-dict-with-keyword (`->` type($results)^)? $body
117117
}];
118118

119119
let hasVerifier = 1;
@@ -236,16 +236,17 @@ def Plan_InlineClosedGroupOp : Plan_InlineClosedGroupBase<"inline_closed_group",
236236
```
237237

238238
}];
239-
let arguments = (ins Variadic<AnyTypeOf<[AnyRankedTensor, AnySignlessIntegerOrIndex]>>:$inputs,
240-
Variadic<AnyRankedTensor>:$outs,
241-
BoundsAttrArray:$input_attrs,
242-
BoundsAttrArray:$res_attrs,
243-
AnyAttr:$target);
239+
let arguments =
240+
(ins Variadic<
241+
AnyTypeOf<[AnyRankedTensor, AnySignlessIntegerOrIndex]>>:$inputs,
242+
Variadic<AnyRankedTensor>:$outs, BoundsAttrArray:$input_attrs,
243+
BoundsAttrArray:$res_attrs,
244+
OptionalAttr<CompilerBackendAttrInterface>:$target);
244245

245246
let results = (outs Variadic<AnyType>:$results);
246247

247248
let assemblyFormat = [{
248-
`target` `(` $target `)` `\n`
249+
(`target` `(` $target^ `)` `\n`)?
249250
`inputs` `(` ( $inputs^ `:` type($inputs) `)` ) : ( `)` ) ? `\n`
250251
`outs` `(` $outs `:` type($outs) `)` `\n`
251252
`in_attrs` $input_attrs `\n`
@@ -257,12 +258,10 @@ def Plan_InlineClosedGroupOp : Plan_InlineClosedGroupBase<"inline_closed_group",
257258

258259
let skipDefaultBuilders = 1;
259260

260-
let builders = [
261-
OpBuilder<(ins "Attribute":$target,
262-
"ValueRange":$inputs, "ValueRange":$outs,
263-
CArg<"ArrayRef<BoundsAttr>", "{}">:$input_attrs,
264-
CArg<"ArrayRef<BoundsAttr>", "{}">:$res_attrs)>
265-
];
261+
let builders = [OpBuilder<(ins "CompilerBackendAttrInterface":$target,
262+
"ValueRange":$inputs, "ValueRange":$outs,
263+
CArg<"ArrayRef<BoundsAttr>", "{}">:$input_attrs,
264+
CArg<"ArrayRef<BoundsAttr>", "{}">:$res_attrs)>];
266265

267266
let extraClassDeclaration = baseInlineClosedExtraClassDeclaration # [{
268267

@@ -326,14 +325,13 @@ def Plan_InlineClosedAllocGroupOp : Plan_InlineClosedGroupBase<"inline_closed_al
326325
}
327326

328327
}];
329-
let arguments = (ins Variadic<AnyType>:$inputs,
330-
BoundsAttrArray:$input_attrs,
331-
AnyAttr:$target);
328+
let arguments = (ins Variadic<AnyType>:$inputs, BoundsAttrArray:$input_attrs,
329+
OptionalAttr<CompilerBackendAttrInterface>:$target);
332330

333331
let results = (outs Variadic<AnyType>:$results);
334332

335333
let assemblyFormat = [{
336-
`target` `(` $target `)` `\n`
334+
(`target` `(` $target^ `)` `\n`)?
337335
`inputs` `(` ( $inputs^ `:` type($inputs) `)` ) : ( `)` ) ? `\n`
338336
`in_attrs` $input_attrs `\n`
339337
attr-dict-with-keyword `->` type($results)
@@ -344,11 +342,10 @@ def Plan_InlineClosedAllocGroupOp : Plan_InlineClosedGroupBase<"inline_closed_al
344342

345343
let skipDefaultBuilders = 1;
346344

347-
let builders = [
348-
OpBuilder<(ins "TypeRange":$results,
349-
"Attribute":$target,
350-
"ValueRange":$inputs,
351-
CArg<"ArrayRef<BoundsAttr>", "{}">:$input_attrs)>,
345+
let builders = [OpBuilder<(ins "TypeRange":$results,
346+
"CompilerBackendAttrInterface":$target,
347+
"ValueRange":$inputs,
348+
CArg<"ArrayRef<BoundsAttr>", "{}">:$input_attrs)>,
352349
];
353350

354351
let extraClassDeclaration = baseInlineClosedExtraClassDeclaration;

mlir-tensorrt/compiler/lib/Dialect/Plan/IR/PlanOps.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -147,13 +147,14 @@ void InlineClosedGroupOp::getAsmBlockArgumentNames(
147147
}
148148

149149
void InlineClosedGroupOp::build(OpBuilder &b, OperationState &state,
150-
Attribute target, ValueRange inputs,
151-
ValueRange outs,
150+
CompilerBackendAttrInterface target,
151+
ValueRange inputs, ValueRange outs,
152152
ArrayRef<BoundsAttr> input_attrs,
153153
ArrayRef<BoundsAttr> result_attrs) {
154154
state.addOperands(inputs);
155155
state.addOperands(outs);
156-
state.getOrAddProperties<Properties>().target = target;
156+
if (target)
157+
state.getOrAddProperties<Properties>().target = target;
157158
state.getOrAddProperties<Properties>().setInputAttrs(b.getArrayAttr(
158159
SmallVector<Attribute>(input_attrs.begin(), input_attrs.end())));
159160
state.getOrAddProperties<Properties>().setResAttrs(b.getArrayAttr(
@@ -212,12 +213,14 @@ void InlineClosedAllocGroupOp::getAsmBlockArgumentNames(
212213
}
213214

214215
void InlineClosedAllocGroupOp::build(OpBuilder &b, OperationState &state,
215-
TypeRange resultTypes, Attribute target,
216+
TypeRange resultTypes,
217+
CompilerBackendAttrInterface target,
216218
ValueRange inputs,
217219
ArrayRef<BoundsAttr> input_attrs) {
218220
state.addTypes(resultTypes);
219221
state.addOperands(inputs);
220-
state.getOrAddProperties<Properties>().target = target;
222+
if (target)
223+
state.getOrAddProperties<Properties>().target = target;
221224
state.getOrAddProperties<Properties>().setInputAttrs(b.getArrayAttr(
222225
SmallVector<Attribute>(input_attrs.begin(), input_attrs.end())));
223226
Region *body = state.addRegion();

mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/CreateClosedRegions.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,7 @@ static LogicalResult createInlineClosedGroupOp(
436436
const ValueRange &inputs,
437437
ArrayRef<DestinationOperandMaterializationResult> destinationOperands) {
438438
InlineClosedGroupOp closedGroupOp = rewriter.create<InlineClosedGroupOp>(
439-
op.getLoc(), /*target=*/op.getTarget(),
439+
op.getLoc(), /*target=*/op.getTargetAttr(),
440440
/*inputs=*/inputs,
441441
/*outs=*/
442442
llvm::map_to_vector(destinationOperands,
@@ -509,7 +509,7 @@ createInlineClosedAllocGroupOp(RewriterBase &rewriter, plan::InlineGroupOp op,
509509
InlineClosedAllocGroupOp closedGroupOp =
510510
rewriter.create<InlineClosedAllocGroupOp>(
511511
op.getLoc(), /*result type*/ op->getResultTypes(),
512-
/*target=*/op.getTarget(),
512+
/*target=*/op.getTargetAttr(),
513513
/*inputs=*/inputs);
514514

515515
rewriter.inlineBlockBefore(
@@ -543,7 +543,10 @@ createClosedGroupOp(RewriterBase &rewriter, plan::InlineGroupOp op,
543543
bool disableDestinationStyleCallingConvention) {
544544
OpBuilder::InsertionGuard g(rewriter);
545545

546-
CompilerBackendAttrInterface backend = op.getTarget();
546+
CompilerBackendAttrInterface backend = op.getTargetAttr();
547+
if (!backend)
548+
return op.emitError("missing target attribute");
549+
547550
bool useDestinationStyleCallingConvention =
548551
!disableDestinationStyleCallingConvention &&
549552
backend.supportsDestinationStyleCallingConvention(op) &&

mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/OutlineClusters.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,11 @@ using namespace mlir::plan;
4545

4646
static CompilerBackendAttrInterface getClusterTargetForRegionOp(Operation *op) {
4747
if (auto regionOp = dyn_cast<plan::InlineGroupOp>(op))
48-
return cast<CompilerBackendAttrInterface>(regionOp.getTarget());
48+
return regionOp.getTargetAttr();
4949
if (auto regionOp = dyn_cast<plan::InlineClosedGroupOp>(op))
50-
return cast<CompilerBackendAttrInterface>(regionOp.getTarget());
50+
return regionOp.getTargetAttr();
5151
if (auto regionOp = dyn_cast<plan::InlineClosedAllocGroupOp>(op))
52-
return cast<CompilerBackendAttrInterface>(regionOp.getTarget());
52+
return regionOp.getTargetAttr();
5353
llvm_unreachable("unknown cluster region op kind");
5454
}
5555

@@ -58,6 +58,8 @@ static FailureOr<OutlineRegionOptions>
5858
getOutliningParam(InputKind inputKind, Operation *op,
5959
SymbolTable &moduleSymbolTable) {
6060
CompilerBackendAttrInterface target = getClusterTargetForRegionOp(op);
61+
if (!target)
62+
return op->emitError("missing target attribute");
6163
std::optional<OutlineRegionOptions> opts = target.getClusterOutliningOptions(
6264
inputKind, op->getContext(), moduleSymbolTable);
6365
if (!opts)
@@ -131,7 +133,7 @@ class OutlineClustersPass
131133
for (FunctionOpInterface func : funcs) {
132134
SmallVector<plan::InlineGroupOp> clusters;
133135
func->walk([&](plan::InlineGroupOp clusterOp) {
134-
if (!isa<CompilerBackendAttrInterface>(clusterOp.getTarget()))
136+
if (!clusterOp.getTargetAttr())
135137
return WalkResult::advance();
136138
clusters.push_back(clusterOp);
137139
return WalkResult::skip();

mlir-tensorrt/compiler/test/Dialect/Plan/roundtrip.mlir

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ func.func @plan_inline_group(%arg0: tensor<10xf32>, %arg1: tensor<10xf32>) -> te
2323
yield %1 : tensor<10xf32>
2424
}
2525

26-
%1 = plan.inline_group target(#plan.host_backend<benefit = 1>) -> tensor<10xf32> {
26+
%1 = plan.inline_group -> tensor<10xf32> {
2727
%2 = stablehlo.add %0, %0 : tensor<10xf32>
2828
yield %2 : tensor<10xf32>
2929
}
@@ -43,7 +43,7 @@ func.func @plan_inline_group(%arg0: tensor<10xf32>, %arg1: tensor<10xf32>) -> te
4343
// CHECK-NEXT: %[[v2:.+]] = stablehlo.add %[[arg0]], %[[arg1]] : tensor<10xf32>
4444
// CHECK-NEXT: yield %[[v2]] : tensor<10xf32>
4545
// CHECK-NEXT: }
46-
// CHECK-NEXT: %[[v1:.+]] = plan.inline_group target(#plan.host_backend<benefit = 1>) -> tensor<10xf32> {
46+
// CHECK-NEXT: %[[v1:.+]] = plan.inline_group -> tensor<10xf32> {
4747
// CHECK-NEXT: %[[v2:.+]] = stablehlo.add %[[v0]], %[[v0]] : tensor<10xf32>
4848
// CHECK-NEXT: yield %[[v2]] : tensor<10xf32>
4949
// CHECK-NEXT: }
@@ -55,7 +55,8 @@ func.func @plan_inline_group(%arg0: tensor<10xf32>, %arg1: tensor<10xf32>) -> te
5555

5656
// -----
5757

58-
func.func @inline_closed_group(%arg0: tensor<?xf32>, %arg1: index, %arg2: tensor<?xf32>) -> tensor<?xf32> {
58+
func.func @inline_closed_group(%arg0: tensor<?xf32>, %arg1: index, %arg2: tensor<?xf32>)
59+
-> (tensor<?xf32>, tensor<10xf32>) {
5960
%2 = plan.inline_closed_group target(#plan.tensorrt_backend<disallow_shape_tensor_calculations = false, benefit = 1>)
6061
inputs(%arg0, %arg1 : tensor<?xf32>, index)
6162
outs(%arg2 : tensor<?xf32>)
@@ -66,7 +67,16 @@ func.func @inline_closed_group(%arg0: tensor<?xf32>, %arg1: index, %arg2: tensor
6667
%res = stablehlo.exponential %2 : tensor<?xf32>
6768
yield %res : tensor<?xf32>
6869
}
69-
return %2 : tensor<?xf32>
70+
71+
%empty = tensor.empty() : tensor<10xf32>
72+
%3 = plan.inline_closed_group inputs()
73+
outs(%empty : tensor<10xf32>)
74+
in_attrs []
75+
res_attrs [#plan.bounds<none>] -> tensor<10xf32> {
76+
^bb0(%out: tensor<10xf32>):
77+
yield %out : tensor<10xf32>
78+
}
79+
return %2, %3 : tensor<?xf32>, tensor<10xf32>
7080
}
7181

7282
// CHECK-LABEL: func.func @inline_closed_group
@@ -80,11 +90,16 @@ func.func @inline_closed_group(%arg0: tensor<?xf32>, %arg1: index, %arg2: tensor
8090
// CHECK-NEXT: %{{.+}} = stablehlo.exponential %{{.+}} : tensor<?xf32>
8191
// CHECK-NEXT: yield %{{.+}} : tensor<?xf32>
8292
// CHECK-NEXT: }
83-
// CHECK-NEXT: return
93+
// CHECK: plan.inline_closed_group inputs()
94+
// CHECK-NEXT: outs(%{{.+}} : tensor<10xf32>)
95+
// CHECK-NEXT: in_attrs []
96+
// CHECK-NEXT: res_attrs [#plan.bounds<none>] -> tensor<10xf32> {
97+
// CHECK-NEXT: ^bb0(%out{{.*}}: tensor<10xf32>):
98+
// CHECK-NEXT: yield %out{{.*}} : tensor<10xf32>
8499

85100
// -----
86101

87-
func.func @inline_closed_alloc_group(%arg0: tensor<?xf32>, %arg1: index) -> tensor<?xf32> {
102+
func.func @inline_closed_alloc_group(%arg0: tensor<?xf32>, %arg1: index) -> (tensor<?xf32>, tensor<10xf32>) {
88103
%2 = plan.inline_closed_alloc_group target(#plan.host_backend<benefit=1>)
89104
inputs(%arg0, %arg1 : tensor<?xf32>, index)
90105
in_attrs [#plan.bounds<shape, [10], [20]>, #plan.bounds<none>] -> tensor<?xf32> {
@@ -93,7 +108,13 @@ func.func @inline_closed_alloc_group(%arg0: tensor<?xf32>, %arg1: index) -> tens
93108
%res = stablehlo.exponential %2 : tensor<?xf32>
94109
yield %res : tensor<?xf32>
95110
}
96-
return %2 : tensor<?xf32>
111+
%3 = plan.inline_closed_alloc_group inputs()
112+
in_attrs [] -> tensor<10xf32> {
113+
^bb0():
114+
%out = tensor.empty() : tensor<10xf32>
115+
yield %out : tensor<10xf32>
116+
}
117+
return %2, %3 : tensor<?xf32>, tensor<10xf32>
97118
}
98119

99120
// CHECK-LABEL: func.func @inline_closed_alloc_group
@@ -106,7 +127,9 @@ func.func @inline_closed_alloc_group(%arg0: tensor<?xf32>, %arg1: index) -> tens
106127
// CHECK-NEXT: %{{.+}} = stablehlo.exponential %{{.+}} : tensor<?xf32>
107128
// CHECK-NEXT: yield %{{.+}} : tensor<?xf32>
108129
// CHECK-NEXT: }
109-
// CHECK-NEXT: return
130+
// CHECK: plan.inline_closed_alloc_group inputs()
131+
// CHECK-NEXT: in_attrs []
132+
// CHECK-NEXT: -> tensor<10xf32> {
110133

111134
// -----
112135

0 commit comments

Comments
 (0)