Skip to content

[flang][cuda] Support data transfer from pointer to a descriptor #114892

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 5, 2024

Conversation

clementval
Copy link
Contributor

@clementval clementval commented Nov 4, 2024

When source is a pointer to an array or a scalar, embox it and use the CUFDataTransferDescDesc or CUFDataTransferGlobalDescDesc entry points. The runtime is already able to deal with all the corner cases like non contiguous arrays and so on so we exploit this.

Memset might still be used for simple case where we want to initialize to 0 for example. This will come in a follow up patch.

@llvmbot llvmbot added flang:runtime flang Flang issues not falling into any other category flang:fir-hlfir labels Nov 4, 2024
@llvmbot
Copy link
Member

llvmbot commented Nov 4, 2024

@llvm/pr-subscribers-flang-runtime

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

When source is a pointer to an array or a scalar, emboss it and use the CUFDataTransferDescDesc or CUFDataTransferGlobalDescDesc entry points. The runtime is already able to deal with all the corner cases like non contiguous arrays and so on so we exploit this.

Memset might still be used for simple case where we want to initialize to 0 for example. This will come in a follow up patch.


Full diff: https://github.com/llvm/llvm-project/pull/114892.diff

4 Files Affected:

  • (modified) flang/include/flang/Runtime/CUDA/memory.h (-5)
  • (modified) flang/lib/Optimizer/Transforms/CUFOpConversion.cpp (+39-12)
  • (modified) flang/runtime/CUDA/memory.cpp (-7)
  • (modified) flang/test/Fir/CUDA/cuda-data-transfer.fir (+52-18)
diff --git a/flang/include/flang/Runtime/CUDA/memory.h b/flang/include/flang/Runtime/CUDA/memory.h
index 51d6b8d4545f09..4ac2528c1aedbc 100644
--- a/flang/include/flang/Runtime/CUDA/memory.h
+++ b/flang/include/flang/Runtime/CUDA/memory.h
@@ -35,11 +35,6 @@ void RTDECL(CUFMemsetDescriptor)(Descriptor *desc, void *value,
 void RTDECL(CUFDataTransferPtrPtr)(void *dst, void *src, std::size_t bytes,
     unsigned mode, const char *sourceFile = nullptr, int sourceLine = 0);
 
-/// Data transfer from a pointer to a descriptor.
-void RTDECL(CUFDataTransferDescPtr)(Descriptor *dst, void *src,
-    std::size_t bytes, unsigned mode, const char *sourceFile = nullptr,
-    int sourceLine = 0);
-
 /// Data transfer from a descriptor to a pointer.
 void RTDECL(CUFDataTransferPtrDesc)(void *dst, Descriptor *src,
     std::size_t bytes, unsigned mode, const char *sourceFile = nullptr,
diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
index a28d0a562f2f0b..9b8bcf2f719281 100644
--- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
@@ -23,6 +23,7 @@
 #include "flang/Runtime/allocatable.h"
 #include "mlir/Conversion/LLVMCommon/Pattern.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/IR/Matchers.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -439,6 +440,14 @@ static bool isDstGlobal(cuf::DataTransferOp op) {
   return false;
 }
 
+static mlir::Value getShapeFromDecl(mlir::Value src) {
+  if (auto declareOp = src.getDefiningOp<fir::DeclareOp>())
+    return declareOp.getShape();
+  if (auto declareOp = src.getDefiningOp<hlfir::DeclareOp>())
+    return declareOp.getShape();
+  return mlir::Value{};
+}
+
 struct CUFDataTransferOpConversion
     : public mlir::OpRewritePattern<cuf::DataTransferOp> {
   using OpRewritePattern::OpRewritePattern;
@@ -528,22 +537,45 @@ struct CUFDataTransferOpConversion
     }
 
     // Conversion of data transfer involving at least one descriptor.
-    if (mlir::isa<fir::BaseBoxType>(srcTy) &&
-        mlir::isa<fir::BaseBoxType>(dstTy)) {
-      // Transfer between two descriptor.
+    if (mlir::isa<fir::BaseBoxType>(dstTy)) {
+      // Transfer to a descriptor.
       mlir::func::FuncOp func =
           isDstGlobal(op)
               ? fir::runtime::getRuntimeFunc<mkRTKey(
                     CUFDataTransferGlobalDescDesc)>(loc, builder)
               : fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferDescDesc)>(
                     loc, builder);
+      mlir::Value dst = op.getDst();
+      mlir::Value src = op.getSrc();
+
+      if (!mlir::isa<fir::BaseBoxType>(srcTy)) {
+        // If src is not a descriptor, create one.
+        mlir::Value addr;
+        if (fir::isa_trivial(srcTy) &&
+            mlir::matchPattern(op.getSrc().getDefiningOp(),
+                               mlir::m_Constant())) {
+          // Put constant in memory if it is not.
+          mlir::Value alloc = builder.createTemporary(loc, srcTy);
+          builder.create<fir::StoreOp>(loc, op.getSrc(), alloc);
+          addr = alloc;
+        } else {
+          addr = getDeviceAddress(rewriter, op.getSrcMutable(), symtab);
+        }
+        mlir::Type boxTy = fir::BoxType::get(srcTy);
+        llvm::SmallVector<mlir::Value> lenParams;
+        mlir::Value box =
+            builder.createBox(loc, boxTy, addr, getShapeFromDecl(src),
+                              /*slice=*/nullptr, lenParams,
+                              /*tdesc=*/nullptr);
+        mlir::Value memBox = builder.createTemporary(loc, box.getType());
+        builder.create<fir::StoreOp>(loc, box, memBox);
+        src = memBox;
+      }
 
       auto fTy = func.getFunctionType();
       mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
       mlir::Value sourceLine =
           fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
-      mlir::Value dst = op.getDst();
-      mlir::Value src = op.getSrc();
       llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
           builder, loc, fTy, dst, src, modeValue, sourceFile, sourceLine)};
       builder.create<fir::CallOp>(loc, func, args);
