Skip to content

Commit d4a3058

Browse files
shelkesagar29Copybara Bot
andauthored
Move internal changes (#380)
This PR moves the following internal changes to OSS repo. ## [plan][transforms] Fix an issue in shape materialization pass This MR fixes an issue in `SimplifyExtractOfReshape` pattern of shape materialization pass. This patten was being applied even when reshape op operand is dynamic. However, with dynamic operand, mapping extract index into reshape operand doesn't work. With this change, we return failure if reshape op operand is dynamic. MLIR test is added for scenario when pattern should return failure. ## [tensorrt] Fix incorrect handling of dynamic shape in `tensorrt-broadcast-elimination` Fixes an issue where a `tensorrt-broadcast-elimination` would improperly handle dynamically shaped tensors when attempting to reshape them. In certain cases (when more than 1 dynamic dimension is present), to perform a reshape, the target shape must be explicitly calculated in the IR and a dynamic reshape must be created. Co-authored-by: Copybara Bot <[email protected]>
1 parent 64c780f commit d4a3058

File tree

12 files changed

+138
-10
lines changed

12 files changed

+138
-10
lines changed

mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/MaterializeShapeCalculations.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,9 @@ struct SimplifyExtractOfReshape : public OpRewritePattern<tensor::ExtractOp> {
361361
if (!reshapeOp)
362362
return failure();
363363

364+
if (!reshapeOp.getOperand().getType().hasStaticShape())
365+
return failure();
366+
364367
std::optional<SmallVector<int64_t>> coords =
365368
getConstantIntValues(getAsOpFoldResult(op.getIndices()));
366369
if (!coords)

mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/BroadcastElimination.cpp

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,65 @@ struct PushDownBroadcastReduceRankOp : public OpRewritePattern<CollapseRankOp> {
120120
};
121121
} // namespace
122122

123+
static Value expandRank(RewriterBase &rewriter, Location loc,
124+
TypedValue<RankedTensorType> input,
125+
ArrayRef<int64_t> reorderedBroadcastDims,
126+
RankedTensorType resultType) {
127+
RankedTensorType inputType = input.getType();
128+
// For <= 1 dynamic dims, no need to do dynamic reshape.
129+
if (input.getType().getNumDynamicDims() <= 1) {
130+
SmallVector<int64_t> staticShape(resultType.getRank());
131+
132+
unsigned inputIdx = 0;
133+
for (unsigned i = 0, e = staticShape.size(); i < e; i++) {
134+
if (inputIdx < reorderedBroadcastDims.size() &&
135+
i == reorderedBroadcastDims[inputIdx]) {
136+
staticShape[i] = inputType.getDimSize(inputIdx++);
137+
continue;
138+
}
139+
staticShape[i] = 1;
140+
}
141+
return rewriter.create<ReshapeOp>(loc, resultType.clone(staticShape),
142+
input);
143+
}
144+
145+
// Otherwise, we need to do dynamic reshape.
146+
auto shape = rewriter.create<tensorrt::ShapeOp>(loc, input);
147+
SmallVector<Value> shapeComponents(resultType.getRank());
148+
SmallVector<int64_t> staticShape(resultType.getRank());
149+
unsigned inputIdx = 0;
150+
for (unsigned i = 0, e = shapeComponents.size(); i < e; i++) {
151+
if (inputIdx < reorderedBroadcastDims.size() &&
152+
i == reorderedBroadcastDims[inputIdx]) {
153+
if (!inputType.isDynamicDim(inputIdx)) {
154+
staticShape[i] = inputType.getDimSize(inputIdx);
155+
shapeComponents[i] = rewriter.create<tensorrt::ConstantOp>(
156+
loc, rewriter.getI32TensorAttr(
157+
{static_cast<int32_t>(inputType.getDimSize(inputIdx++))}));
158+
continue;
159+
}
160+
shapeComponents[i] = rewriter.create<tensorrt::SliceOp>(
161+
loc, shape,
162+
/*offset=*/ArrayRef<int32_t>{static_cast<int32_t>(inputIdx++)},
163+
ArrayRef<int32_t>{1}, ArrayRef<int32_t>{1});
164+
staticShape[i] = ShapedType::kDynamic;
165+
continue;
166+
}
167+
staticShape[i] = 1;
168+
shapeComponents[i] = rewriter.create<tensorrt::ConstantOp>(
169+
loc, rewriter.getI32TensorAttr(
170+
{static_cast<int32_t>(inputType.getDimSize(1))}));
171+
}
172+
auto newShape = rewriter.create<tensorrt::ConcatenationOp>(
173+
loc,
174+
RankedTensorType::get(static_cast<int64_t>(shapeComponents.size()),
175+
rewriter.getI32Type()),
176+
shapeComponents, /*axis=*/0);
177+
178+
return rewriter.create<ReshapeOp>(loc, resultType.clone(staticShape), input,
179+
newShape);
180+
}
181+
123182
namespace {
124183
/// Create transpose + expand_rank on the input of a `tensorrt.broadcast` so
125184
/// that the result has the same rank as the `tensorrt.broadcast` result and the
@@ -157,8 +216,9 @@ struct SimplifyBroadcast : public OpRewritePattern<BroadcastOp> {
157216
}
158217
expandedShape[i] = 1;
159218
}
160-
Value expanded = rewriter.create<ExpandRankOp>(
161-
loc, resultType.clone(expandedShape), transposeOp);
219+
220+
Value expanded = expandRank(rewriter, loc, transposeOp,
221+
reorderedBroadcastDims, resultType);
162222
rewriter.replaceOpWithNewOp<BroadcastOp>(
163223
op, op.getType(), expanded, op.getShape(),
164224
llvm::to_vector(llvm::seq<int64_t>(0, resultType.getRank())));
@@ -341,6 +401,8 @@ class BroadcastEliminationPass
341401
patterns.add<SimplifyBroadcast, ElementwiseAbsorbBroadcast,
342402
PushDownBroadcastReduceRankOp, SelectAbsorbBroadcast,
343403
MatMulAbsorbBroadcast>(&getContext());
404+
tensorrt::ReshapeOp::getCanonicalizationPatterns(patterns,
405+
patterns.getContext());
344406
if (failed(applyPatternsAndFoldGreedily(getOperation(),
345407
std::move(patterns)))) {
346408
emitError(getOperation()->getLoc())

mlir-tensorrt/tensorrt/test/Dialect/TensorRT/broadcast-elimination.mlir

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,3 +236,42 @@ func.func @broadcast_elim_matmul_vector(%arg0: tensor<?x?x128xf32>, %arg1: tenso
236236
// CHECK: return %[[v0]] : tensor<?x?x100xf32>
237237

238238

239+
// -----
240+
241+
func.func @broadcast_dynamic_expand_shape_regression(%arg0: tensor<?x?x1x1xi1>, %arg1: tensor<?x1xf16>, %arg2: tensor<?x?x256x256xf16>, %arg3: tensor<4xi32>) -> tensor<?x?x256x256xf16> {
242+
%0 = tensorrt.broadcast %arg0 broadcast_dims<0, 1, 2, 3> shape(%arg3 : tensor<4xi32>) : tensor<?x?x1x1xi1> to tensor<?x?x256x256xi1>
243+
%1 = tensorrt.broadcast %arg1 broadcast_dims<2, 3> shape(%arg3 : tensor<4xi32>) : tensor<?x1xf16> to tensor<?x?x256x256xf16>
244+
%2 = tensorrt.select ins(%0, %arg2, %1 : tensor<?x?x256x256xi1>, tensor<?x?x256x256xf16>, tensor<?x?x256x256xf16>)
245+
-> tensor<?x?x256x256xf16>
246+
return %2 : tensor<?x?x256x256xf16>
247+
}
248+
249+
// CHECK-LABEL: func.func @broadcast_dynamic_expand_shape_regression
250+
// CHECK-SAME: (%[[arg0:.+]]: tensor<?x?x1x1xi1>, %[[arg1:.+]]: tensor<?x1xf16>, %[[arg2:.+]]: tensor<?x?x256x256xf16>, %[[arg3:.+]]: tensor<4xi32>) -> tensor<?x?x256x256xf16> {
251+
// CHECK: %[[v0:.+]] = tensorrt.reshape %[[arg1]] : tensor<?x1xf16> to tensor<1x1x?x1xf16>
252+
// CHECK: %[[v1:.+]] = tensorrt.select ins(%[[arg0]], %[[arg2]], %[[v0]] : tensor<?x?x1x1xi1>, tensor<?x?x256x256xf16>, tensor<1x1x?x1xf16>) -> tensor<?x?x256x256xf16>
253+
// CHECK: return %[[v1]] : tensor<?x?x256x256xf16>
254+
255+
// -----
256+
257+
func.func @broadcast_dynamic_expand_shape_regression(%arg0: tensor<?x?x1x1xi1>, %arg1: tensor<?x1x?xf16>, %arg2: tensor<?x?x256x256xf16>, %arg3: tensor<4xi32>) -> tensor<?x?x256x256xf16> {
258+
%0 = tensorrt.broadcast %arg0 broadcast_dims<0, 1, 2, 3> shape(%arg3 : tensor<4xi32>) : tensor<?x?x1x1xi1> to tensor<?x?x256x256xi1>
259+
%1 = tensorrt.broadcast %arg1 broadcast_dims<3, 2, 1> shape(%arg3 : tensor<4xi32>) : tensor<?x1x?xf16> to tensor<?x?x256x256xf16>
260+
%2 = tensorrt.select ins(%0, %arg2, %1 : tensor<?x?x256x256xi1>, tensor<?x?x256x256xf16>, tensor<?x?x256x256xf16>)
261+
-> tensor<?x?x256x256xf16>
262+
return %2 : tensor<?x?x256x256xf16>
263+
}
264+
265+
// CHECK: #[[$map:.+]] = affine_map<(d0, d1, d2) -> (d2, d1, d0)>
266+
// CHECK: module {
267+
// CHECK-LABEL: func.func @broadcast_dynamic_expand_shape_regression
268+
// CHECK-SAME: (%[[arg0:.+]]: tensor<?x?x1x1xi1>, %[[arg1:.+]]: tensor<?x1x?xf16>, %[[arg2:.+]]: tensor<?x?x256x256xf16>, %[[arg3:.+]]: tensor<4xi32>) -> tensor<?x?x256x256xf16> {
269+
// CHECK: %[[cst_i32:.+]] = tensorrt.constant dense<1> : tensor<1xi32>
270+
// CHECK: %[[v0:.+]] = tensorrt.transpose {permutation = #[[$map]]} %[[arg1]] : tensor<?x1x?xf16> to tensor<?x1x?xf16>
271+
// CHECK: %[[v1:.+]] = tensorrt.shape %[[v0]] : tensor<?x1x?xf16> -> tensor<3xi32>
272+
// CHECK: %[[v2:.+]] = tensorrt.slice %[[v1]][0][1][1] : tensor<3xi32> to tensor<1xi32>
273+
// CHECK: %[[v3:.+]] = tensorrt.slice %[[v1]][2][1][1] : tensor<3xi32> to tensor<1xi32>
274+
// CHECK: %[[v4:.+]] = tensorrt.concatenation {axis = 0 : i32} ins(%[[cst_i32]], %[[v2]], %[[cst_i32]], %[[v3]] : tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32>
275+
// CHECK: %[[v5:.+]] = tensorrt.reshape %[[v0]] shape(%[[v4]]: tensor<4xi32>) : tensor<?x1x?xf16> to tensor<1x?x1x?xf16>
276+
// CHECK: %[[v6:.+]] = tensorrt.select ins(%[[arg0]], %[[arg2]], %[[v5]] : tensor<?x?x1x1xi1>, tensor<?x?x256x256xf16>, tensor<1x?x1x?xf16>) -> tensor<?x?x256x256xf16>
277+
// CHECK: return %[[v6]] : tensor<?x?x256x256xf16>

mlir-tensorrt/test/Dialect/Plan/materialize-shape-calculations.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1088,3 +1088,27 @@ func.func @reduce_window_dynamic_input(%arg0: tensor<?x?x?x?xf32> {tensorrt.shap
10881088
// CHECK-DAG: %[[v2:.+]] = arith.maxsi %[[dim]], %[[c0]] : index
10891089
// CHECK-DAG: %[[v3:.+]] = plan.with_shape %[[v1]](%[[v2]], %[[c3]], %[[c512]], %[[c512]]) :
10901090
// CHECK-DAG: return %[[v3]]
1091+
1092+
// -----
1093+
1094+
func.func @simplify_extract_of_reshape_negative(%arg0: tensor<1x?x3x4xf32>) -> f32 {
1095+
%c0 = arith.constant 0: index
1096+
%c1 = arith.constant 1 : index
1097+
%c2 = arith.constant 2 : index
1098+
%1 = stablehlo.reshape %arg0 : (tensor<1x?x3x4xf32>) -> tensor<1x6x4xf32>
1099+
%2 = tensor.extract %1[%c0, %c1, %c2] : tensor<1x6x4xf32>
1100+
return %2 : f32
1101+
}
1102+
1103+
// CHECK-LABEL: simplify_extract_of_reshape_negative
1104+
// CHECK-SAME: (%[[arg0:.+]]: tensor<1x?x3x4xf32>)
1105+
// CHECK-NEXT: %[[c4:.+]] = arith.constant 4 : index
1106+
// CHECK-NEXT: %[[c3:.+]] = arith.constant 3 : index
1107+
// CHECK-NEXT: %[[c2:.+]] = arith.constant 2 : index
1108+
// CHECK-NEXT: %[[c1:.+]] = arith.constant 1 : index
1109+
// CHECK-NEXT: %[[c0:.+]] = arith.constant 0 : index
1110+
// CHECK-NEXT: %[[dim:.+]] = tensor.dim %[[arg0]], %[[c1]] : tensor<1x?x3x4xf32>
1111+
// CHECK-NEXT: %[[v0:.+]] = plan.with_shape %[[arg0]](%[[c1]], %[[dim]], %[[c3]], %[[c4]])
1112+
// CHECK-NEXT: %[[v1:.+]] = stablehlo.reshape %[[v0]]
1113+
// CHECK-NEXT: %[[extracted:.+]] = tensor.extract %[[v1]][%[[c0]], %[[c1]], %[[c2]]]
1114+
// CHECK-NEXT: return %extracted

mlir-tensorrt/test/models/bert.stablehlo.elided.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
module @jit__unnamed_wrapped_function_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
1+
module @bert attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
22
func.func public @main(%arg0: tensor<32x8xi32> {mhlo.layout_mode = "default"}) -> (tensor<32x8x768xf16> {mhlo.layout_mode = "default"}, tensor<32x768xf16> {mhlo.layout_mode = "default"}) {
33
%0 = stablehlo.constant dense_resource<__elided__> : tensor<30522x768xf32>
44
%1 = stablehlo.constant dense_resource<__elided__> : tensor<512x768xf32>

mlir-tensorrt/test/models/gpt2.stablehlo.bs2.elided.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
module @jit_generate attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
1+
module @gpt2_bs2 attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
22
func.func public @main(%arg0: tensor<2x6xi32> {mhlo.sharding = "{replicated}"}, %arg1: tensor<2x6xi32> {mhlo.sharding = "{replicated}"}) -> (tensor<2x20xi32> {jax.result_info = ""}) {
33
%0 = stablehlo.constant dense<0> : tensor<1xi32>
44
%1 = stablehlo.constant dense<768> : tensor<i32>

mlir-tensorrt/test/models/gpt2.stablehlo.elided.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
module @jit_generate attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
1+
module @gpt_bs1 attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
22
func.func public @main(%arg0: tensor<1x7xi32> {jax.arg_info = "inputs['attention_mask']", mhlo.sharding = "{replicated}"}, %arg1: tensor<1x7xi32> {jax.arg_info = "inputs['input_ids']", mhlo.sharding = "{replicated}"}) -> (tensor<1x20xi32> {jax.result_info = ""}) {
33
%0 = stablehlo.constant dense_resource<__elided__> : tensor<50257x768xf16>
44
%1 = stablehlo.constant dense_resource<__elided__> : tensor<1024x768xf16>

mlir-tensorrt/test/models/llama-68m.stablehlo.elided.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
module @jit_generate attributes {mhlo.cross_program_prefetches = [], mhlo.is_dynamic = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32, mhlo.use_auto_spmd_partitioning = false} {
1+
module @llama_68m attributes {mhlo.cross_program_prefetches = [], mhlo.is_dynamic = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32, mhlo.use_auto_spmd_partitioning = false} {
22
func.func @main(%arg0: tensor<1x9xi32> {mhlo.sharding = "{replicated}"}, %arg1: tensor<1x9xi32> {mhlo.sharding = "{replicated}"}) -> tensor<1x20xi32> {
33
%0 = stablehlo.constant dense<1.000000e+00> : tensor<1x1x3072xf32>
44
%1 = stablehlo.constant dense<-3.40282347E+38> : tensor<1x1x1x20xf32>

mlir-tensorrt/test/models/llama-v2.stablehlo.elided.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
module @jit__unnamed_wrapped_function_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
1+
module @llama_v2 attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
22
func.func public @main(%arg0: tensor<1x27xf32> {mhlo.layout_mode = "default"}) -> (tensor<1x27x32000xf32> {mhlo.layout_mode = "default"}) {
33
%0 = stablehlo.constant dense_resource<__elided__> : tensor<32000x4096xf16>
44
%1 = stablehlo.constant dense_resource<__elided__> : tensor<4096xf16>

mlir-tensorrt/test/models/resnet50.stablehlo.elided.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
module @jit__unnamed_wrapped_function_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
1+
module @resnet50 attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
22
func.func public @main(%arg0: tensor<16x3x224x224xf16> {mhlo.layout_mode = "default"}) -> (tensor<16x1000xf16> {mhlo.layout_mode = "default"}) {
33
%0 = stablehlo.constant dense_resource<__elided__> : tensor<7x7x3x64xf32>
44
%1 = stablehlo.constant dense_resource<__elided__> : tensor<64xf32>

0 commit comments

Comments
 (0)