Skip to content

Commit 81831ef

Browse files
authored
[flang][cuda] Correctly allocate descriptor in managed memory when reboxing (#120795)
Reboxing might create a new in memory descriptor. If this one was allocate with managed memory, allocate the new one in managed memory as well.
1 parent f82bb3d commit 81831ef

File tree

2 files changed

+78
-3
lines changed

2 files changed

+78
-3
lines changed

flang/lib/Optimizer/CodeGen/CodeGen.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1725,13 +1725,17 @@ struct EmboxOpConversion : public EmboxCommonConversion<fir::EmboxOp> {
17251725
};
17261726

17271727
static bool isDeviceAllocation(mlir::Value val) {
1728+
if (auto loadOp = mlir::dyn_cast_or_null<fir::LoadOp>(val.getDefiningOp()))
1729+
return isDeviceAllocation(loadOp.getMemref());
17281730
if (auto convertOp =
17291731
mlir::dyn_cast_or_null<fir::ConvertOp>(val.getDefiningOp()))
17301732
val = convertOp.getValue();
17311733
if (auto callOp = mlir::dyn_cast_or_null<fir::CallOp>(val.getDefiningOp()))
17321734
if (callOp.getCallee() &&
1733-
callOp.getCallee().value().getRootReference().getValue().starts_with(
1734-
RTNAME_STRING(CUFMemAlloc)))
1735+
(callOp.getCallee().value().getRootReference().getValue().starts_with(
1736+
RTNAME_STRING(CUFMemAlloc)) ||
1737+
callOp.getCallee().value().getRootReference().getValue().starts_with(
1738+
RTNAME_STRING(CUFAllocDesciptor))))
17351739
return true;
17361740
return false;
17371741
}
@@ -2045,7 +2049,8 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
20452049
}
20462050
dest = insertBaseAddress(rewriter, loc, dest, base);
20472051
mlir::Value result =
2048-
placeInMemoryIfNotGlobalInit(rewriter, rebox.getLoc(), destBoxTy, dest);
2052+
placeInMemoryIfNotGlobalInit(rewriter, rebox.getLoc(), destBoxTy, dest,
2053+
isDeviceAllocation(rebox.getBox()));
20492054
rewriter.replaceOp(rebox, result);
20502055
return mlir::success();
20512056
}