@@ -573,9 +605,7 @@ struct CUFDataTransferOpConversion
       // Type used to compute the width.
       mlir::Type computeType = dstTy;
       auto seqTy = mlir::dyn_cast<fir::SequenceType>(dstTy);
-      bool dstIsDesc = false;
       if (mlir::isa<fir::BaseBoxType>(dstTy)) {
-        dstIsDesc = true;
         computeType = srcTy;
         seqTy = mlir::dyn_cast<fir::SequenceType>(srcTy);
       }
@@ -606,11 +636,8 @@ struct CUFDataTransferOpConversion
           rewriter.create<mlir::arith::MulIOp>(loc, nbElement, widthValue);
 
       mlir::func::FuncOp func =
-          dstIsDesc
-              ? fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferDescPtr)>(
-                    loc, builder)
-              : fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferPtrDesc)>(
-                    loc, builder);
+          fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferPtrDesc)>(
+              loc, builder);
       auto fTy = func.getFunctionType();
       mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
       mlir::Value sourceLine =
diff --git a/flang/runtime/CUDA/memory.cpp b/flang/runtime/CUDA/memory.cpp
index 0e03c618663ebd..2d499f93fbaece 100644
--- a/flang/runtime/CUDA/memory.cpp
+++ b/flang/runtime/CUDA/memory.cpp
@@ -96,13 +96,6 @@ void RTDEF(CUFDataTransferPtrPtr)(void *dst, void *src, std::size_t bytes,
   CUDA_REPORT_IF_ERROR(cudaMemcpy(dst, src, bytes, kind));
 }
 
