Skip to content

[AArch64] Extend custom lowering for SVE types in @llvm.experimental.vector.compress #105515

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 70 additions & 17 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1781,16 +1781,18 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
MVT::v2f32, MVT::v4f32, MVT::v2f64})
setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom);

// We can lower types that have <vscale x {2|4}> elements to compact.
// We can lower all legal (or smaller) SVE types to `compact`.
for (auto VT :
{MVT::nxv2i8, MVT::nxv2i16, MVT::nxv2i32, MVT::nxv2i64, MVT::nxv2f32,
MVT::nxv2f64, MVT::nxv4i8, MVT::nxv4i16, MVT::nxv4i32, MVT::nxv4f32})
MVT::nxv2f64, MVT::nxv4i8, MVT::nxv4i16, MVT::nxv4i32, MVT::nxv4f32,
MVT::nxv8i8, MVT::nxv8i16, MVT::nxv16i8})
setOperationAction(ISD::VECTOR_COMPRESS, VT, Custom);

// If we have SVE, we can use SVE logic for legal (or smaller than legal)
// NEON vectors in the lowest bits of the SVE register.
for (auto VT : {MVT::v2i8, MVT::v2i16, MVT::v2i32, MVT::v2i64, MVT::v2f32,
MVT::v2f64, MVT::v4i8, MVT::v4i16, MVT::v4i32, MVT::v4f32})
MVT::v2f64, MVT::v4i8, MVT::v4i16, MVT::v4i32, MVT::v4f32,
MVT::v8i8, MVT::v8i16, MVT::v16i8})
setOperationAction(ISD::VECTOR_COMPRESS, VT, Custom);