flang/test/Fir/CUDA/cuda-code-gen.mlir

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,73 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<f80 = dense<128> : vector<2xi64>
5656
// CHECK-LABEL: llvm.func @_QQmain()
5757
// CHECK: llvm.call @_FortranACUFMemAlloc
5858
// CHECK: llvm.call @_FortranACUFAllocDesciptor
59+
60+
// -----
61+
62+
module attributes {dlti.dl_spec = #dlti.dl_spec<f80 = dense<128> : vector<2xi64>, i128 = dense<128> : vector<2xi64>, i64 = dense<64> : vector<2xi64>, !llvm.ptr<272> = dense<64> : vector<4xi64>, !llvm.ptr<271> = dense<32> : vector<4xi64>, !llvm.ptr<270> = dense<32> : vector<4xi64>, f128 = dense<128> : vector<2xi64>, f64 = dense<64> : vector<2xi64>, f16 = dense<16> : vector<2xi64>, i32 = dense<32> : vector<2xi64>, i16 = dense<16> : vector<2xi64>, i8 = dense<8> : vector<2xi64>, i1 = dense<8> : vector<2xi64>, !llvm.ptr = dense<64> : vector<4xi64>, "dlti.endianness" = "little", "dlti.stack_alignment" = 128 : i64>} {
63+
func.func @_QQmain() attributes {fir.bindc_name = "p1"} {
64+
%c1_i32 = arith.constant 1 : i32
65+
%c0_i32 = arith.constant 0 : i32
66+
%c16_i32 = arith.constant 16 : i32
67+
%c1 = arith.constant 1 : index
68+
%c0 = arith.constant 0 : index
69+
%0 = fir.alloca i32 {bindc_name = "iblk", uniq_name = "_QFEiblk"}
70+
%1 = fir.alloca i32 {bindc_name = "ithr", uniq_name = "_QFEithr"}
71+
%2 = fir.address_of(@_QQclX64756D6D792E6D6C697200) : !fir.ref<!fir.char<1,11>>
72+
%c14_i32 = arith.constant 14 : i32
73+
%c72 = arith.constant 72 : index
74+
%3 = fir.convert %c72 : (index) -> i64
75+
%4 = fir.convert %2 : (!fir.ref<!fir.char<1,11>>) -> !fir.ref<i8>
76+
%5 = fir.call @_FortranACUFAllocDesciptor(%3, %4, %c14_i32) : (i64, !fir.ref<i8>, i32) -> !fir.ref<!fir.box<none>>
77+
%6 = fir.convert %5 : (!fir.ref<!fir.box<none>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>
78+
%7 = fir.zero_bits !fir.heap<!fir.array<?x?xf32>>
79+
%8 = fircg.ext_embox %7(%c0, %c0) {allocator_idx = 2 : i32} : (!fir.heap<!fir.array<?x?xf32>>, index, index) -> !fir.box<!fir.heap<!fir.array<?x?xf32>>>
80+
fir.store %8 to %6 : !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>
81+
%9 = fir.address_of(@_QQclX64756D6D792E6D6C697200) : !fir.ref<!fir.char<1,11>>
82+
%c20_i32 = arith.constant 20 : i32
83+
%c48 = arith.constant 48 : index
84+
%10 = fir.convert %c48 : (index) -> i64
85+
%11 = fir.convert %9 : (!fir.ref<!fir.char<1,11>>) -> !fir.ref<i8>
86+
%12 = fir.call @_FortranACUFAllocDesciptor(%10, %11, %c20_i32) : (i64, !fir.ref<i8>, i32) -> !fir.ref<!fir.box<none>>
87+
%13 = fir.convert %12 : (!fir.ref<!fir.box<none>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
88+
%14 = fir.zero_bits !fir.heap<!fir.array<?xf32>>
89+
%15 = fircg.ext_embox %14(%c0) {allocator_idx = 2 : i32} : (!fir.heap<!fir.array<?xf32>>, index) -> !fir.box<!fir.heap<!fir.array<?xf32>>>
90+
fir.store %15 to %13 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
91+
%16 = fir.convert %6 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>) -> !fir.ref<!fir.box<none>>
92+
%17 = fir.convert %c1 : (index) -> i64
93+
%18 = fir.convert %c16_i32 : (i32) -> i64
94+
%19 = fir.call @_FortranAAllocatableSetBounds(%16, %c0_i32, %17, %18) fastmath<contract> : (!fir.ref<!fir.box<none>>, i32, i64, i64) -> none
95+
%20 = fir.call @_FortranAAllocatableSetBounds(%16, %c1_i32, %17, %18) fastmath<contract> : (!fir.ref<!fir.box<none>>, i32, i64, i64) -> none
96+
%21 = fir.address_of(@_QQclX64756D6D792E6D6C697200) : !fir.ref<!fir.char<1,11>>
97+
%c31_i32 = arith.constant 31 : i32
98+
%false = arith.constant false
99+
%22 = fir.absent !fir.box<none>
100+
%c-1_i64 = arith.constant -1 : i64
101+
%23 = fir.convert %6 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>) -> !fir.ref<!fir.box<none>>
102+
%24 = fir.convert %21 : (!fir.ref<!fir.char<1,11>>) -> !fir.ref<i8>
103+
%25 = fir.call @_FortranACUFAllocatableAllocate(%23, %c-1_i64, %false, %22, %24, %c31_i32) : (!fir.ref<!fir.box<none>>, i64, i1, !fir.box<none>, !fir.ref<i8>, i32) -> i32
104+
%26 = fir.convert %13 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
105+
%27 = fir.call @_FortranAAllocatableSetBounds(%26, %c0_i32, %17, %18) fastmath<contract> : (!fir.ref<!fir.box<none>>, i32, i64, i64) -> none
106+
%28 = fir.address_of(@_QQclX64756D6D792E6D6C697200) : !fir.ref<!fir.char<1,11>>
107+
%c34_i32 = arith.constant 34 : i32
108+
%false_0 = arith.constant false
109+
%29 = fir.absent !fir.box<none>
110+
%c-1_i64_1 = arith.constant -1 : i64
111+
%30 = fir.convert %13 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
112+
%31 = fir.convert %28 : (!fir.ref<!fir.char<1,11>>) -> !fir.ref<i8>
113+
%32 = fir.call @_FortranACUFAllocatableAllocate(%30, %c-1_i64_1, %false_0, %29, %31, %c34_i32) : (!fir.ref<!fir.box<none>>, i64, i1, !fir.box<none>, !fir.ref<i8>, i32) -> i32
114+
%33 = fir.load %6 : !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>
115+
%34 = fircg.ext_rebox %33 : (!fir.box<!fir.heap<!fir.array<?x?xf32>>>) -> !fir.box<!fir.array<?x?xf32>>
116+
return
117+
}
118+
func.func private @_FortranAAllocatableSetBounds(!fir.ref<!fir.box<none>>, i32, i64, i64) -> none attributes {fir.runtime}
119+
fir.global linkonce @_QQclX64756D6D792E6D6C697200 constant : !fir.char<1,11> {
120+
%0 = fir.string_lit "dummy.mlir\00"(11) : !fir.char<1,11>
121+
fir.has_value %0 : !fir.char<1,11>
122+
}
123+
func.func private @_FortranACUFAllocDesciptor(i64, !fir.ref<i8>, i32) -> !fir.ref<!fir.box<none>> attributes {fir.runtime}
124+
func.func private @_FortranACUFAllocatableAllocate(!fir.ref<!fir.box<none>>, i64, i1, !fir.box<none>, !fir.ref<i8>, i32) -> i32 attributes {fir.runtime}
125+
}
126+
127+
// CHECK-LABEL: llvm.func @_QQmain()
128+
// CHECK-COUNT-4: llvm.call @_FortranACUFAllocDesciptor

0 commit comments

Comments
 (0)