@@ -51,6 +51,25 @@ func.func @transpose_pushdown_noop(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32
5151
5252// -----
5353
54+ #map = affine_map <(d0 , d1 , d2 ) -> (d0 , d2 , d1 )>
55+
56+ func.func @tranpose_pushdown_dynamic (%arg0: tensor <?x80 x80 xf32 >) -> tensor <?x80 x80 xf32 > {
57+ %cst_f32 = tensorrt.constant dense <1.000000e+00 > : tensor <1 x1 x1 xf32 >
58+ %1 = tensorrt.transpose {permutation = #map } %arg0 : tensor <?x80 x80 xf32 > to tensor <?x80 x80 xf32 >
59+ %2 = tensorrt.element_wise <kSUB >(%cst_f32 , %1 : tensor <1 x1 x1 xf32 >, tensor <?x80 x80 xf32 >) -> tensor <?x80 x80 xf32 >
60+ return %2 : tensor <?x80 x80 xf32 >
61+ }
62+
63+ // CHECK: #[[$map:.+]] = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
64+ // CHECK-LABEL: func.func @tranpose_pushdown_dynamic
65+ // CHECK-SAME: (%[[arg0:.+]]: tensor<?x80x80xf32>) -> tensor<?x80x80xf32
66+ // CHECK-DAG: %[[cst_f32:.+]] = tensorrt.constant dense<1.000000e+00> : tensor<1x1x1xf32>
67+ // CHECK-DAG: %[[v0:.+]] = tensorrt.element_wise <kSUB>(%[[cst_f32]], %[[arg0]] : tensor<1x1x1xf32>, tensor<?x80x80xf32>) -> tensor<?x80x80xf32>
68+ // CHECK-DAG: %[[v1:.+]] = tensorrt.transpose {permutation = #[[$map]]} %[[v0]] : tensor<?x80x80xf32> to tensor<?x80x80xf32>
69+ // CHECK-DAG: return %[[v1]] : tensor<?x80x80xf32>
70+
71+ // -----
72+
5473func.func @transpose_pushdown_switch (%arg0: tensor <2 x2 xf32 >, %arg1: tensor <1 x2 xf32 >) -> tensor <2 x2 xf32 > {
5574 %1 = tensorrt.transpose {permutation = affine_map <(d0 , d1 )->(d1 , d0 )>} %arg0 : tensor <2 x2 xf32 > to tensor <2 x2 xf32 >
5675 %2 = tensorrt.element_wise <kSUM > (%1 , %arg1: tensor <2 x2 xf32 >, tensor <1 x2 xf32 >) -> tensor <2 x2 xf32 >
0 commit comments