// Histcnt is SVE2 only
Expand Down Expand Up @@ -6648,6 +6650,7 @@ SDValue AArch64TargetLowering::LowerVECTOR_COMPRESS(SDValue Op,
EVT ElmtVT = VecVT.getVectorElementType();
const bool IsFixedLength = VecVT.isFixedLengthVector();
const bool HasPassthru = !Passthru.isUndef();
bool CompressedViaStack = false;
unsigned MinElmts = VecVT.getVectorElementCount().getKnownMinValue();
EVT FixedVecVT = MVT::getVectorVT(ElmtVT.getSimpleVT(), MinElmts);

Expand All @@ -6659,10 +6662,6 @@ SDValue AArch64TargetLowering::LowerVECTOR_COMPRESS(SDValue Op,
if (IsFixedLength && VecVT.getSizeInBits().getFixedValue() > 128)
return SDValue();

// Only <vscale x {4|2} x {i32|i64}> supported for compact.
if (MinElmts != 2 && MinElmts != 4)
return SDValue();

// We can use the SVE register containing the NEON vector in its lowest bits.
if (IsFixedLength) {
EVT ScalableVecVT =
Expand Down Expand Up @@ -6690,19 +6689,73 @@ SDValue AArch64TargetLowering::LowerVECTOR_COMPRESS(SDValue Op,
EVT ContainerVT = getSVEContainerType(VecVT);
EVT CastVT = VecVT.changeVectorElementTypeToInteger();

// Convert to i32 or i64 for smaller types, as these are the only supported
// sizes for compact.
if (ContainerVT != VecVT) {
Vec = DAG.getBitcast(CastVT, Vec);
Vec = DAG.getNode(ISD::ANY_EXTEND, DL, ContainerVT, Vec);
}
// These vector types aren't supported by the `compact` instruction, so
// we split and compact them as <vscale x 4 x i32>, store them on the stack,
// and then merge them again. In the other cases, emit compact directly.
SDValue Compressed;
if (VecVT == MVT::nxv8i16 || VecVT == MVT::nxv8i8 || VecVT == MVT::nxv16i8) {
SDValue Chain = DAG.getEntryNode();
SDValue StackPtr = DAG.CreateStackTemporary(
VecVT.getStoreSize(), DAG.getReducedAlign(VecVT, /*UseABI=*/false));
MachineFunction &MF = DAG.getMachineFunction();

EVT PartialVecVT =
EVT::getVectorVT(*DAG.getContext(), ElmtVT, 4, /*isScalable*/ true);
EVT OffsetVT = getVectorIdxTy(DAG.getDataLayout());
SDValue Offset = DAG.getConstant(0, DL, OffsetVT);

for (unsigned I = 0; I < MinElmts; I += 4) {
SDValue VectorIdx = DAG.getVectorIdxConstant(I, DL);
SDValue PartialVec =
DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, PartialVecVT, Vec, VectorIdx);
PartialVec = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::nxv4i32, PartialVec);

SDValue PartialMask =
DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::nxv4i1, Mask, VectorIdx);

SDValue PartialCompressed = DAG.getNode(
ISD::INTRINSIC_WO_CHAIN, DL, MVT::nxv4i32,
DAG.getConstant(Intrinsic::aarch64_sve_compact, DL, MVT::i64),
PartialMask, PartialVec);
PartialCompressed =
DAG.getNode(ISD::TRUNCATE, DL, PartialVecVT, PartialCompressed);

SDValue OutPtr = DAG.getNode(
ISD::ADD, DL, StackPtr.getValueType(), StackPtr,
DAG.getNode(
ISD::MUL, DL, OffsetVT, Offset,
DAG.getConstant(ElmtVT.getScalarSizeInBits() / 8, DL, OffsetVT)));
Chain = DAG.getStore(Chain, DL, PartialCompressed, OutPtr,
MachinePointerInfo::getUnknownStack(MF));

SDValue PartialOffset = DAG.getNode(
ISD::INTRINSIC_WO_CHAIN, DL, OffsetVT,
DAG.getConstant(Intrinsic::aarch64_sve_cntp, DL, MVT::i64),
PartialMask, PartialMask);
Offset = DAG.getNode(ISD::ADD, DL, OffsetVT, Offset, PartialOffset);
}

MachinePointerInfo PtrInfo = MachinePointerInfo::getFixedStack(
MF, cast<FrameIndexSDNode>(StackPtr.getNode())->getIndex());
Compressed = DAG.getLoad(VecVT, DL, Chain, StackPtr, PtrInfo);
CompressedViaStack = true;
} else {
// Convert to i32 or i64 for smaller types, as these are the only supported
// sizes for compact.
if (ContainerVT != VecVT) {
Vec = DAG.getBitcast(CastVT, Vec);
Vec = DAG.getNode(ISD::ANY_EXTEND, DL, ContainerVT, Vec);
}

SDValue Compressed = DAG.getNode(
ISD::INTRINSIC_WO_CHAIN, DL, Vec.getValueType(),
DAG.getConstant(Intrinsic::aarch64_sve_compact, DL, MVT::i64), Mask, Vec);
Compressed = DAG.getNode(
ISD::INTRINSIC_WO_CHAIN, DL, Vec.getValueType(),
DAG.getConstant(Intrinsic::aarch64_sve_compact, DL, MVT::i64), Mask,
Vec);
}

// compact fills with 0s, so if our passthru is all 0s, do nothing here.
if (HasPassthru && !ISD::isConstantSplatVectorAllZeros(Passthru.getNode())) {
if (HasPassthru && (!ISD::isConstantSplatVectorAllZeros(Passthru.getNode()) ||
CompressedViaStack)) {
SDValue Offset = DAG.getNode(
ISD::INTRINSIC_WO_CHAIN, DL, MVT::i64,
DAG.getConstant(Intrinsic::aarch64_sve_cntp, DL, MVT::i64), Mask, Mask);
Expand Down
160 changes: 160 additions & 0 deletions llvm/test/CodeGen/AArch64/sve-vector-compress.ll
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,132 @@ define <vscale x 4 x float> @test_compress_nxv4f32(<vscale x 4 x float> %vec, <v
ret <vscale x 4 x float> %out
}

define <vscale x 8 x i8> @test_compress_nxv8i8(<vscale x 8 x i8> %vec, <vscale x 8 x i1> %mask) {
; CHECK-LABEL: test_compress_nxv8i8:
; CHECK: // %bb.0:
; CHECK-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill
; CHECK-NEXT: addvl sp, sp, #-1
; CHECK-NEXT: .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x08, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 8 * VG
; CHECK-NEXT: .cfi_offset w29, -16
; CHECK-NEXT: uunpklo z1.s, z0.h
; CHECK-NEXT: uunpkhi z0.s, z0.h
; CHECK-NEXT: addpl x9, sp, #4
; CHECK-NEXT: punpklo p1.h, p0.b
; CHECK-NEXT: punpkhi p0.h, p0.b
; CHECK-NEXT: cntp x8, p1, p1.s
; CHECK-NEXT: compact z1.s, p1, z1.s
; CHECK-NEXT: compact z0.s, p0, z0.s
; CHECK-NEXT: ptrue p0.s
; CHECK-NEXT: ptrue p1.h
; CHECK-NEXT: st1b { z1.s }, p0, [sp, #2, mul vl]
; CHECK-NEXT: st1b { z0.s }, p0, [x9, x8]
; CHECK-NEXT: ld1b { z0.h }, p1/z, [sp, #1, mul vl]
; CHECK-NEXT: addvl sp, sp, #1
; CHECK-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
; CHECK-NEXT: ret
%out = call <vscale x 8 x i8> @llvm.experimental.vector.compress(<vscale x 8 x i8> %vec, <vscale x 8 x i1> %mask, <vscale x 8 x i8> undef)
ret <vscale x 8 x i8> %out
}

define <vscale x 8 x i16> @test_compress_nxv8i16(<vscale x 8 x i16> %vec, <vscale x 8 x i1> %mask) {
; CHECK-LABEL: test_compress_nxv8i16:
; CHECK: // %bb.0:
; CHECK-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill
; CHECK-NEXT: addvl sp, sp, #-1
; CHECK-NEXT: .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x08, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 8 * VG
; CHECK-NEXT: .cfi_offset w29, -16
; CHECK-NEXT: uunpklo z1.s, z0.h
; CHECK-NEXT: uunpkhi z0.s, z0.h
; CHECK-NEXT: mov x9, sp
; CHECK-NEXT: punpklo p1.h, p0.b
; CHECK-NEXT: punpkhi p0.h, p0.b
; CHECK-NEXT: cntp x8, p1, p1.s
; CHECK-NEXT: compact z1.s, p1, z1.s
; CHECK-NEXT: compact z0.s, p0, z0.s
; CHECK-NEXT: ptrue p0.s
; CHECK-NEXT: ptrue p1.h
; CHECK-NEXT: st1h { z1.s }, p0, [sp]
; CHECK-NEXT: st1h { z0.s }, p0, [x9, x8, lsl #1]
; CHECK-NEXT: ld1h { z0.h }, p1/z, [sp]
; CHECK-NEXT: addvl sp, sp, #1
; CHECK-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
; CHECK-NEXT: ret
%out = call <vscale x 8 x i16> @llvm.experimental.vector.compress(<vscale x 8 x i16> %vec, <vscale x 8 x i1> %mask, <vscale x 8 x i16> undef)
ret <vscale x 8 x i16> %out
}

define <vscale x 16 x i8> @test_compress_nxv16i8(<vscale x 16 x i8> %vec, <vscale x 16 x i1> %mask) {
; CHECK-LABEL: test_compress_nxv16i8:
; CHECK: // %bb.0:
; CHECK-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill
; CHECK-NEXT: addvl sp, sp, #-1
; CHECK-NEXT: .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x08, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 8 * VG
; CHECK-NEXT: .cfi_offset w29, -16
; CHECK-NEXT: uunpklo z1.h, z0.b
; CHECK-NEXT: punpklo p2.h, p0.b
; CHECK-NEXT: mov x9, sp
; CHECK-NEXT: uunpkhi z0.h, z0.b
; CHECK-NEXT: punpkhi p0.h, p0.b
; CHECK-NEXT: punpklo p3.h, p2.b
; CHECK-NEXT: punpkhi p2.h, p2.b
; CHECK-NEXT: uunpklo z2.s, z1.h
; CHECK-NEXT: uunpkhi z1.s, z1.h
; CHECK-NEXT: cntp x8, p3, p3.s
; CHECK-NEXT: uunpklo z3.s, z0.h
; CHECK-NEXT: ptrue p1.s
; CHECK-NEXT: uunpkhi z0.s, z0.h
; CHECK-NEXT: compact z2.s, p3, z2.s
; CHECK-NEXT: compact z1.s, p2, z1.s
; CHECK-NEXT: punpklo p3.h, p0.b
; CHECK-NEXT: punpkhi p0.h, p0.b
; CHECK-NEXT: compact z0.s, p0, z0.s
; CHECK-NEXT: ptrue p0.b
; CHECK-NEXT: st1b { z2.s }, p1, [sp]
; CHECK-NEXT: st1b { z1.s }, p1, [x9, x8]
; CHECK-NEXT: compact z1.s, p3, z3.s
; CHECK-NEXT: incp x8, p2.s
; CHECK-NEXT: st1b { z1.s }, p1, [x9, x8]
; CHECK-NEXT: incp x8, p3.s
; CHECK-NEXT: st1b { z0.s }, p1, [x9, x8]
; CHECK-NEXT: ld1b { z0.b }, p0/z, [sp]
; CHECK-NEXT: addvl sp, sp, #1
; CHECK-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
; CHECK-NEXT: ret
%out = call <vscale x 16 x i8> @llvm.experimental.vector.compress(<vscale x 16 x i8> %vec, <vscale x 16 x i1> %mask, <vscale x 16 x i8> undef)
ret <vscale x 16 x i8> %out
}

define <vscale x 8 x i16> @test_compress_nxv8i16_with_0_passthru(<vscale x 8 x i16> %vec, <vscale x 8 x i1> %mask) {
; CHECK-LABEL: test_compress_nxv8i16_with_0_passthru:
; CHECK: // %bb.0:
; CHECK-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill
; CHECK-NEXT: addvl sp, sp, #-1
; CHECK-NEXT: .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x08, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 8 * VG
; CHECK-NEXT: .cfi_offset w29, -16
; CHECK-NEXT: uunpklo z1.s, z0.h
; CHECK-NEXT: uunpkhi z0.s, z0.h
; CHECK-NEXT: mov x10, sp
; CHECK-NEXT: punpklo p1.h, p0.b
; CHECK-NEXT: punpkhi p2.h, p0.b
; CHECK-NEXT: cntp x8, p1, p1.s
; CHECK-NEXT: compact z1.s, p1, z1.s
; CHECK-NEXT: compact z0.s, p2, z0.s
; CHECK-NEXT: ptrue p1.s
; CHECK-NEXT: cntp x9, p0, p0.h
; CHECK-NEXT: ptrue p0.h
; CHECK-NEXT: st1h { z1.s }, p1, [sp]
; CHECK-NEXT: mov z1.h, #0 // =0x0
; CHECK-NEXT: st1h { z0.s }, p1, [x10, x8, lsl #1]
; CHECK-NEXT: ld1h { z0.h }, p0/z, [sp]
; CHECK-NEXT: whilelo p0.h, xzr, x9
; CHECK-NEXT: sel z0.h, p0, z0.h, z1.h
; CHECK-NEXT: addvl sp, sp, #1
; CHECK-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
; CHECK-NEXT: ret
%out = call <vscale x 8 x i16> @llvm.experimental.vector.compress(<vscale x 8 x i16> %vec, <vscale x 8 x i1> %mask, <vscale x 8 x i16> splat (i16 0))
ret <vscale x 8 x i16> %out
}

define <vscale x 4 x i4> @test_compress_illegal_element_type(<vscale x 4 x i4> %vec, <vscale x 4 x i1> %mask) {
; CHECK-LABEL: test_compress_illegal_element_type:
; CHECK: // %bb.0:
Expand Down Expand Up @@ -240,6 +366,40 @@ define <2 x i16> @test_compress_v2i16_with_sve(<2 x i16> %vec, <2 x i1> %mask) {
ret <2 x i16> %out
}

define <8 x i16> @test_compress_v8i16_with_sve(<8 x i16> %vec, <8 x i1> %mask) {
; CHECK-LABEL: test_compress_v8i16_with_sve:
; CHECK: // %bb.0:
; CHECK-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill
; CHECK-NEXT: addvl sp, sp, #-1
; CHECK-NEXT: .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x08, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 8 * VG
; CHECK-NEXT: .cfi_offset w29, -16
; CHECK-NEXT: ushll v1.8h, v1.8b, #0
; CHECK-NEXT: ptrue p0.h
; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
; CHECK-NEXT: mov x9, sp
; CHECK-NEXT: shl v1.8h, v1.8h, #15
; CHECK-NEXT: cmlt v1.8h, v1.8h, #0
; CHECK-NEXT: and z1.h, z1.h, #0x1
; CHECK-NEXT: cmpne p1.h, p0/z, z1.h, #0
; CHECK-NEXT: uunpklo z1.s, z0.h
; CHECK-NEXT: uunpkhi z0.s, z0.h
; CHECK-NEXT: punpklo p2.h, p1.b
; CHECK-NEXT: punpkhi p1.h, p1.b
; CHECK-NEXT: compact z1.s, p2, z1.s
; CHECK-NEXT: cntp x8, p2, p2.s
; CHECK-NEXT: compact z0.s, p1, z0.s
; CHECK-NEXT: ptrue p1.s
; CHECK-NEXT: st1h { z1.s }, p1, [sp]
; CHECK-NEXT: st1h { z0.s }, p1, [x9, x8, lsl #1]
; CHECK-NEXT: ld1h { z0.h }, p0/z, [sp]
; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
; CHECK-NEXT: addvl sp, sp, #1
; CHECK-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
; CHECK-NEXT: ret
%out = call <8 x i16> @llvm.experimental.vector.compress(<8 x i16> %vec, <8 x i1> %mask, <8 x i16> undef)
ret <8 x i16> %out
}


define <vscale x 4 x i32> @test_compress_nxv4i32_with_passthru(<vscale x 4 x i32> %vec, <vscale x 4 x i1> %mask, <vscale x 4 x i32> %passthru) {
; CHECK-LABEL: test_compress_nxv4i32_with_passthru:
Expand Down
Loading