Skip to content

Commit c2180a6

Browse files
Copybara Botchristopherbate
authored andcommitted
Move internal changes
GitOrigin-RevId: 99e6260ee0c2679e1e6539e60b7840d9a90845b9
1 parent ee30299 commit c2180a6

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/TransposeElimination.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,10 @@ namespace tensorrt {
4444
using namespace mlir;
4545
using namespace mlir::tensorrt;
4646

47-
static int64_t memoryCost(TensorType type) {
47+
static int64_t memoryCost(RankedTensorType type) {
48+
// If the type is dynamic, then return max.
49+
if (!type.hasStaticShape())
50+
return std::numeric_limits<int64_t>::max();
4851
ArrayRef<int64_t> shape = type.getShape();
4952
return std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<>());
5053
}

mlir-tensorrt/tensorrt/test/Dialect/TensorRT/transpose-elimination.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,25 @@ func.func @transpose_pushdown_noop(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32
5151

5252
// -----
5353

54+
#map = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
55+
56+
func.func @tranpose_pushdown_dynamic(%arg0: tensor<?x80x80xf32>) -> tensor<?x80x80xf32> {
57+
%cst_f32 = tensorrt.constant dense<1.000000e+00> : tensor<1x1x1xf32>
58+
%1 = tensorrt.transpose {permutation = #map} %arg0 : tensor<?x80x80xf32> to tensor<?x80x80xf32>
59+
%2 = tensorrt.element_wise <kSUB>(%cst_f32, %1 : tensor<1x1x1xf32>, tensor<?x80x80xf32>) -> tensor<?x80x80xf32>
60+
return %2 : tensor<?x80x80xf32>
61+
}
62+
63+
// CHECK: #[[$map:.+]] = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
64+
// CHECK-LABEL: func.func @tranpose_pushdown_dynamic
65+
// CHECK-SAME: (%[[arg0:.+]]: tensor<?x80x80xf32>) -> tensor<?x80x80xf32
66+
// CHECK-DAG: %[[cst_f32:.+]] = tensorrt.constant dense<1.000000e+00> : tensor<1x1x1xf32>
67+
// CHECK-DAG: %[[v0:.+]] = tensorrt.element_wise <kSUB>(%[[cst_f32]], %[[arg0]] : tensor<1x1x1xf32>, tensor<?x80x80xf32>) -> tensor<?x80x80xf32>
68+
// CHECK-DAG: %[[v1:.+]] = tensorrt.transpose {permutation = #[[$map]]} %[[v0]] : tensor<?x80x80xf32> to tensor<?x80x80xf32>
69+
// CHECK-DAG: return %[[v1]] : tensor<?x80x80xf32>
70+
71+
// -----
72+
5473
func.func @transpose_pushdown_switch(%arg0: tensor<2x2xf32>, %arg1: tensor<1x2xf32>) -> tensor<2x2xf32> {
5574
%1 = tensorrt.transpose {permutation = affine_map<(d0, d1)->(d1, d0)>} %arg0 : tensor<2x2xf32> to tensor<2x2xf32>
5675
%2 = tensorrt.element_wise <kSUM> (%1, %arg1: tensor<2x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32>

0 commit comments

Comments
 (0)