Skip to content

Re-enable torch-adjust-calling-conventions tests #4034

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 40 additions & 16 deletions lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,27 +164,51 @@ class AdjustCallingConventionForReturn
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
matchAndRewrite(func::ReturnOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

SmallVector<Value> newOperands;
for (auto operand : adaptor.getOperands()) {
if (!operand)
continue;
if (isa<Torch::NoneType>(operand.getType()))
continue;
if (auto tuple = dyn_cast<Torch::TupleType>(operand.getType())) {
Location loc = op.getLoc();
for (auto en : llvm::enumerate(tuple.getContainedTypes())) {
auto i = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(en.index()));
newOperands.push_back(
rewriter.create<PrimTupleIndexOp>(loc, en.value(), operand, i));
for (const auto &vals : adaptor.getOperands()) {
if (vals.size() == 1) {
if (isa<Torch::NoneType>(vals[0].getType()))
continue;
newOperands.push_back(vals[0]);
} else if (vals.size() > 1) {
// The dialect conversion framework inserts unrealized conversion casts
// to materialize legal types from illegal types. For example, for input
// IR like
// %1 = torch.prim.TupleConstruct %arg0, %arg1 : !torch.tensor,
// torch.tensor -> !torch.tuple<tensor, tensor>
// return %1 : !torch.tuple<tensor, tensor>
// at this stage in the conversion process we'll have something like
// %1 = torch.prim.TupleConstruct %arg0, %arg1 : !torch.tensor,
// !torch.tensor -> !torch.tuple<tensor, tensor>
// %2 = builtin.unrealized_conversion_cast %1 :
// !torch.tuple<tensor, tensor> to !torch.tensor
// %3 = builtin.unrealized_conversion_cast %1 :
// !torch.tuple<tensor, tensor> to !torch.tensor
// return %2, %3 : !torch.tensor, !torch.tensor
//
// Given (%2, %3) as operands, here we map back to the original
// torch.prim.TupleConstruct.
if (vals[0].getDefiningOp() &&
isa<mlir::UnrealizedConversionCastOp>(vals[0].getDefiningOp())) {
Value operand = vals[0].getDefiningOp()->getOperand(0);
if (auto tuple = dyn_cast<Torch::TupleType>(operand.getType())) {
Location loc = op.getLoc();
for (auto en : llvm::enumerate(tuple.getContainedTypes())) {
auto i = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(en.index()));
newOperands.push_back(rewriter.create<PrimTupleIndexOp>(
loc, en.value(), operand, i));
}
continue;
}
}
continue;

llvm::append_range(newOperands, vals);
}
newOperands.push_back(operand);
}

rewriter.replaceOpWithNewOp<func::ReturnOp>(op, newOperands);
return success();
}
Expand Down
131 changes: 66 additions & 65 deletions test/Dialect/Torch/adjust-calling-conventions.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,17 @@ func.func @basic(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[2,3,?]
return %arg0 : !torch.tensor
}

// -----

// CHECK-LABEL: func.func @no_type_bound(
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor) -> !torch.tensor {
// CHECK: return %[[ARG]] : !torch.tensor
func.func @no_type_bound(%arg0: !torch.tensor) -> !torch.tensor {
return %arg0 : !torch.tensor
}

// -----

// CHECK-LABEL: func.func @call(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor {
// CHECK: %[[ARG_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.vtensor<[2,3,?],f32> to !torch.vtensor
Expand All @@ -29,71 +33,68 @@ func.func @call(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[2,3,?],
return %arg0 : !torch.tensor
}

// COM: func.func @none_return() {
// COM: %[[NONE:.*]] = torch.constant.none
// COM: return
// func.func @none_return() -> !torch.none {
// %1 = torch.constant.none
// return %1 : !torch.none
// }
// -----

// CHECK-LABEL: func.func @none_return() {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: return
func.func @none_return() -> !torch.none {
%1 = torch.constant.none
return %1 : !torch.none
}

// CHECK-LABEL: func.func @none_call_return() {
// CHECK: call @none_return() : () -> ()
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: "test.use"(%[[NONE]]) : (!torch.none) -> ()
// CHECK: return
func.func @none_call_return() {
%0 = call @none_return() : () -> !torch.none
"test.use"(%0) : (!torch.none) -> ()
return
}

// COM: func.func @none_call_return() {
// COM: call @none_return() : () -> ()
// COM: %[[NONE:.*]] = torch.constant.none
// COM: "test.use"(%[[NONE]]) : (!torch.none) -> ()
// COM: return
// func.func @none_call_return() {
// %0 = call @none_return() : () -> !torch.none
// "test.use"(%0) : (!torch.none) -> ()
// return
// }
// -----

// COM: func.func @tuple_return(
// COM: %[[ARG0:.*]]: !torch.vtensor<[?],f32>,
// COM: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor) {
// COM: %[[ARG0_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG0]] : !torch.vtensor<[?],f32> to !torch.vtensor
// COM: %[[ARG0_NONVAL:.*]] = torch.copy.to_tensor %[[ARG0_ERASED]] : !torch.tensor
// COM: %[[ARG1_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG1]] : !torch.vtensor<[?],f32> to !torch.vtensor
// COM: %[[ARG1_NONVAL:.*]] = torch.copy.to_tensor %[[ARG1_ERASED]] : !torch.tensor
// COM: %[[TUPLE:.*]] = torch.prim.TupleConstruct %[[ARG0_NONVAL]], %[[ARG1_NONVAL]] :
// COM: !torch.tensor, !torch.tensor -> !torch.tuple<tensor, tensor>
// COM: %[[CST0:.*]] = torch.constant.int 0
// COM: %[[RET0:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST0]] :
// COM: !torch.tuple<tensor, tensor>, !torch.int -> !torch.tensor
// COM: %[[CST1:.*]] = torch.constant.int 1
// COM: %[[RET1:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST1]] :
// COM: !torch.tuple<tensor, tensor>, !torch.int -> !torch.tensor
// COM: return %[[RET0]], %[[RET1]] : !torch.tensor, !torch.tensor
// func.func @tuple_return(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>},
// %arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}) -> !torch.tuple<tensor, tensor> {
// %1 = torch.prim.TupleConstruct %arg0, %arg1 : !torch.tensor, !torch.tensor -> !torch.tuple<tensor, tensor>
// return %1 : !torch.tuple<tensor, tensor>
// }
// CHECK-LABEL: func.func @tuple_return(
// CHECK: %[[ARG0:.*]]: !torch.vtensor<[?],f32>,
// CHECK: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor) {
// CHECK: %[[ARG1_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG1]] : !torch.vtensor<[?],f32> to !torch.vtensor
// CHECK: %[[ARG1_NONVAL:.*]] = torch.copy.to_tensor %[[ARG1_ERASED]] : !torch.tensor
// CHECK: %[[ARG0_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG0]] : !torch.vtensor<[?],f32> to !torch.vtensor
// CHECK: %[[ARG0_NONVAL:.*]] = torch.copy.to_tensor %[[ARG0_ERASED]] : !torch.tensor
// CHECK: %[[TUPLE:.*]] = torch.prim.TupleConstruct %[[ARG0_NONVAL]], %[[ARG1_NONVAL]] : !torch.tensor, !torch.tensor -> !torch.tuple<tensor, tensor>
// CHECK: %[[CST0:.*]] = torch.constant.int 0
// CHECK: %[[RET0:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST0]] : !torch.tuple<tensor, tensor>, !torch.int -> !torch.tensor
// CHECK: %[[CST1:.*]] = torch.constant.int 1
// CHECK: %[[RET1:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST1]] : !torch.tuple<tensor, tensor>, !torch.int -> !torch.tensor
// CHECK: return %[[RET0]], %[[RET1]] : !torch.tensor, !torch.tensor
func.func @tuple_return(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>},
%arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}) -> !torch.tuple<tensor, tensor> {
%1 = torch.prim.TupleConstruct %arg0, %arg1 : !torch.tensor, !torch.tensor -> !torch.tuple<tensor, tensor>
return %1 : !torch.tuple<tensor, tensor>
}

// COM: func.func @call_tuple_return(
// COM: %[[ARG0:.*]]: !torch.vtensor<[?],f32>,
// COM: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor) {
// COM: %[[ARG0_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG0]] : !torch.vtensor<[?],f32> to !torch.vtensor
// COM: %[[ARG0_NONVAL:.*]] = torch.copy.to_tensor %[[ARG0_ERASED]] : !torch.tensor
// COM: %[[ARG1_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG1]] : !torch.vtensor<[?],f32> to !torch.vtensor
// COM: %[[ARG1_NONVAL:.*]] = torch.copy.to_tensor %[[ARG1_ERASED]] : !torch.tensor
// COM: %[[ARG0_NONVAL_SHAPED:.*]] = torch.tensor_static_info_cast %[[ARG0_NONVAL]] : !torch.tensor to !torch.tensor<[?],f32>
// COM: %[[ARG0_VAL_SHAPED:.*]] = torch.copy.to_vtensor %[[ARG0_NONVAL_SHAPED]] : !torch.vtensor<[?],f32>
// COM: %[[ARG1_NONVAL_SHAPED:.*]] = torch.tensor_static_info_cast %[[ARG1_NONVAL]] : !torch.tensor to !torch.tensor<[?],f32>
// COM: %[[ARG1_VAL_SHAPED:.*]] = torch.copy.to_vtensor %[[ARG1_NONVAL_SHAPED]] : !torch.vtensor<[?],f32>
// COM: %[[RETS:.*]]:2 = call @tuple_return(%[[ARG0_VAL_SHAPED]], %[[ARG1_VAL_SHAPED]]) :
// COM: (!torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor)
// COM: %[[TUPLE:.*]] = torch.prim.TupleConstruct %[[RETS]]#0, %[[RETS]]#1 :
// COM: !torch.tensor, !torch.tensor -> !torch.tuple<tensor, tensor>
// COM: %[[CST0:.*]] = torch.constant.int 0
// COM: %[[RET0:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST0]] :
// COM: !torch.tuple<tensor, tensor>, !torch.int -> !torch.tensor
// COM: %[[CST1:.*]] = torch.constant.int 1
// COM: %[[RET1:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST1]] :
// COM: !torch.tuple<tensor, tensor>, !torch.int -> !torch.tensor
// COM: return %[[RET0]], %[[RET1]] : !torch.tensor, !torch.tensor
// func.func @call_tuple_return(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>},
// %arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}) -> !torch.tuple<tensor, tensor> {
// %0 = call @tuple_return(%arg0, %arg1) : (!torch.tensor, !torch.tensor) -> !torch.tuple<tensor, tensor>
// return %0 : !torch.tuple<tensor, tensor>
// }
// CHECK-LABEL: func.func @call_tuple_return(
// CHECK: %[[ARG0:.*]]: !torch.vtensor<[?],f32>,
// CHECK: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor) {
// CHECK: %[[ARG1_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG1]] : !torch.vtensor<[?],f32> to !torch.vtensor
// CHECK: %[[ARG1_NONVAL:.*]] = torch.copy.to_tensor %[[ARG1_ERASED]] : !torch.tensor
// CHECK: %[[ARG0_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG0]] : !torch.vtensor<[?],f32> to !torch.vtensor
// CHECK: %[[ARG0_NONVAL:.*]] = torch.copy.to_tensor %[[ARG0_ERASED]] : !torch.tensor
// CHECK: %[[ARG0_NONVAL_SHAPED:.*]] = torch.tensor_static_info_cast %[[ARG0_NONVAL]] : !torch.tensor to !torch.tensor<[?],f32>
// CHECK: %[[ARG0_VAL_SHAPED:.*]] = torch.copy.to_vtensor %[[ARG0_NONVAL_SHAPED]] : !torch.vtensor<[?],f32>
// CHECK: %[[ARG1_NONVAL_SHAPED:.*]] = torch.tensor_static_info_cast %[[ARG1_NONVAL]] : !torch.tensor to !torch.tensor<[?],f32>
// CHECK: %[[ARG1_VAL_SHAPED:.*]] = torch.copy.to_vtensor %[[ARG1_NONVAL_SHAPED]] : !torch.vtensor<[?],f32>
// CHECK: %[[RETS:.*]]:2 = call @tuple_return(%[[ARG0_VAL_SHAPED]], %[[ARG1_VAL_SHAPED]]) : (!torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor)
// CHECK: %[[TUPLE:.*]] = torch.prim.TupleConstruct %[[RETS]]#0, %[[RETS]]#1 : !torch.tensor, !torch.tensor -> !torch.tuple<tensor, tensor>
// CHECK: %[[CST0:.*]] = torch.constant.int 0
// CHECK: %[[RET0:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST0]] : !torch.tuple<tensor, tensor>, !torch.int -> !torch.tensor
// CHECK: %[[CST1:.*]] = torch.constant.int 1
// CHECK: %[[RET1:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST1]] : !torch.tuple<tensor, tensor>, !torch.int -> !torch.tensor
// CHECK: return %[[RET0]], %[[RET1]] : !torch.tensor, !torch.tensor
func.func @call_tuple_return(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>},
%arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}) -> !torch.tuple<tensor, tensor> {
%0 = call @tuple_return(%arg0, %arg1) : (!torch.tensor, !torch.tensor) -> !torch.tuple<tensor, tensor>
return %0 : !torch.tuple<tensor, tensor>
}
Loading