diff --git a/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp b/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp index e8b0d6b0364c..3508f1bc059e 100644 --- a/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp +++ b/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp @@ -164,27 +164,51 @@ class AdjustCallingConventionForReturn public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor, + matchAndRewrite(func::ReturnOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - SmallVector newOperands; - for (auto operand : adaptor.getOperands()) { - if (!operand) - continue; - if (isa(operand.getType())) - continue; - if (auto tuple = dyn_cast(operand.getType())) { - Location loc = op.getLoc(); - for (auto en : llvm::enumerate(tuple.getContainedTypes())) { - auto i = rewriter.create( - loc, rewriter.getI64IntegerAttr(en.index())); - newOperands.push_back( - rewriter.create(loc, en.value(), operand, i)); + for (const auto &vals : adaptor.getOperands()) { + if (vals.size() == 1) { + if (isa(vals[0].getType())) + continue; + newOperands.push_back(vals[0]); + } else if (vals.size() > 1) { + // The dialect conversion framework inserts unrealized conversion casts + // to materialize legal types from illegal types. For example, for input + // IR like + // %1 = torch.prim.TupleConstruct %arg0, %arg1 : !torch.tensor, + // torch.tensor -> !torch.tuple + // return %1 : !torch.tuple + // at this stage in the conversion process we'll have something like + // %1 = torch.prim.TupleConstruct %arg0, %arg1 : !torch.tensor, + // !torch.tensor -> !torch.tuple + // %2 = builtin.unrealized_conversion_cast %1 : + // !torch.tuple to !torch.tensor + // %3 = builtin.unrealized_conversion_cast %1 : + // !torch.tuple to !torch.tensor + // return %2, %3 : !torch.tensor, !torch.tensor + // + // Given (%2, %3) as operands, here we map back to the original + // torch.prim.TupleConstruct. + if (vals[0].getDefiningOp() && + isa(vals[0].getDefiningOp())) { + Value operand = vals[0].getDefiningOp()->getOperand(0); + if (auto tuple = dyn_cast(operand.getType())) { + Location loc = op.getLoc(); + for (auto en : llvm::enumerate(tuple.getContainedTypes())) { + auto i = rewriter.create( + loc, rewriter.getI64IntegerAttr(en.index())); + newOperands.push_back(rewriter.create( + loc, en.value(), operand, i)); + } + continue; + } } - continue; + + llvm::append_range(newOperands, vals); } - newOperands.push_back(operand); } + rewriter.replaceOpWithNewOp(op, newOperands); return success(); } diff --git a/test/Dialect/Torch/adjust-calling-conventions.mlir b/test/Dialect/Torch/adjust-calling-conventions.mlir index 455a8e847486..992b60271327 100644 --- a/test/Dialect/Torch/adjust-calling-conventions.mlir +++ b/test/Dialect/Torch/adjust-calling-conventions.mlir @@ -9,6 +9,8 @@ func.func @basic(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[2,3,?] return %arg0 : !torch.tensor } +// ----- + // CHECK-LABEL: func.func @no_type_bound( // CHECK-SAME: %[[ARG:.*]]: !torch.tensor) -> !torch.tensor { // CHECK: return %[[ARG]] : !torch.tensor @@ -16,6 +18,8 @@ func.func @no_type_bound(%arg0: !torch.tensor) -> !torch.tensor { return %arg0 : !torch.tensor } +// ----- + // CHECK-LABEL: func.func @call( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor { // CHECK: %[[ARG_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.vtensor<[2,3,?],f32> to !torch.vtensor @@ -29,71 +33,68 @@ func.func @call(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[2,3,?], return %arg0 : !torch.tensor } -// COM: func.func @none_return() { -// COM: %[[NONE:.*]] = torch.constant.none -// COM: return -// func.func @none_return() -> !torch.none { -// %1 = torch.constant.none -// return %1 : !torch.none -// } +// ----- + +// CHECK-LABEL: func.func @none_return() { +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: return +func.func @none_return() -> !torch.none { + %1 = torch.constant.none + return %1 : !torch.none +} + +// CHECK-LABEL: func.func @none_call_return() { +// CHECK: call @none_return() : () -> () +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: "test.use"(%[[NONE]]) : (!torch.none) -> () +// CHECK: return +func.func @none_call_return() { + %0 = call @none_return() : () -> !torch.none + "test.use"(%0) : (!torch.none) -> () + return +} -// COM: func.func @none_call_return() { -// COM: call @none_return() : () -> () -// COM: %[[NONE:.*]] = torch.constant.none -// COM: "test.use"(%[[NONE]]) : (!torch.none) -> () -// COM: return -// func.func @none_call_return() { -// %0 = call @none_return() : () -> !torch.none -// "test.use"(%0) : (!torch.none) -> () -// return -// } +// ----- -// COM: func.func @tuple_return( -// COM: %[[ARG0:.*]]: !torch.vtensor<[?],f32>, -// COM: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor) { -// COM: %[[ARG0_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG0]] : !torch.vtensor<[?],f32> to !torch.vtensor -// COM: %[[ARG0_NONVAL:.*]] = torch.copy.to_tensor %[[ARG0_ERASED]] : !torch.tensor -// COM: %[[ARG1_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG1]] : !torch.vtensor<[?],f32> to !torch.vtensor -// COM: %[[ARG1_NONVAL:.*]] = torch.copy.to_tensor %[[ARG1_ERASED]] : !torch.tensor -// COM: %[[TUPLE:.*]] = torch.prim.TupleConstruct %[[ARG0_NONVAL]], %[[ARG1_NONVAL]] : -// COM: !torch.tensor, !torch.tensor -> !torch.tuple -// COM: %[[CST0:.*]] = torch.constant.int 0 -// COM: %[[RET0:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST0]] : -// COM: !torch.tuple, !torch.int -> !torch.tensor -// COM: %[[CST1:.*]] = torch.constant.int 1 -// COM: %[[RET1:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST1]] : -// COM: !torch.tuple, !torch.int -> !torch.tensor -// COM: return %[[RET0]], %[[RET1]] : !torch.tensor, !torch.tensor -// func.func @tuple_return(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}, -// %arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}) -> !torch.tuple { -// %1 = torch.prim.TupleConstruct %arg0, %arg1 : !torch.tensor, !torch.tensor -> !torch.tuple -// return %1 : !torch.tuple -// } +// CHECK-LABEL: func.func @tuple_return( +// CHECK: %[[ARG0:.*]]: !torch.vtensor<[?],f32>, +// CHECK: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor) { +// CHECK: %[[ARG1_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG1]] : !torch.vtensor<[?],f32> to !torch.vtensor +// CHECK: %[[ARG1_NONVAL:.*]] = torch.copy.to_tensor %[[ARG1_ERASED]] : !torch.tensor +// CHECK: %[[ARG0_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG0]] : !torch.vtensor<[?],f32> to !torch.vtensor +// CHECK: %[[ARG0_NONVAL:.*]] = torch.copy.to_tensor %[[ARG0_ERASED]] : !torch.tensor +// CHECK: %[[TUPLE:.*]] = torch.prim.TupleConstruct %[[ARG0_NONVAL]], %[[ARG1_NONVAL]] : !torch.tensor, !torch.tensor -> !torch.tuple +// CHECK: %[[CST0:.*]] = torch.constant.int 0 +// CHECK: %[[RET0:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST0]] : !torch.tuple, !torch.int -> !torch.tensor +// CHECK: %[[CST1:.*]] = torch.constant.int 1 +// CHECK: %[[RET1:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST1]] : !torch.tuple, !torch.int -> !torch.tensor +// CHECK: return %[[RET0]], %[[RET1]] : !torch.tensor, !torch.tensor +func.func @tuple_return(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}, + %arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}) -> !torch.tuple { + %1 = torch.prim.TupleConstruct %arg0, %arg1 : !torch.tensor, !torch.tensor -> !torch.tuple + return %1 : !torch.tuple +} -// COM: func.func @call_tuple_return( -// COM: %[[ARG0:.*]]: !torch.vtensor<[?],f32>, -// COM: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor) { -// COM: %[[ARG0_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG0]] : !torch.vtensor<[?],f32> to !torch.vtensor -// COM: %[[ARG0_NONVAL:.*]] = torch.copy.to_tensor %[[ARG0_ERASED]] : !torch.tensor -// COM: %[[ARG1_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG1]] : !torch.vtensor<[?],f32> to !torch.vtensor -// COM: %[[ARG1_NONVAL:.*]] = torch.copy.to_tensor %[[ARG1_ERASED]] : !torch.tensor -// COM: %[[ARG0_NONVAL_SHAPED:.*]] = torch.tensor_static_info_cast %[[ARG0_NONVAL]] : !torch.tensor to !torch.tensor<[?],f32> -// COM: %[[ARG0_VAL_SHAPED:.*]] = torch.copy.to_vtensor %[[ARG0_NONVAL_SHAPED]] : !torch.vtensor<[?],f32> -// COM: %[[ARG1_NONVAL_SHAPED:.*]] = torch.tensor_static_info_cast %[[ARG1_NONVAL]] : !torch.tensor to !torch.tensor<[?],f32> -// COM: %[[ARG1_VAL_SHAPED:.*]] = torch.copy.to_vtensor %[[ARG1_NONVAL_SHAPED]] : !torch.vtensor<[?],f32> -// COM: %[[RETS:.*]]:2 = call @tuple_return(%[[ARG0_VAL_SHAPED]], %[[ARG1_VAL_SHAPED]]) : -// COM: (!torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor) -// COM: %[[TUPLE:.*]] = torch.prim.TupleConstruct %[[RETS]]#0, %[[RETS]]#1 : -// COM: !torch.tensor, !torch.tensor -> !torch.tuple -// COM: %[[CST0:.*]] = torch.constant.int 0 -// COM: %[[RET0:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST0]] : -// COM: !torch.tuple, !torch.int -> !torch.tensor -// COM: %[[CST1:.*]] = torch.constant.int 1 -// COM: %[[RET1:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST1]] : -// COM: !torch.tuple, !torch.int -> !torch.tensor -// COM: return %[[RET0]], %[[RET1]] : !torch.tensor, !torch.tensor -// func.func @call_tuple_return(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}, -// %arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}) -> !torch.tuple { -// %0 = call @tuple_return(%arg0, %arg1) : (!torch.tensor, !torch.tensor) -> !torch.tuple -// return %0 : !torch.tuple -// } +// CHECK-LABEL: func.func @call_tuple_return( +// CHECK: %[[ARG0:.*]]: !torch.vtensor<[?],f32>, +// CHECK: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor) { +// CHECK: %[[ARG1_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG1]] : !torch.vtensor<[?],f32> to !torch.vtensor +// CHECK: %[[ARG1_NONVAL:.*]] = torch.copy.to_tensor %[[ARG1_ERASED]] : !torch.tensor +// CHECK: %[[ARG0_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG0]] : !torch.vtensor<[?],f32> to !torch.vtensor +// CHECK: %[[ARG0_NONVAL:.*]] = torch.copy.to_tensor %[[ARG0_ERASED]] : !torch.tensor +// CHECK: %[[ARG0_NONVAL_SHAPED:.*]] = torch.tensor_static_info_cast %[[ARG0_NONVAL]] : !torch.tensor to !torch.tensor<[?],f32> +// CHECK: %[[ARG0_VAL_SHAPED:.*]] = torch.copy.to_vtensor %[[ARG0_NONVAL_SHAPED]] : !torch.vtensor<[?],f32> +// CHECK: %[[ARG1_NONVAL_SHAPED:.*]] = torch.tensor_static_info_cast %[[ARG1_NONVAL]] : !torch.tensor to !torch.tensor<[?],f32> +// CHECK: %[[ARG1_VAL_SHAPED:.*]] = torch.copy.to_vtensor %[[ARG1_NONVAL_SHAPED]] : !torch.vtensor<[?],f32> +// CHECK: %[[RETS:.*]]:2 = call @tuple_return(%[[ARG0_VAL_SHAPED]], %[[ARG1_VAL_SHAPED]]) : (!torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor) +// CHECK: %[[TUPLE:.*]] = torch.prim.TupleConstruct %[[RETS]]#0, %[[RETS]]#1 : !torch.tensor, !torch.tensor -> !torch.tuple +// CHECK: %[[CST0:.*]] = torch.constant.int 0 +// CHECK: %[[RET0:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST0]] : !torch.tuple, !torch.int -> !torch.tensor +// CHECK: %[[CST1:.*]] = torch.constant.int 1 +// CHECK: %[[RET1:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST1]] : !torch.tuple, !torch.int -> !torch.tensor +// CHECK: return %[[RET0]], %[[RET1]] : !torch.tensor, !torch.tensor +func.func @call_tuple_return(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}, + %arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}) -> !torch.tuple { + %0 = call @tuple_return(%arg0, %arg1) : (!torch.tensor, !torch.tensor) -> !torch.tuple + return %0 : !torch.tuple +}