-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[NVPTX] pull in v2i32 build_vector through v2f32 bitcast #153478
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[NVPTX] pull in v2i32 build_vector through v2f32 bitcast #153478
Conversation
@llvm/pr-subscribers-backend-nvptx Author: Princeton Ferro (Prince781) ChangesTransform:
Since v2f32 is legal but v2i32 is not, v2i32 build_vector would be legalized as bitwise ops on i64, when we want each 32-bit element to be in its own register. Fixes #153109 Full diff: https://github.com/llvm/llvm-project/pull/153478.diff 2 Files Affected:
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 3daf25d551520..fcabe49e09c6c 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -892,10 +892,11 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
setOperationAction(ISD::UMUL_LOHI, MVT::i64, Expand);
// We have some custom DAG combine patterns for these nodes
- setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
- ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
- ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::LOAD,
- ISD::STORE, ISD::ZERO_EXTEND, ISD::SIGN_EXTEND});
+ setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::BITCAST,
+ ISD::EXTRACT_VECTOR_ELT, ISD::FADD, ISD::MUL, ISD::SHL,
+ ISD::SREM, ISD::UREM, ISD::VSELECT, ISD::BUILD_VECTOR,
+ ISD::ADDRSPACECAST, ISD::LOAD, ISD::STORE,
+ ISD::ZERO_EXTEND, ISD::SIGN_EXTEND});
// setcc for f16x2 and bf16x2 needs special handling to prevent
// legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -5334,6 +5335,26 @@ static SDValue PerformANDCombine(SDNode *N,
return SDValue();
}
+static SDValue combineBitcast(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
+ const SDValue &Input = N->getOperand(0);
+ const MVT FromVT = Input.getSimpleValueType();
+ const MVT ToVT = N->getSimpleValueType(0);
+ const SDLoc DL(N);
+
+ if (Input.getOpcode() != ISD::BUILD_VECTOR || ToVT != MVT::v2f32 ||
+ !(FromVT.isVector() &&
+ FromVT.getVectorNumElements() == ToVT.getVectorNumElements()))
+ return SDValue();
+
+ const MVT ToEltVT = ToVT.getVectorElementType();
+
+ // pull in build_vector through bitcast
+ return DCI.DAG.getBuildVector(
+ ToVT, DL,
+ {DCI.DAG.getBitcast(ToEltVT, Input.getOperand(0)),
+ DCI.DAG.getBitcast(ToEltVT, Input.getOperand(1))});
+}
+
static SDValue PerformREMCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
CodeGenOptLevel OptLevel) {
@@ -6007,6 +6028,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
return combineADDRSPACECAST(N, DCI);
case ISD::AND:
return PerformANDCombine(N, DCI);
+ case ISD::BITCAST:
+ return combineBitcast(N, DCI);
case ISD::SIGN_EXTEND:
case ISD::ZERO_EXTEND:
return combineMulWide(N, DCI, OptLevel);
diff --git a/llvm/test/CodeGen/NVPTX/f32x2-inlineasm.ll b/llvm/test/CodeGen/NVPTX/f32x2-inlineasm.ll
new file mode 100644
index 0000000000000..4d80ee68faac6
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/f32x2-inlineasm.ll
@@ -0,0 +1,64 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -mcpu=sm_90a -O0 -disable-post-ra -frame-pointer=all \
+; RUN: -verify-machineinstrs | FileCheck --check-prefixes=CHECK %s
+; RUN: %if ptxas-12.7 %{ \
+; RUN: llc < %s -mcpu=sm_90a -O0 -disable-post-ra -frame-pointer=all \
+; RUN: -verify-machineinstrs | %ptxas-verify -arch=sm_90a \
+; RUN: %}
+
+target datalayout = "e-m:o-i64:64-i128:128-n32:64-S128"
+target triple = "nvptx64-nvidia-cuda"
+
+define ptx_kernel void @kernel1() {
+; CHECK-LABEL: kernel1(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<11>;
+; CHECK-NEXT: .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0: // %bb
+; CHECK-NEXT: mov.b32 %r3, 0;
+; CHECK-NEXT: mov.b32 %r4, %r3;
+; CHECK-NEXT: mov.b32 %r1, %r3;
+; CHECK-NEXT: mov.b32 %r2, %r4;
+; CHECK-NEXT: // begin inline asm
+; CHECK-NEXT: { .reg .pred p; setp.ne.b32 p, 66, 0; wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {%r1,%r2}, 64, 65, p, 67, 68, 69, 70; }
+; CHECK-EMPTY:
+; CHECK-NEXT: // end inline asm
+; CHECK-NEXT: mov.b32 %r5, %r3;
+; CHECK-NEXT: // begin inline asm
+; CHECK-NEXT: mbarrier.arrive.release.cta.shared::cta.b64 %rd1, [%r5], 1; // XXSTART
+; CHECK-NEXT: // end inline asm
+; CHECK-NEXT: wgmma.wait_group.sync.aligned 0;
+; CHECK-NEXT: mov.b32 %r6, %r3;
+; CHECK-NEXT: // begin inline asm
+; CHECK-NEXT: mbarrier.arrive.release.cta.shared::cta.b64 %rd2, [%r6], 1; // XXEND
+; CHECK-NEXT: // end inline asm
+; CHECK-NEXT: mul.rn.f32 %r7, %r1, 0f00000000;
+; CHECK-NEXT: mul.rn.f32 %r8, %r2, 0f00000000;
+; CHECK-NEXT: add.rn.f32 %r9, %r8, %r7;
+; CHECK-NEXT: shfl.sync.bfly.b32 %r10, %r9, 0, 0, 0;
+; CHECK-NEXT: ret;
+bb:
+ %i = call { <1 x float>, <1 x float> } asm sideeffect "{ .reg .pred p; setp.ne.b32 p, 66, 0; wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 {$0,$1}, 64, 65, p, 67, 68, 69, 70; }\0A", "=f,=f,0,1"(<1 x float> zeroinitializer, <1 x float> zeroinitializer)
+ %i1 = extractvalue { <1 x float>, <1 x float> } %i, 0
+ %i2 = extractvalue { <1 x float>, <1 x float> } %i, 1
+ %i3 = call i64 asm sideeffect " mbarrier.arrive.release.cta.shared::cta.b64 $0, [$1], 1; // XXSTART ", "=l,r"(ptr addrspace(3) null)
+ call void @llvm.nvvm.wgmma.wait_group.sync.aligned(i64 0)
+ %i4 = shufflevector <1 x float> %i1, <1 x float> %i2, <2 x i32> <i32 0, i32 1>
+ %i5 = call i64 asm sideeffect " mbarrier.arrive.release.cta.shared::cta.b64 $0, [$1], 1; // XXEND ", "=l,r"(ptr addrspace(3) null)
+ %i6 = fmul <2 x float> %i4, zeroinitializer
+ %i7 = extractelement <2 x float> %i6, i64 0
+ %i8 = extractelement <2 x float> %i6, i64 1
+ %i9 = fadd float %i8, %i7
+ %i10 = bitcast float %i9 to <1 x i32>
+ %i11 = extractelement <1 x i32> %i10, i64 0
+ %i12 = call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 0, i32 %i11, i32 0, i32 0)
+ ret void
+}
+
+declare void @llvm.nvvm.wgmma.wait_group.sync.aligned(i64 immarg) #0
+
+declare i32 @llvm.nvvm.shfl.sync.bfly.i32(i32, i32, i32, i32) #1
+
+attributes #0 = { convergent nounwind }
+attributes #1 = { convergent nocallback nounwind memory(inaccessiblemem: readwrite) }
|
1e355f1
to
ce79a0d
Compare
e13ac9f
to
99a54b3
Compare
Updated the code and simplified test cases. |
99a54b3
to
8db1673
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks reasonable to me, although I have not had a chance to dig into the issue this is trying to address very deeply. Please wait for @Artem-B to take a look as well.
P.S. I wonder if we should consider just making v2i32 a legal type.
I thought about this as well. I'm hesitant to introduce another legal type and plumb it through NVPTX (although it should be far less effort than v2f32) when there aren't any instructions that support v2i32. If we see more bug reports related to this issue then I'll reconsider it. |
6b7a81a
to
ddab65c
Compare
Looks like CI is red because of a broken Arm/Thumb2 test, but this shouldn't have anything to do with this change. |
@Artem-B ping for review. |
Hi @Artem-B, What are your thoughts on still having this checked in? This doesn't fix your bug but it still improves PTX codegen even in sm1xx. We don't want to be emitting these logical ops under any circumstances. |
It sounds reasonable, but it looks like we're fixing one particular case while we potentially may want a more general fix for construction of small vectors that fit into b32/b64 regardless of their element type. LGTM, but add a TODO that we may need a more general solution. |
I think we shouldn't need to worry about |
ddab65c
to
3155ed2
Compare
✅ With the latest revision this PR passed the C/C++ code formatter. |
b554bea
to
0449a2d
Compare
Transform: v2f32 (bitcast (v2i32 build_vector i32:A, i32:B)) ---> v2f32 (build_vector (f32 (bitcast i32:A)), (f32 (bitcast i32:B))) Since v2f32 is legal but v2i32 is not, v2i32 build_vector would be legalized as bitwise ops on i64, when we want each 32-bit element to be in its own register.
0449a2d
to
d75d057
Compare
Transform:
Since v2f32 is legal but v2i32 is not, v2i32 build_vector would be legalized as bitwise ops on i64, when we want each 32-bit element to be in its own register.