-void RTDEF(CUFDataTransferDescPtr)(Descriptor *desc, void *addr,
-    std::size_t bytes, unsigned mode, const char *sourceFile, int sourceLine) {
-  Terminator terminator{sourceFile, sourceLine};
-  terminator.Crash(
-      "not yet implemented: CUDA data transfer from a pointer to a descriptor");
-}
-
 void RTDEF(CUFDataTransferPtrDesc)(void *addr, Descriptor *desc,
     std::size_t bytes, unsigned mode, const char *sourceFile, int sourceLine) {
   Terminator terminator{sourceFile, sourceLine};
diff --git a/flang/test/Fir/CUDA/cuda-data-transfer.fir b/flang/test/Fir/CUDA/cuda-data-transfer.fir
index a760650d143583..6a33190168024f 100644
--- a/flang/test/Fir/CUDA/cuda-data-transfer.fir
+++ b/flang/test/Fir/CUDA/cuda-data-transfer.fir
@@ -29,13 +29,16 @@ func.func @_QPsub2() {
 }
 
 // CHECK-LABEL: func.func @_QPsub2()
+// CHECK: %[[TEMP_BOX:.*]] = fir.alloca !fir.box<i32>
 // CHECK: %[[TEMP:.*]] = fir.alloca i32
 // CHECK: %[[ADEV:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub2Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
 // CHECK: %[[C2:.*]] = arith.constant 2 : i32
 // CHECK: fir.store %[[C2]] to %[[TEMP]] : !fir.ref<i32>
+// CHECK: %[[EMBOX:.*]] = fir.embox %[[TEMP]] : (!fir.ref<i32>) -> !fir.box<i32>
+// CHECK: fir.store %[[EMBOX]] to %[[TEMP_BOX]] : !fir.ref<!fir.box<i32>>
 // CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
-// CHECK: %[[TEMP_CONV:.*]] = fir.convert %[[TEMP]] : (!fir.ref<i32>) -> !fir.llvm_ptr<i8>
-// CHECK: fir.call @_FortranACUFMemsetDescriptor(%[[ADEV_BOX]], %[[TEMP_CONV]], %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> none
+// CHECK: %[[TEMP_CONV:.*]] = fir.convert %[[TEMP_BOX]] : (!fir.ref<!fir.box<i32>>) -> !fir.ref<!fir.box<none>>
+// CHECK: fir.call @_FortranACUFDataTransferDescDesc(%[[ADEV_BOX]], %[[TEMP_CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
 
 func.func @_QPsub3() {
   %0 = cuf.alloc !fir.box<!fir.heap<!fir.array<?xi32>>> {bindc_name = "adev", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub3Eadev"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
@@ -48,12 +51,15 @@ func.func @_QPsub3() {
 }
 
 // CHECK-LABEL: func.func @_QPsub3()
+// CHECK: %[[TEMP_BOX:.*]] = fir.alloca !fir.box<i32>
 // CHECK: %[[ADEV:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub3Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
 // CHECK: %[[V:.*]]:2 = hlfir.declare %{{.*}} {uniq_name = "_QFsub3Ev"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+// CHECK: %[[EMBOX:.*]] = fir.embox %[[V]]#0 : (!fir.ref<i32>) -> !fir.box<i32>
+// CHECK: fir.store %[[EMBOX]] to %[[TEMP_BOX]] : !fir.ref<!fir.box<i32>>
 // CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
-// CHECK: %[[V_CONV:.*]] = fir.convert %[[V]]#0 : (!fir.ref<i32>) -> !fir.llvm_ptr<i8>
-// CHECK: fir.call @_FortranACUFMemsetDescriptor(%[[ADEV_BOX]], %[[V_CONV]], %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> none
-
+// CHECK: %[[V_CONV:.*]] = fir.convert %[[TEMP_BOX]] : (!fir.ref<!fir.box<i32>>) -> !fir.ref<!fir.box<none>>
+// CHECK: fir.call @_FortranACUFDataTransferDescDesc(%[[ADEV_BOX]], %[[V_CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
+  
 func.func @_QPsub4() {
   %0 = cuf.alloc !fir.box<!fir.heap<!fir.array<?xi32>>> {bindc_name = "adev", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub4Eadev"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
   %4:2 = hlfir.declare %0 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub4Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
@@ -67,15 +73,14 @@ func.func @_QPsub4() {
   return
 }
 // CHECK-LABEL: func.func @_QPsub4()
+// CHECK: %[[TEMP_BOX:.*]] = fir.alloca !fir.box<!fir.array<10xi32>>
 // CHECK: %[[ADEV:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub4Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
-// CHECK: %[[AHOST:.*]]:2 = hlfir.declare %{{.*}}(%{{.*}}) {uniq_name = "_QFsub4Eahost"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>)
-// CHECK: %[[NBELEM:.*]] = arith.constant 10 : index
-// CHECK: %[[WIDTH:.*]] = arith.constant 4 : index
-// CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %[[WIDTH]] : index
+// CHECK: %[[AHOST:.*]]:2 = hlfir.declare %{{.*}}(%[[AHOST_SHAPE:.*]]) {uniq_name = "_QFsub4Eahost"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>)
+// CHECK: %[[EMBOX:.*]] = fir.embox %[[AHOST]]#0(%[[AHOST_SHAPE]]) : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> !fir.box<!fir.array<10xi32>>
+// CHECK: fir.store %[[EMBOX]] to %[[TEMP_BOX]] : !fir.ref<!fir.box<!fir.array<10xi32>>>
 // CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
-// CHECK: %[[AHOST_PTR:.*]] = fir.convert %[[AHOST]]#0 : (!fir.ref<!fir.array<10xi32>>) -> !fir.llvm_ptr<i8>
-// CHECK: %[[BYTES_CONV:.*]] = fir.convert %[[BYTES]] : (index) -> i64
-// CHECK: fir.call @_FortranACUFDataTransferDescPtr(%[[ADEV_BOX]], %[[AHOST_PTR]], %[[BYTES_CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none
+// CHECK: %[[AHOST_BOX:.*]] = fir.convert %[[TEMP_BOX]] : (!fir.ref<!fir.box<!fir.array<10xi32>>>) -> !fir.ref<!fir.box<none>>
+// CHECK: fir.call @_FortranACUFDataTransferDescDesc(%[[ADEV_BOX]], %[[AHOST_BOX]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
 // CHECK: %[[NBELEM:.*]] = arith.constant 10 : index
 // CHECK: %[[WIDTH:.*]] = arith.constant 4 : index
 // CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %[[WIDTH]] : index
@@ -110,16 +115,15 @@ func.func @_QPsub5(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}) {
 }
 
 // CHECK-LABEL: func.func @_QPsub5
+// CHECK: %[[TEMP_BOX:.*]] = fir.alloca !fir.box<!fir.array<?x?xi32>>
 // CHECK: %[[ADEV:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub5Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>)
 // CHECK: %[[SHAPE:.*]] = fir.shape %[[I1:.*]], %[[I2:.*]] : (index, index) -> !fir.shape<2>
 // CHECK: %[[AHOST:.*]]:2 = hlfir.declare %{{.*}}(%[[SHAPE]]) {uniq_name = "_QFsub5Eahost"} : (!fir.ref<!fir.array<?x?xi32>>, !fir.shape<2>) -> (!fir.box<!fir.array<?x?xi32>>, !fir.ref<!fir.array<?x?xi32>>)
-// CHECK: %[[NBELEM:.*]] = arith.muli %[[I1]], %[[I2]] : index
-// CHECK: %[[WIDTH:.*]] = arith.constant 4 : index
-// CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %[[WIDTH]] : index
+// CHECK: %[[EMBOX:.*]] = fir.embox %[[AHOST]]#1(%[[SHAPE]]) : (!fir.ref<!fir.array<?x?xi32>>, !fir.shape<2>) -> !fir.box<!fir.array<?x?xi32>>
+// CHECK: fir.store %[[EMBOX]] to %[[TEMP_BOX]] : !fir.ref<!fir.box<!fir.array<?x?xi32>>>
 // CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>) -> !fir.ref<!fir.box<none>>
-// CHECK: %[[AHOST_PTR:.*]] = fir.convert %[[AHOST]]#1 : (!fir.ref<!fir.array<?x?xi32>>) -> !fir.llvm_ptr<i8>
-// CHECK: %[[BYTES_CONV:.*]] = fir.convert %[[BYTES]] : (index) -> i64
-// CHECK: fir.call @_FortranACUFDataTransferDescPtr(%[[ADEV_BOX]], %[[AHOST_PTR]], %[[BYTES_CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none
+// CHECK: %[[AHOST_BOX:.*]] = fir.convert %[[TEMP_BOX]] : (!fir.ref<!fir.box<!fir.array<?x?xi32>>>) -> !fir.ref<!fir.box<none>>
+// CHECK: fir.call @_FortranACUFDataTransferDescDesc(%[[ADEV_BOX]], %[[AHOST_BOX]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
 // CHECK: %[[NBELEM:.*]] = arith.muli %[[I1]], %[[I2]] : index
 // CHECK: %[[WIDTH:.*]] = arith.constant 4 : index
 // CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %[[WIDTH]] : index
@@ -248,5 +252,35 @@ func.func @_QQdesc_global() attributes {fir.bindc_name = "host_sub"} {
 // CHECK: %[[BOX_NONE:.*]] = fir.convert %[[GLOBAL_DECL:.*]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
 // CHECK: fir.call @_FortranACUFDataTransferGlobalDescDesc(%[[BOX_NONE]],{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
 
+fir.global @_QMmod2Eadev {data_attr = #cuf.cuda<device>} : !fir.box<!fir.heap<!fir.array<?xi32>>> {
+  %c0 = arith.constant 0 : index
+  %0 = fir.zero_bits !fir.heap<!fir.array<?xi32>>
+  %1 = fir.shape %c0 : (index) -> !fir.shape<1>
+  %2 = fir.embox %0(%1) {allocator_idx = 2 : i32} : (!fir.heap<!fir.array<?xi32>>, !fir.shape<1>) -> !fir.box<!fir.heap<!fir.array<?xi32>>>
+  fir.has_value %2 : !fir.box<!fir.heap<!fir.array<?xi32>>>
+}
+func.func @_QPdesc_global_ptr() {
+  %c10 = arith.constant 10 : index
+  %0 = fir.address_of(@_QMmod2Eadev) : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+  %1 = fir.declare %0 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QMmod2Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+  %2 = fir.alloca !fir.array<10xi32> {bindc_name = "ahost", uniq_name = "_QFdesc_global_ptrEahost"}
+  %3 = fir.shape %c10 : (index) -> !fir.shape<1>
+  %4 = fir.declare %2(%3) {uniq_name = "_QFdesc_global_ptrEahost"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> !fir.ref<!fir.array<10xi32>>
+  cuf.data_transfer %4 to %1 {transfer_kind = #cuf.cuda_transfer<host_device>} : !fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+  return
+}
+
+// CHECK-LABEL: func.func @_QPdesc_global_ptr()
+// CHECK: %[[TEMP_BOX:.*]] = fir.alloca !fir.box<!fir.array<10xi32>>
+// CHECK: %[[ADDR_ADEV:.*]] = fir.address_of(@_QMmod2Eadev) : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+// CHECK: %[[DECL_ADEV:.*]] = fir.declare %[[ADDR_ADEV]] {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QMmod2Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+// CHECK: %[[AHOST:.*]] = fir.alloca !fir.array<10xi32> {bindc_name = "ahost", uniq_name = "_QFdesc_global_ptrEahost"}
+// CHECK: %[[SHAPE:.*]] = fir.shape %c10 : (index) -> !fir.shape<1>
+// CHECK: %[[DECL_AHOST:.*]] = fir.declare %[[AHOST]](%[[SHAPE]]) {uniq_name = "_QFdesc_global_ptrEahost"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> !fir.ref<!fir.array<10xi32>>
+// CHECK: %[[EMBOX:.*]] = fir.embox %[[DECL_AHOST]](%[[SHAPE]]) : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> !fir.box<!fir.array<10xi32>>
+// CHECK: fir.store %[[EMBOX]] to %[[TEMP_BOX]] : !fir.ref<!fir.box<!fir.array<10xi32>>>
+// CHECK: %[[ADEV_BOXNONE:.*]] = fir.convert %[[DECL_ADEV]] : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
+// CHECK: %[[AHOST_BOXNONE:.*]] = fir.convert %[[TEMP_BOX]] : (!fir.ref<!fir.box<!fir.array<10xi32>>>) -> !fir.ref<!fir.box<none>>
+// CHECK: fir.call @_FortranACUFDataTransferGlobalDescDesc(%[[ADEV_BOXNONE]], %[[AHOST_BOXNONE]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
 
 } // end of module

@llvmbot
Copy link
Member

llvmbot commented Nov 4, 2024

@llvm/pr-subscribers-flang-fir-hlfir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

When source is a pointer to an array or a scalar, emboss it and use the CUFDataTransferDescDesc or CUFDataTransferGlobalDescDesc entry points. The runtime is already able to deal with all the corner cases like non contiguous arrays and so on so we exploit this.

Memset might still be used for simple case where we want to initialize to 0 for example. This will come in a follow up patch.


Full diff: https://github.com/llvm/llvm-project/pull/114892.diff

4 Files Affected:

  • (modified) flang/include/flang/Runtime/CUDA/memory.h (-5)
  • (modified) flang/lib/Optimizer/Transforms/CUFOpConversion.cpp (+39-12)
  • (modified) flang/runtime/CUDA/memory.cpp (-7)
  • (modified) flang/test/Fir/CUDA/cuda-data-transfer.fir (+52-18)
diff --git a/flang/include/flang/Runtime/CUDA/memory.h b/flang/include/flang/Runtime/CUDA/memory.h
index 51d6b8d4545f09..4ac2528c1aedbc 100644
--- a/flang/include/flang/Runtime/CUDA/memory.h
+++ b/flang/include/flang/Runtime/CUDA/memory.h
@@ -35,11 +35,6 @@ void RTDECL(CUFMemsetDescriptor)(Descriptor *desc, void *value,
 void RTDECL(CUFDataTransferPtrPtr)(void *dst, void *src, std::size_t bytes,
     unsigned mode, const char *sourceFile = nullptr, int sourceLine = 0);
 
-/// Data transfer from a pointer to a descriptor.
-void RTDECL(CUFDataTransferDescPtr)(Descriptor *dst, void *src,
-    std::size_t bytes, unsigned mode, const char *sourceFile = nullptr,
-    int sourceLine = 0);
-
 /// Data transfer from a descriptor to a pointer.
 void RTDECL(CUFDataTransferPtrDesc)(void *dst, Descriptor *src,
     std::size_t bytes, unsigned mode, const char *sourceFile = nullptr,
diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
index a28d0a562f2f0b..9b8bcf2f719281 100644
--- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
@@ -23,6 +23,7 @@
 #include "flang/Runtime/allocatable.h"
 #include "mlir/Conversion/LLVMCommon/Pattern.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/IR/Matchers.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -439,6 +440,14 @@ static bool isDstGlobal(cuf::DataTransferOp op) {
   return false;
 }
 
+static mlir::Value getShapeFromDecl(mlir::Value src) {
+  if (auto declareOp = src.getDefiningOp<fir::DeclareOp>())
+    return declareOp.getShape();
+  if (auto declareOp = src.getDefiningOp<hlfir::DeclareOp>())
+    return declareOp.getShape();
+  return mlir::Value{};
+}
+
 struct CUFDataTransferOpConversion
     : public mlir::OpRewritePattern<cuf::DataTransferOp> {
   using OpRewritePattern::OpRewritePattern;
@@ -528,22 +537,45 @@ struct CUFDataTransferOpConversion
     }
 
     // Conversion of data transfer involving at least one descriptor.
-    if (mlir::isa<fir::BaseBoxType>(srcTy) &&
-        mlir::isa<fir::BaseBoxType>(dstTy)) {
-      // Transfer between two descriptor.
+    if (mlir::isa<fir::BaseBoxType>(dstTy)) {
+      // Transfer to a descriptor.
       mlir::func::FuncOp func =
           isDstGlobal(op)
               ? fir::runtime::getRuntimeFunc<mkRTKey(
                     CUFDataTransferGlobalDescDesc)>(loc, builder)
               : fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferDescDesc)>(
                     loc, builder);
+      mlir::Value dst = op.getDst();
+      mlir::Value src = op.getSrc();
+
+      if (!mlir::isa<fir::BaseBoxType>(srcTy)) {
+        // If src is not a descriptor, create one.
+        mlir::Value addr;
+        if (fir::isa_trivial(srcTy) &&
+            mlir::matchPattern(op.getSrc().getDefiningOp(),
+                               mlir::m_Constant())) {
+          // Put constant in memory if it is not.
+          mlir::Value alloc = builder.createTemporary(loc, srcTy);
+          builder.create<fir::StoreOp>(loc, op.getSrc(), alloc);
+          addr = alloc;
+        } else {
+          addr = getDeviceAddress(rewriter, op.getSrcMutable(), symtab);
+        }
+        mlir::Type boxTy = fir::BoxType::get(srcTy);
+        llvm::SmallVector<mlir::Value> lenParams;
+        mlir::Value box =
+            builder.createBox(loc, boxTy, addr, getShapeFromDecl(src),
+                              /*slice=*/nullptr, lenParams,
+                              /*tdesc=*/nullptr);
+        mlir::Value memBox = builder.createTemporary(loc, box.getType());
+        builder.create<fir::StoreOp>(loc, box, memBox);
+        src = memBox;
+      }
 
       auto fTy = func.getFunctionType();
       mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
       mlir::Value sourceLine =
           fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
-      mlir::Value dst = op.getDst();
-      mlir::Value src = op.getSrc();
       llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
           builder, loc, fTy, dst, src, modeValue, sourceFile, sourceLine)};
       builder.create<fir::CallOp>(loc, func, args);
@@ -573,9 +605,7 @@ struct CUFDataTransferOpConversion
       // Type used to compute the width.
       mlir::Type computeType = dstTy;
       auto seqTy = mlir::dyn_cast<fir::SequenceType>(dstTy);
-      bool dstIsDesc = false;
       if (mlir::isa<fir::BaseBoxType>(dstTy)) {
-        dstIsDesc = true;
         computeType = srcTy;
         seqTy = mlir::dyn_cast<fir::SequenceType>(srcTy);
       }
@@ -606,11 +636,8 @@ struct CUFDataTransferOpConversion
           rewriter.create<mlir::arith::MulIOp>(loc, nbElement, widthValue);
 
       mlir::func::FuncOp func =
-          dstIsDesc
-              ? fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferDescPtr)>(
-                    loc, builder)
-              : fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferPtrDesc)>(
-                    loc, builder);
+          fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferPtrDesc)>(
+              loc, builder);
       auto fTy = func.getFunctionType();
       mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
       mlir::Value sourceLine =
diff --git a/flang/runtime/CUDA/memory.cpp b/flang/runtime/CUDA/memory.cpp
index 0e03c618663ebd..2d499f93fbaece 100644
--- a/flang/runtime/CUDA/memory.cpp
+++ b/flang/runtime/CUDA/memory.cpp
@@ -96,13 +96,6 @@ void RTDEF(CUFDataTransferPtrPtr)(void *dst, void *src, std::size_t bytes,
   CUDA_REPORT_IF_ERROR(cudaMemcpy(dst, src, bytes, kind));
 }
 
-void RTDEF(CUFDataTransferDescPtr)(Descriptor *desc, void *addr,
-    std::size_t bytes, unsigned mode, const char *sourceFile, int sourceLine) {
-  Terminator terminator{sourceFile, sourceLine};
-  terminator.Crash(
-      "not yet implemented: CUDA data transfer from a pointer to a descriptor");
-}
-
 void RTDEF(CUFDataTransferPtrDesc)(void *addr, Descriptor *desc,
     std::size_t bytes, unsigned mode, const char *sourceFile, int sourceLine) {
   Terminator terminator{sourceFile, sourceLine};
diff --git a/flang/test/Fir/CUDA/cuda-data-transfer.fir b/flang/test/Fir/CUDA/cuda-data-transfer.fir
index a760650d143583..6a33190168024f 100644
--- a/flang/test/Fir/CUDA/cuda-data-transfer.fir
+++ b/flang/test/Fir/CUDA/cuda-data-transfer.fir
@@ -29,13 +29,16 @@ func.func @_QPsub2() {
 }
 
 // CHECK-LABEL: func.func @_QPsub2()
+// CHECK: %[[TEMP_BOX:.*]] = fir.alloca !fir.box<i32>
 // CHECK: %[[TEMP:.*]] = fir.alloca i32
 // CHECK: %[[ADEV:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub2Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
 // CHECK: %[[C2:.*]] = arith.constant 2 : i32
 // CHECK: fir.store %[[C2]] to %[[TEMP]] : !fir.ref<i32>
+// CHECK: %[[EMBOX:.*]] = fir.embox %[[TEMP]] : (!fir.ref<i32>) -> !fir.box<i32>
+// CHECK: fir.store %[[EMBOX]] to %[[TEMP_BOX]] : !fir.ref<!fir.box<i32>>
 // CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
-// CHECK: %[[TEMP_CONV:.*]] = fir.convert %[[TEMP]] : (!fir.ref<i32>) -> !fir.llvm_ptr<i8>
-// CHECK: fir.call @_FortranACUFMemsetDescriptor(%[[ADEV_BOX]], %[[TEMP_CONV]], %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> none
+// CHECK: %[[TEMP_CONV:.*]] = fir.convert %[[TEMP_BOX]] : (!fir.ref<!fir.box<i32>>) -> !fir.ref<!fir.box<none>>
+// CHECK: fir.call @_FortranACUFDataTransferDescDesc(%[[ADEV_BOX]], %[[TEMP_CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
 
 func.func @_QPsub3() {
   %0 = cuf.alloc !fir.box<!fir.heap<!fir.array<?xi32>>> {bindc_name = "adev", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub3Eadev"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
@@ -48,12 +51,15 @@ func.func @_QPsub3() {
 }
 
 // CHECK-LABEL: func.func @_QPsub3()
+// CHECK: %[[TEMP_BOX:.*]] = fir.alloca !fir.box<i32>
 // CHECK: %[[ADEV:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub3Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
 // CHECK: %[[V:.*]]:2 = hlfir.declare %{{.*}} {uniq_name = "_QFsub3Ev"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+// CHECK: %[[EMBOX:.*]] = fir.embox %[[V]]#0 : (!fir.ref<i32>) -> !fir.box<i32>
+// CHECK: fir.store %[[EMBOX]] to %[[TEMP_BOX]] : !fir.ref<!fir.box<i32>>
 // CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
-// CHECK: %[[V_CONV:.*]] = fir.convert %[[V]]#0 : (!fir.ref<i32>) -> !fir.llvm_ptr<i8>
-// CHECK: fir.call @_FortranACUFMemsetDescriptor(%[[ADEV_BOX]], %[[V_CONV]], %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> none
-
+// CHECK: %[[V_CONV:.*]] = fir.convert %[[TEMP_BOX]] : (!fir.ref<!fir.box<i32>>) -> !fir.ref<!fir.box<none>>
+// CHECK: fir.call @_FortranACUFDataTransferDescDesc(%[[ADEV_BOX]], %[[V_CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
+  
 func.func @_QPsub4() {
   %0 = cuf.alloc !fir.box<!fir.heap<!fir.array<?xi32>>> {bindc_name = "adev", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub4Eadev"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
   %4:2 = hlfir.declare %0 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub4Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
@@ -67,15 +73,14 @@ func.func @_QPsub4() {
   return
 }
 // CHECK-LABEL: func.func @_QPsub4()
+// CHECK: %[[TEMP_BOX:.*]] = fir.alloca !fir.box<!fir.array<10xi32>>
 // CHECK: %[[ADEV:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub4Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
-// CHECK: %[[AHOST:.*]]:2 = hlfir.declare %{{.*}}(%{{.*}}) {uniq_name = "_QFsub4Eahost"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>)
-// CHECK: %[[NBELEM:.*]] = arith.constant 10 : index
-// CHECK: %[[WIDTH:.*]] = arith.constant 4 : index
-// CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %[[WIDTH]] : index
+// CHECK: %[[AHOST:.*]]:2 = hlfir.declare %{{.*}}(%[[AHOST_SHAPE:.*]]) {uniq_name = "_QFsub4Eahost"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>)
+// CHECK: %[[EMBOX:.*]] = fir.embox %[[AHOST]]#0(%[[AHOST_SHAPE]]) : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> !fir.box<!fir.array<10xi32>>
+// CHECK: fir.store %[[EMBOX]] to %[[TEMP_BOX]] : !fir.ref<!fir.box<!fir.array<10xi32>>>
 // CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
-// CHECK: %[[AHOST_PTR:.*]] = fir.convert %[[AHOST]]#0 : (!fir.ref<!fir.array<10xi32>>) -> !fir.llvm_ptr<i8>
-// CHECK: %[[BYTES_CONV:.*]] = fir.convert %[[BYTES]] : (index) -> i64
-// CHECK: fir.call @_FortranACUFDataTransferDescPtr(%[[ADEV_BOX]], %[[AHOST_PTR]], %[[BYTES_CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none
+// CHECK: %[[AHOST_BOX:.*]] = fir.convert %[[TEMP_BOX]] : (!fir.ref<!fir.box<!fir.array<10xi32>>>) -> !fir.ref<!fir.box<none>>
+// CHECK: fir.call @_FortranACUFDataTransferDescDesc(%[[ADEV_BOX]], %[[AHOST_BOX]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
 // CHECK: %[[NBELEM:.*]] = arith.constant 10 : index
 // CHECK: %[[WIDTH:.*]] = arith.constant 4 : index
 // CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %[[WIDTH]] : index
@@ -110,16 +115,15 @@ func.func @_QPsub5(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}) {
 }
 
 // CHECK-LABEL: func.func @_QPsub5
+// CHECK: %[[TEMP_BOX:.*]] = fir.alloca !fir.box<!fir.array<?x?xi32>>
 // CHECK: %[[ADEV:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub5Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>)
 // CHECK: %[[SHAPE:.*]] = fir.shape %[[I1:.*]], %[[I2:.*]] : (index, index) -> !fir.shape<2>
 // CHECK: %[[AHOST:.*]]:2 = hlfir.declare %{{.*}}(%[[SHAPE]]) {uniq_name = "_QFsub5Eahost"} : (!fir.ref<!fir.array<?x?xi32>>, !fir.shape<2>) -> (!fir.box<!fir.array<?x?xi32>>, !fir.ref<!fir.array<?x?xi32>>)
-// CHECK: %[[NBELEM:.*]] = arith.muli %[[I1]], %[[I2]] : index
-// CHECK: %[[WIDTH:.*]] = arith.constant 4 : index
-// CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %[[WIDTH]] : index
+// CHECK: %[[EMBOX:.*]] = fir.embox %[[AHOST]]#1(%[[SHAPE]]) : (!fir.ref<!fir.array<?x?xi32>>, !fir.shape<2>) -> !fir.box<!fir.array<?x?xi32>>
+// CHECK: fir.store %[[EMBOX]] to %[[TEMP_BOX]] : !fir.ref<!fir.box<!fir.array<?x?xi32>>>
 // CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>) -> !fir.ref<!fir.box<none>>
-// CHECK: %[[AHOST_PTR:.*]] = fir.convert %[[AHOST]]#1 : (!fir.ref<!fir.array<?x?xi32>>) -> !fir.llvm_ptr<i8>
-// CHECK: %[[BYTES_CONV:.*]] = fir.convert %[[BYTES]] : (index) -> i64
-// CHECK: fir.call @_FortranACUFDataTransferDescPtr(%[[ADEV_BOX]], %[[AHOST_PTR]], %[[BYTES_CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none
+// CHECK: %[[AHOST_BOX:.*]] = fir.convert %[[TEMP_BOX]] : (!fir.ref<!fir.box<!fir.array<?x?xi32>>>) -> !fir.ref<!fir.box<none>>
+// CHECK: fir.call @_FortranACUFDataTransferDescDesc(%[[ADEV_BOX]], %[[AHOST_BOX]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
 // CHECK: %[[NBELEM:.*]] = arith.muli %[[I1]], %[[I2]] : index
 // CHECK: %[[WIDTH:.*]] = arith.constant 4 : index
 // CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %[[WIDTH]] : index
@@ -248,5 +252,35 @@ func.func @_QQdesc_global() attributes {fir.bindc_name = "host_sub"} {
 // CHECK: %[[BOX_NONE:.*]] = fir.convert %[[GLOBAL_DECL:.*]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
 // CHECK: fir.call @_FortranACUFDataTransferGlobalDescDesc(%[[BOX_NONE]],{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
 
+fir.global @_QMmod2Eadev {data_attr = #cuf.cuda<device>} : !fir.box<!fir.heap<!fir.array<?xi32>>> {
+  %c0 = arith.constant 0 : index
+  %0 = fir.zero_bits !fir.heap<!fir.array<?xi32>>
+  %1 = fir.shape %c0 : (index) -> !fir.shape<1>
+  %2 = fir.embox %0(%1) {allocator_idx = 2 : i32} : (!fir.heap<!fir.array<?xi32>>, !fir.shape<1>) -> !fir.box<!fir.heap<!fir.array<?xi32>>>
+  fir.has_value %2 : !fir.box<!fir.heap<!fir.array<?xi32>>>
+}
+func.func @_QPdesc_global_ptr() {
+  %c10 = arith.constant 10 : index
+  %0 = fir.address_of(@_QMmod2Eadev) : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+  %1 = fir.declare %0 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QMmod2Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+  %2 = fir.alloca !fir.array<10xi32> {bindc_name = "ahost", uniq_name = "_QFdesc_global_ptrEahost"}
+  %3 = fir.shape %c10 : (index) -> !fir.shape<1>
+  %4 = fir.declare %2(%3) {uniq_name = "_QFdesc_global_ptrEahost"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> !fir.ref<!fir.array<10xi32>>
+  cuf.data_transfer %4 to %1 {transfer_kind = #cuf.cuda_transfer<host_device>} : !fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+  return
+}
+
+// CHECK-LABEL: func.func @_QPdesc_global_ptr()
+// CHECK: %[[TEMP_BOX:.*]] = fir.alloca !fir.box<!fir.array<10xi32>>
+// CHECK: %[[ADDR_ADEV:.*]] = fir.address_of(@_QMmod2Eadev) : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+// CHECK: %[[DECL_ADEV:.*]] = fir.declare %[[ADDR_ADEV]] {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QMmod2Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+// CHECK: %[[AHOST:.*]] = fir.alloca !fir.array<10xi32> {bindc_name = "ahost", uniq_name = "_QFdesc_global_ptrEahost"}
+// CHECK: %[[SHAPE:.*]] = fir.shape %c10 : (index) -> !fir.shape<1>
+// CHECK: %[[DECL_AHOST:.*]] = fir.declare %[[AHOST]](%[[SHAPE]]) {uniq_name = "_QFdesc_global_ptrEahost"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> !fir.ref<!fir.array<10xi32>>
+// CHECK: %[[EMBOX:.*]] = fir.embox %[[DECL_AHOST]](%[[SHAPE]]) : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> !fir.box<!fir.array<10xi32>>
+// CHECK: fir.store %[[EMBOX]] to %[[TEMP_BOX]] : !fir.ref<!fir.box<!fir.array<10xi32>>>
+// CHECK: %[[ADEV_BOXNONE:.*]] = fir.convert %[[DECL_ADEV]] : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
+// CHECK: %[[AHOST_BOXNONE:.*]] = fir.convert %[[TEMP_BOX]] : (!fir.ref<!fir.box<!fir.array<10xi32>>>) -> !fir.ref<!fir.box<none>>
+// CHECK: fir.call @_FortranACUFDataTransferGlobalDescDesc(%[[ADEV_BOXNONE]], %[[AHOST_BOXNONE]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
 
 } // end of module

@clementval clementval merged commit 652db7e into llvm:main Nov 5, 2024
6 of 8 checks passed
@clementval clementval deleted the cuf_data_transfer_desc_ptr branch November 5, 2024 17:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang:runtime flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants