Skip to content

Commit 8f87f14

Browse files
committed
[compiler][conversion] Fix an issue in stablehlo.slice conversion
This PR fixes an issue in `stablehlo.slice` conversion which popped result type is dynamic. In the converter, we use `tensorrt.slice` op builder that infers result based on `static_size` (output shape is same as size on TensorRT side) and `size` was taken as shape of stablehlo op output. This caused two issues, - `truncateI64ToI32` failed for dynamic output because dynamic dim can't be truncated to i32 without loss. - A new builder needed which could set result to dynamic shape even when size is static. With this change, - `size` is computed as `ceil(limit-start)/stride` (these are stablehlo slice op attributes) and NOT as shape of output. - A new builder is added. MLIR test is added.
1 parent 7fc38c8 commit 8f87f14

File tree

4 files changed

+44
-8
lines changed

4 files changed

+44
-8
lines changed

mlir-tensorrt/compiler/lib/Conversion/StablehloToTensorRT/StablehloToTensorRT.cpp

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@
4444
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
4545
#include "stablehlo/dialect/StablehloOps.h"
4646
#include "llvm/ADT/StringExtras.h"
47+
#include "llvm/Support/Debug.h"
48+
#include <cmath>
4749
#include <functional>
4850
#include <numeric>
4951
#include <regex>
@@ -1947,19 +1949,25 @@ struct HloSliceConverter : public ConvertHloOpToTensorRTPattern<HloOpType> {
19471949
if (failed(startIndices))
19481950
return rewriter.notifyMatchFailure(
19491951
op, "could not convert i64 offsets to i32");
1952+
FailureOr<SmallVector<int32_t>> limitIndices =
1953+
truncateI64ToI32(loc, op.getLimitIndices());
1954+
if (failed(limitIndices))
1955+
return rewriter.notifyMatchFailure(
1956+
op, "could not convert i64 offsets to i32");
19501957
FailureOr<SmallVector<int32_t>> strides =
19511958
truncateI64ToI32(loc, op.getStrides());
19521959
if (failed(strides))
19531960
return rewriter.notifyMatchFailure(op,
19541961
"could not convert i64 stride to i32");
1955-
FailureOr<SmallVector<int32_t>> i32Shape =
1956-
truncateI64ToI32(loc, op.getType().getShape());
1957-
if (failed(i32Shape))
1958-
return rewriter.notifyMatchFailure(op,
1959-
"could not convert i64 shape to i32");
1962+
1963+
SmallVector<int32_t> i32Shape(limitIndices->size());
1964+
for (size_t i = 0; i < limitIndices->size(); i++) {
1965+
i32Shape[i] = std::ceil((((*limitIndices)[i] - (*startIndices)[i])) /
1966+
static_cast<float>((*strides)[i]));
1967+
}
19601968
auto sliceOp = trtRewriter.checkAndCreate<mlir::tensorrt::SliceOp>(
1961-
op.getLoc(), targetTrtMajorVersion, adaptor.getOperand(), *startIndices,
1962-
*i32Shape, *strides);
1969+
op.getLoc(), targetTrtMajorVersion, op.getType(), adaptor.getOperand(),
1970+
*startIndices, i32Shape, *strides);
19631971
if (!sliceOp)
19641972
return failure();
19651973

mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTOps.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1628,6 +1628,11 @@ def TensorRT_SliceOp : TensorRT_Op<"slice", [
16281628
"ArrayRef<int32_t>":$size, "ArrayRef<int32_t>":$stride,
16291629
CArg<"SliceMode", "SliceMode::kDEFAULT">:$sliceMode,
16301630
CArg<"Value", "Value()">:$fill)>,
1631+
// Same as above but result shape is provided and not inferred.
1632+
OpBuilder<(ins "Type":$result, "Value":$input, "ArrayRef<int32_t>":$start,
1633+
"ArrayRef<int32_t>":$size, "ArrayRef<int32_t>":$stride,
1634+
CArg<"SliceMode", "SliceMode::kDEFAULT">:$sliceMode,
1635+
CArg<"Value", "Value()">:$fill)>,
16311636
// Builder using static array for start/stride and Value for size.
16321637
OpBuilder<(ins "Value":$input, "ArrayRef<int32_t>":$start,
16331638
"Value":$size, "ArrayRef<int32_t>":$stride,

mlir-tensorrt/tensorrt/lib/TensorRT/IR/TensorRT.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1136,6 +1136,17 @@ void tensorrt::SliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
11361136
toArrayAttr(size), toArrayAttr(stride), sliceMode, fill);
11371137
}
11381138

1139+
void tensorrt::SliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1140+
Type result, Value input, ArrayRef<int32_t> start,
1141+
ArrayRef<int32_t> size, ArrayRef<int32_t> stride,
1142+
SliceMode sliceMode, Value fill) {
1143+
auto toArrayAttr = [&](ArrayRef<int32_t> arr) {
1144+
return OpFoldResult(DenseI32ArrayAttr::get(odsBuilder.getContext(), arr));
1145+
};
1146+
SliceOp::build(odsBuilder, odsState, result, input, toArrayAttr(start),
1147+
toArrayAttr(size), toArrayAttr(stride), sliceMode, fill);
1148+
}
1149+
11391150
void tensorrt::SliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
11401151
Value input, ArrayRef<int32_t> start, Value size,
11411152
ArrayRef<int32_t> stride, SliceMode sliceMode,

mlir-tensorrt/test/Conversion/StablehloToTensorRT/stablehlo-to-tensorrt.mlir

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1927,4 +1927,16 @@ func.func @jnp_cumsum_2d_f16(%arg0: tensor<1x134xf16>) -> tensor<1x134xf16> {
19271927
// CHECK-SAME: post_padding = array<i64: 0, 0>
19281928
// CHECK-SAME: pre_padding = array<i64: 0, 133>
19291929
// CHECK-SAME: in(%[[v1]] : tensor<1x1x1x134xf16>) kernel(%[[v2]] : tensor<1x1x1x134xf16>) -> tensor<1x1x1x134xf16>
1930-
// CHECK: %[[v4:.+]] = tensorrt.reshape %[[v3]] : tensor<1x1x1x134xf16> to tensor<1x134xf16>
1930+
// CHECK: %[[v4:.+]] = tensorrt.reshape %[[v3]] : tensor<1x1x1x134xf16> to tensor<1x134xf16>
1931+
1932+
// -----
1933+
1934+
func.func @slice_conversion_dynamic(%arg0: tensor<1x?x256xf16>) -> tensor<1x?x256xf16>{
1935+
%16 = "stablehlo.slice"(%arg0) <{limit_indices = array<i64: 1, 6, 256>, start_indices = array<i64: 0, 2, 0>, strides = array<i64: 1, 1, 1>}> : (tensor<1x?x256xf16>) -> tensor<1x?x256xf16>
1936+
return %16: tensor<1x?x256xf16>
1937+
}
1938+
1939+
// CHECK-LABEL: @slice_conversion_dynamic
1940+
// CHECK-SAME: (%[[arg0:.+]]: tensor<1x?x256xf16>) -> tensor<1x?x256xf16>
1941+
// CHECK-NEXT: %[[v0:.+]] = tensorrt.slice %[[arg0]][0, 2, 0][1, 4, 256][1, 1, 1] : tensor<1x?x256xf16> to tensor<1x?x256xf16>
1942+
// CHECK-NEXT: return %[[v0]] : tensor<1x?x256xf16>

0 commit comments

Comments
 (0)