diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h index 50d1a39126ea3..d474ba8485d5d 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h @@ -21,6 +21,7 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/Interfaces/InferIntRangeInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Target/LLVMIR/ModuleTranslation.h" #include "llvm/IR/IntrinsicsNVPTX.h" #include "mlir/Dialect/LLVMIR/NVVMOpsEnums.h.inc" diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 797a006708131..8c8e44a054a62 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -849,55 +849,24 @@ def LoadCacheModifierKind : I32EnumAttr<"LoadCacheModifierKind", def LoadCacheModifierAttr : EnumAttr; -def NVVM_CpAsyncOp : NVVM_PTXBuilder_Op<"cp.async.shared.global">, +def NVVM_CpAsyncOp : NVVM_Op<"cp.async.shared.global">, Arguments<(ins LLVM_PointerShared:$dst, LLVM_PointerGlobal:$src, I32Attr:$size, LoadCacheModifierAttr:$modifier, Optional:$cpSize)> { - string llvmBuilder = [{ - llvm::Intrinsic::ID id; - switch ($size) { - case 4: - id = llvm::Intrinsic::nvvm_cp_async_ca_shared_global_4; - break; - case 8: - id = llvm::Intrinsic::nvvm_cp_async_ca_shared_global_8; - break; - case 16: - if($modifier == NVVM::LoadCacheModifierKind::CG) - id = llvm::Intrinsic::nvvm_cp_async_cg_shared_global_16; - else if($modifier == NVVM::LoadCacheModifierKind::CA) - id = llvm::Intrinsic::nvvm_cp_async_ca_shared_global_16; - else - llvm_unreachable("unsupported cache modifier"); - break; - default: - llvm_unreachable("unsupported async copy size"); - } - createIntrinsicCall(builder, id, {$dst, $src}); - }]; let assemblyFormat = "$dst `,` $src `,` $size `,` `cache` `=` $modifier (`,` $cpSize^)? attr-dict `:` type(operands)"; let hasVerifier = 1; let extraClassDeclaration = [{ - bool hasIntrinsic() { if(getCpSize()) return false; return true; } - - void getAsmValues(RewriterBase &rewriter, - llvm::SmallVectorImpl> &asmValues) { - asmValues.push_back({getDst(), PTXRegisterMod::Read}); - asmValues.push_back({getSrc(), PTXRegisterMod::Read}); - asmValues.push_back({makeConstantI32(rewriter, getSize()), PTXRegisterMod::Read}); - asmValues.push_back({getCpSize(), PTXRegisterMod::Read}); - } + static llvm::Intrinsic::ID + getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, + llvm::SmallVector &args); }]; - let extraClassDefinition = [{ - std::string $cppClass::getPtx() { - if(getModifier() == NVVM::LoadCacheModifierKind::CG) - return std::string("cp.async.cg.shared.global [%0], [%1], %2, %3;\n"); - if(getModifier() == NVVM::LoadCacheModifierKind::CA) - return std::string("cp.async.ca.shared.global [%0], [%1], %2, %3;\n"); - llvm_unreachable("unsupported cache modifier"); - } + string llvmBuilder = [{ + llvm::SmallVector translatedOperands; + auto id = NVVM::CpAsyncOp::getIntrinsicIDAndArgs( + *op, moduleTranslation, translatedOperands); + createIntrinsicCall(builder, id, translatedOperands); }]; } diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index ccb5ad05f0bf7..dc7e724379ed0 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -1110,6 +1110,44 @@ LogicalResult NVVM::BarrierOp::verify() { return success(); } +#define CP_ASYNC_ID_IMPL(mod, size, suffix) \ + llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix + +#define GET_CP_ASYNC_ID(mod, size, has_cpsize) \ + has_cpsize ? CP_ASYNC_ID_IMPL(mod, size, _s) : CP_ASYNC_ID_IMPL(mod, size, ) + +llvm::Intrinsic::ID +CpAsyncOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, + llvm::SmallVector &args) { + llvm::Intrinsic::ID id; + + auto cpAsyncOp = cast(op); + bool hasCpSize = cpAsyncOp.getCpSize() ? true : false; + switch (cpAsyncOp.getSize()) { + case 4: + id = GET_CP_ASYNC_ID(ca, 4, hasCpSize); + break; + case 8: + id = GET_CP_ASYNC_ID(ca, 8, hasCpSize); + break; + case 16: + id = (cpAsyncOp.getModifier() == NVVM::LoadCacheModifierKind::CG) + ? GET_CP_ASYNC_ID(cg, 16, hasCpSize) + : GET_CP_ASYNC_ID(ca, 16, hasCpSize); + break; + default: + llvm_unreachable("Invalid copy size in CpAsyncOp."); + } + + // Fill the Intrinsic Args + args.push_back(mt.lookupValue(cpAsyncOp.getDst())); + args.push_back(mt.lookupValue(cpAsyncOp.getSrc())); + if (hasCpSize) + args.push_back(mt.lookupValue(cpAsyncOp.getCpSize())); + + return id; +} + llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID(int tensorDims, bool isIm2Col) { switch (tensorDims) { diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir index 84ea55ceb5acc..c7a6eca158276 100644 --- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir +++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir @@ -74,13 +74,9 @@ func.func @async_cp(%dst: !llvm.ptr<3>, %src: !llvm.ptr<1>) { // CHECK-LABEL: @async_cp_zfill func.func @async_cp_zfill(%dst: !llvm.ptr<3>, %src: !llvm.ptr<1>, %cpSize: i32) { - // CHECK: llvm.inline_asm has_side_effects asm_dialect = att - // CHECK-SAME: "cp.async.cg.shared.global [$0], [$1], $2, $3;\0A", - // CHECK-SAME: "r,l,n,r" %{{.*}}, %{{.*}}, %{{.*}} : (!llvm.ptr<3>, !llvm.ptr<1>, i32, i32) -> () + // CHECK: nvvm.cp.async.shared.global %{{.*}}, %{{.*}}, 16, cache = cg, %{{.*}} : !llvm.ptr<3>, !llvm.ptr<1>, i32 nvvm.cp.async.shared.global %dst, %src, 16, cache = cg, %cpSize : !llvm.ptr<3>, !llvm.ptr<1>, i32 - // CHECK: llvm.inline_asm has_side_effects asm_dialect = att - // CHECK-SAME: "cp.async.ca.shared.global [$0], [$1], $2, $3;\0A", - // CHECK-SAME: "r,l,n,r" %{{.*}}, %{{.*}}, %{{.*}} : (!llvm.ptr<3>, !llvm.ptr<1>, i32, i32) -> () + // CHECK: nvvm.cp.async.shared.global %{{.*}}, %{{.*}}, 4, cache = ca, %{{.*}} : !llvm.ptr<3>, !llvm.ptr<1>, i32 nvvm.cp.async.shared.global %dst, %src, 4, cache = ca, %cpSize : !llvm.ptr<3>, !llvm.ptr<1>, i32 return } diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir index 09e98765413f0..7dad9a403def0 100644 --- a/mlir/test/Target/LLVMIR/nvvmir.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir.mlir @@ -488,21 +488,35 @@ llvm.func @nvvm_wmma_mma(%0 : i32, %1 : i32, %2 : i32, %3 : i32, %4 : i32, %5 : // CHECK-LABEL: @cp_async llvm.func @cp_async(%arg0: !llvm.ptr<3>, %arg1: !llvm.ptr<1>) { -// CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.4(ptr addrspace(3) %{{.*}}, ptr addrspace(1) %{{.*}}) + // CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.4(ptr addrspace(3) %{{.*}}, ptr addrspace(1) %{{.*}}) nvvm.cp.async.shared.global %arg0, %arg1, 4, cache = ca : !llvm.ptr<3>, !llvm.ptr<1> -// CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.8(ptr addrspace(3) %{{.*}}, ptr addrspace(1) %{{.*}}) + // CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.8(ptr addrspace(3) %{{.*}}, ptr addrspace(1) %{{.*}}) nvvm.cp.async.shared.global %arg0, %arg1, 8, cache = ca : !llvm.ptr<3>, !llvm.ptr<1> -// CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.16(ptr addrspace(3) %{{.*}}, ptr addrspace(1) %{{.*}}) + // CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.16(ptr addrspace(3) %{{.*}}, ptr addrspace(1) %{{.*}}) nvvm.cp.async.shared.global %arg0, %arg1, 16, cache = ca : !llvm.ptr<3>, !llvm.ptr<1> -// CHECK: call void @llvm.nvvm.cp.async.cg.shared.global.16(ptr addrspace(3) %{{.*}}, ptr addrspace(1) %{{.*}}) + // CHECK: call void @llvm.nvvm.cp.async.cg.shared.global.16(ptr addrspace(3) %{{.*}}, ptr addrspace(1) %{{.*}}) nvvm.cp.async.shared.global %arg0, %arg1, 16, cache = cg : !llvm.ptr<3>, !llvm.ptr<1> -// CHECK: call void @llvm.nvvm.cp.async.commit.group() + + // CHECK: call void @llvm.nvvm.cp.async.commit.group() nvvm.cp.async.commit.group -// CHECK: call void @llvm.nvvm.cp.async.wait.group(i32 0) + // CHECK: call void @llvm.nvvm.cp.async.wait.group(i32 0) nvvm.cp.async.wait.group 0 llvm.return } +// CHECK-LABEL: @async_cp_zfill +llvm.func @async_cp_zfill(%dst: !llvm.ptr<3>, %src: !llvm.ptr<1>, %cpSize: i32) { + // CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.4.s(ptr addrspace(3) %{{.*}}, ptr addrspace(1) %{{.*}}, i32 %{{.*}}) + nvvm.cp.async.shared.global %dst, %src, 4, cache = ca, %cpSize : !llvm.ptr<3>, !llvm.ptr<1>, i32 + // CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.8.s(ptr addrspace(3) %{{.*}}, ptr addrspace(1) %{{.*}}, i32 %{{.*}}) + nvvm.cp.async.shared.global %dst, %src, 8, cache = ca, %cpSize : !llvm.ptr<3>, !llvm.ptr<1>, i32 + // CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.16.s(ptr addrspace(3) %{{.*}}, ptr addrspace(1) %{{.*}}, i32 %{{.*}}) + nvvm.cp.async.shared.global %dst, %src, 16, cache = ca, %cpSize : !llvm.ptr<3>, !llvm.ptr<1>, i32 + // CHECK: call void @llvm.nvvm.cp.async.cg.shared.global.16.s(ptr addrspace(3) %{{.*}}, ptr addrspace(1) %{{.*}}, i32 %{{.*}}) + nvvm.cp.async.shared.global %dst, %src, 16, cache = cg, %cpSize : !llvm.ptr<3>, !llvm.ptr<1>, i32 + llvm.return +} + // CHECK-LABEL: @cp_async_mbarrier_arrive llvm.func @cp_async_mbarrier_arrive(%bar_shared: !llvm.ptr<3>, %bar_gen: !llvm.ptr) { // CHECK: call void @llvm.nvvm.cp.async.mbarrier.arrive(ptr %{{.*}})