Skip to content

Commit 7605e21

Browse files
Copybara Botchristopherbate
authored andcommitted
Move internal changes
GitOrigin-RevId: de640057c4adcfee8deb6a6f06405c595d48ac56
1 parent f3ec678 commit 7605e21

File tree

2 files changed

+31
-9
lines changed

2 files changed

+31
-9
lines changed

mlir-tensorrt/tensorrt/lib/Utils/ShapeUtils.cpp

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -109,32 +109,33 @@ tensorrt::getBroadcastedShape(ArrayRef<ArrayRef<int64_t>> shapes) {
109109
// shapes are broadcastable. Don't fail because we can't say for sure it's
110110
// invalid.
111111
const bool allEqual = llvm::all_equal(dimSizes);
112-
if (allEqual && dimSizes.front() == ShapedType::kDynamic)
113-
return ShapedType::kDynamic;
114112

115-
// Dimensions are all equal to a static size.
113+
// Dimensions are all equal to a fixed value or dynamic.
116114
if (allEqual)
117115
return dimSizes.front();
118116

119-
// Some dims are '1', all other dims are equal to another fixed number or
120-
// dynamic.
117+
// Mixture of fixed or unkown extents.
121118
std::optional<int64_t> nonUnitSize{};
122119
for (int64_t dimSize : dimSizes) {
120+
// Extent of 1 is always valid.
123121
if (dimSize == 1)
124122
continue;
123+
// Dynamic extent is always valid.
125124
if (ShapedType::isDynamic(dimSize))
126125
continue;
126+
// If a extent > 1 is present, check that it matches any previously seen
127+
// static >1 extent.
127128
if (nonUnitSize && dimSize == *nonUnitSize)
128129
continue;
129130
if (nonUnitSize && dimSize != *nonUnitSize)
130131
return failure();
131132
nonUnitSize = dimSize;
132133
}
133-
if (nonUnitSize)
134-
return *nonUnitSize;
135134

136-
// No other case is valid.
137-
return failure();
135+
// Return the size >1 is seen, otherwise return dynamic indicator. An
136+
// inferred size of 1 is only possible if all extents are 1; this case is
137+
// captured by the check before the loop.
138+
return nonUnitSize ? *nonUnitSize : ShapedType::kDynamic;
138139
};
139140

140141
for (auto dim : llvm::seq<unsigned>(0, rank)) {

mlir-tensorrt/tensorrt/test/Dialect/TensorRT/invalid.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -857,6 +857,27 @@ func.func @trt_select(%arg0: tensor<10x10xi1>, %arg1: tensor<1x10xf32>, %arg2: t
857857

858858
// -----
859859

860+
func.func @valid_select_ds_infer(%arg0: tensor<?x?xi1>, %arg1: tensor<?x?xf16>, %arg2: tensor<1x1xf16>) -> tensor<?x?xf16> {
861+
%0 = tensorrt.select ins(%arg0, %arg1, %arg2 : tensor<?x?xi1>, tensor<?x?xf16>, tensor<1x1xf16>) -> tensor<?x?xf16>
862+
return %0 : tensor<?x?xf16>
863+
}
864+
865+
// -----
866+
867+
func.func @valid_select_ds_infer2(%arg0: tensor<1x?xi1>, %arg1: tensor<1x?xf16>, %arg2: tensor<1x1xf16>) -> tensor<?x?xf16> {
868+
%0 = tensorrt.select ins(%arg0, %arg1, %arg2 : tensor<1x?xi1>, tensor<1x?xf16>, tensor<1x1xf16>) -> tensor<?x?xf16>
869+
return %0 : tensor<?x?xf16>
870+
}
871+
872+
// -----
873+
874+
func.func @valid_select_ds_infer3(%arg0: tensor<1x?xi1>, %arg1: tensor<1x?xf16>, %arg2: tensor<1x1xf16>) -> tensor<1x1xf16> {
875+
%0 = tensorrt.select ins(%arg0, %arg1, %arg2 : tensor<1x?xi1>, tensor<1x?xf16>, tensor<1x1xf16>) -> tensor<1x1xf16>
876+
return %0 : tensor<1x1xf16>
877+
}
878+
879+
// -----
880+
860881
func.func @trt_softmax(%arg0: tensor<10x10xf32>) -> tensor<10x10xf32> {
861882
// expected-error @below {{'tensorrt.softmax' op expected axis to be non-negative and less than 2}}
862883
%0 = tensorrt.softmax {axis = 2 : i64} %arg0 : tensor<10x10xf32>

0 commit comments

Comments
 (0)