-
Notifications
You must be signed in to change notification settings - Fork 14.5k
[SLP]Initial support for non-power-of-2 (but whole reg) vectorization for stores #111194
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SLP]Initial support for non-power-of-2 (but whole reg) vectorization for stores #111194
Conversation
Created using spr 1.3.5
@llvm/pr-subscribers-llvm-transforms Author: Alexey Bataev (alexey-bataev) ChangesAllows non-power-of-2 vectorization for stores, but still requires, that Full diff: https://github.com/llvm/llvm-project/pull/111194.diff 2 Files Affected:
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index dc9ad5335f8a52..228837872506bf 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -278,6 +278,22 @@ static unsigned getFullVectorNumberOfElements(const TargetTransformInfo &TTI,
return bit_ceil(divideCeil(Sz, NumParts)) * NumParts;
}
+/// Returns the number of elements of the given type \p Ty, not greater than \p
+/// Sz, which forms type, which splits by \p TTI into whole vector types during
+/// legalization.
+static unsigned
+getFloorFullVectorNumberOfElements(const TargetTransformInfo &TTI, Type *Ty,
+ unsigned Sz) {
+ if (!isValidElementType(Ty))
+ return bit_floor(Sz);
+ // Find the number of elements, which forms full vectors.
+ unsigned NumParts = TTI.getNumberOfParts(getWidenedType(Ty, Sz));
+ if (NumParts == 0 || NumParts >= Sz)
+ return bit_floor(Sz);
+ unsigned RegVF = bit_ceil(divideCeil(Sz, NumParts));
+ return (Sz / RegVF) * RegVF;
+}
+
static void transformScalarShuffleIndiciesToVector(unsigned VecTyNumElements,
SmallVectorImpl<int> &Mask) {
// The ShuffleBuilder implementation use shufflevector to splat an "element".
@@ -7651,7 +7667,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
}
size_t NumUniqueScalarValues = UniqueValues.size();
bool IsFullVectors = hasFullVectorsOrPowerOf2(
- *TTI, UniqueValues.front()->getType(), NumUniqueScalarValues);
+ *TTI, getValueType(UniqueValues.front()), NumUniqueScalarValues);
if (NumUniqueScalarValues == VL.size() &&
(VectorizeNonPowerOf2 || IsFullVectors)) {
ReuseShuffleIndices.clear();
@@ -17385,7 +17401,11 @@ SLPVectorizerPass::vectorizeStoreChain(ArrayRef<Value *> Chain, BoUpSLP &R,
const unsigned Sz = R.getVectorElementSize(Chain[0]);
unsigned VF = Chain.size();
- if (!has_single_bit(Sz) || !has_single_bit(VF) || VF < 2 || VF < MinVF) {
+ if (!has_single_bit(Sz) ||
+ !hasFullVectorsOrPowerOf2(
+ *TTI, cast<StoreInst>(Chain.front())->getValueOperand()->getType(),
+ VF) ||
+ VF < 2 || VF < MinVF) {
// Check if vectorizing with a non-power-of-2 VF should be considered. At
// the moment, only consider cases where VF + 1 is a power-of-2, i.e. almost
// all vector lanes are used.
@@ -17403,10 +17423,12 @@ SLPVectorizerPass::vectorizeStoreChain(ArrayRef<Value *> Chain, BoUpSLP &R,
InstructionsState S = getSameOpcode(ValOps.getArrayRef(), *TLI);
if (all_of(ValOps, IsaPred<Instruction>) && ValOps.size() > 1) {
DenseSet<Value *> Stores(Chain.begin(), Chain.end());
- bool IsPowerOf2 =
- has_single_bit(ValOps.size()) ||
+ bool IsAllowedSize =
+ hasFullVectorsOrPowerOf2(*TTI, ValOps.front()->getType(),
+ ValOps.size()) ||
(VectorizeNonPowerOf2 && has_single_bit(ValOps.size() + 1));
- if ((!IsPowerOf2 && S.getOpcode() && S.getOpcode() != Instruction::Load &&
+ if ((!IsAllowedSize && S.getOpcode() &&
+ S.getOpcode() != Instruction::Load &&
(!S.MainOp->isSafeToRemove() ||
any_of(ValOps.getArrayRef(),
[&](Value *V) {
@@ -17417,7 +17439,7 @@ SLPVectorizerPass::vectorizeStoreChain(ArrayRef<Value *> Chain, BoUpSLP &R,
}));
}))) ||
(ValOps.size() > Chain.size() / 2 && !S.getOpcode())) {
- Size = (!IsPowerOf2 && S.getOpcode()) ? 1 : 2;
+ Size = (!IsAllowedSize && S.getOpcode()) ? 1 : 2;
return false;
}
}
@@ -17545,15 +17567,11 @@ bool SLPVectorizerPass::vectorizeStores(
unsigned MaxVF =
std::min(R.getMaximumVF(EltSize, Instruction::Store), MaxElts);
- unsigned MaxRegVF = MaxVF;
auto *Store = cast<StoreInst>(Operands[0]);
Type *StoreTy = Store->getValueOperand()->getType();
Type *ValueTy = StoreTy;
if (auto *Trunc = dyn_cast<TruncInst>(Store->getValueOperand()))
ValueTy = Trunc->getSrcTy();
- if (ValueTy == StoreTy &&
- R.getVectorElementSize(Store->getValueOperand()) <= EltSize)
- MaxVF = std::min<unsigned>(MaxVF, bit_floor(Operands.size()));
unsigned MinVF = std::max<unsigned>(
2, PowerOf2Ceil(TTI->getStoreMinimumVF(
R.getMinVF(DL->getTypeStoreSizeInBits(StoreTy)), StoreTy,
@@ -17571,10 +17589,21 @@ bool SLPVectorizerPass::vectorizeStores(
// First try vectorizing with a non-power-of-2 VF. At the moment, only
// consider cases where VF + 1 is a power-of-2, i.e. almost all vector
// lanes are used.
- unsigned CandVF =
- std::clamp<unsigned>(Operands.size(), MaxVF, MaxRegVF);
- if (has_single_bit(CandVF + 1))
+ unsigned CandVF = std::clamp<unsigned>(Operands.size(), MinVF, MaxVF);
+ if (has_single_bit(CandVF + 1)) {
NonPowerOf2VF = CandVF;
+ assert(NonPowerOf2VF != MaxVF &&
+ "Non-power-of-2 VF should not be equal to MaxVF");
+ }
+ }
+
+ unsigned MaxRegVF = MaxVF;
+ MaxVF = std::min<unsigned>(MaxVF, bit_floor(Operands.size()));
+ if (MaxVF < MinVF) {
+ LLVM_DEBUG(dbgs() << "SLP: Vectorization infeasible as MaxVF (" << MaxVF
+ << ") < "
+ << "MinVF (" << MinVF << ")\n");
+ continue;
}
unsigned Sz = 1 + Log2_32(MaxVF) - Log2_32(MinVF);
@@ -17742,7 +17771,7 @@ bool SLPVectorizerPass::vectorizeStores(
(Repeat > 1 && (RepeatChanged || !AnyProfitableGraph)))
break;
constexpr unsigned StoresLimit = 64;
- const unsigned MaxTotalNum = bit_floor(std::min<unsigned>(
+ const unsigned MaxTotalNum = std::min<unsigned>(
Operands.size(),
static_cast<unsigned>(
End -
@@ -17750,8 +17779,13 @@ bool SLPVectorizerPass::vectorizeStores(
RangeSizes.begin(),
find_if(RangeSizes, std::bind(IsNotVectorized, true,
std::placeholders::_1))) +
- 1)));
- unsigned VF = PowerOf2Ceil(CandidateVFs.front()) * 2;
+ 1));
+ unsigned VF = bit_ceil(CandidateVFs.front()) * 2;
+ unsigned Limit =
+ getFloorFullVectorNumberOfElements(*TTI, StoreTy, MaxTotalNum);
+ CandidateVFs.clear();
+ if (bit_floor(Limit) == VF)
+ CandidateVFs.push_back(Limit);
if (VF > MaxTotalNum || VF >= StoresLimit)
break;
for_each(RangeSizes, [&](std::pair<unsigned, unsigned> &P) {
@@ -17760,7 +17794,6 @@ bool SLPVectorizerPass::vectorizeStores(
});
// Last attempt to vectorize max number of elements, if all previous
// attempts were unsuccessful because of the cost issues.
- CandidateVFs.clear();
CandidateVFs.push_back(VF);
}
}
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/long-full-reg-stores.ll b/llvm/test/Transforms/SLPVectorizer/X86/long-full-reg-stores.ll
index b5f993f986c7cc..aff66dd7c10ea7 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/long-full-reg-stores.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/long-full-reg-stores.ll
@@ -4,30 +4,16 @@
define void @test(ptr noalias %0, ptr noalias %1) {
; CHECK-LABEL: define void @test(
; CHECK-SAME: ptr noalias [[TMP0:%.*]], ptr noalias [[TMP1:%.*]]) {
-; CHECK-NEXT: [[TMP3:%.*]] = getelementptr i8, ptr [[TMP1]], i64 24
-; CHECK-NEXT: [[TMP4:%.*]] = getelementptr i8, ptr [[TMP1]], i64 48
; CHECK-NEXT: [[TMP5:%.*]] = getelementptr i8, ptr [[TMP1]], i64 8
-; CHECK-NEXT: [[TMP6:%.*]] = getelementptr i8, ptr [[TMP1]], i64 16
-; CHECK-NEXT: [[TMP7:%.*]] = getelementptr i8, ptr [[TMP0]], i64 24
-; CHECK-NEXT: [[TMP8:%.*]] = load double, ptr [[TMP7]], align 8
-; CHECK-NEXT: store double [[TMP8]], ptr [[TMP5]], align 8
; CHECK-NEXT: [[TMP9:%.*]] = getelementptr i8, ptr [[TMP0]], i64 48
-; CHECK-NEXT: [[TMP10:%.*]] = load double, ptr [[TMP9]], align 16
-; CHECK-NEXT: store double [[TMP10]], ptr [[TMP6]], align 16
; CHECK-NEXT: [[TMP11:%.*]] = getelementptr i8, ptr [[TMP0]], i64 8
-; CHECK-NEXT: [[TMP12:%.*]] = load double, ptr [[TMP11]], align 8
-; CHECK-NEXT: store double [[TMP12]], ptr [[TMP3]], align 8
-; CHECK-NEXT: [[TMP13:%.*]] = getelementptr i8, ptr [[TMP0]], i64 32
-; CHECK-NEXT: [[TMP14:%.*]] = load double, ptr [[TMP13]], align 16
-; CHECK-NEXT: [[TMP15:%.*]] = getelementptr i8, ptr [[TMP1]], i64 32
-; CHECK-NEXT: store double [[TMP14]], ptr [[TMP15]], align 16
-; CHECK-NEXT: [[TMP16:%.*]] = getelementptr i8, ptr [[TMP0]], i64 56
-; CHECK-NEXT: [[TMP17:%.*]] = load double, ptr [[TMP16]], align 8
-; CHECK-NEXT: [[TMP18:%.*]] = getelementptr i8, ptr [[TMP1]], i64 40
-; CHECK-NEXT: store double [[TMP17]], ptr [[TMP18]], align 8
-; CHECK-NEXT: [[TMP19:%.*]] = getelementptr i8, ptr [[TMP0]], i64 16
-; CHECK-NEXT: [[TMP20:%.*]] = load double, ptr [[TMP19]], align 16
-; CHECK-NEXT: store double [[TMP20]], ptr [[TMP4]], align 16
+; CHECK-NEXT: [[TMP6:%.*]] = load <2 x double>, ptr [[TMP9]], align 16
+; CHECK-NEXT: [[TMP7:%.*]] = load <4 x double>, ptr [[TMP11]], align 8
+; CHECK-NEXT: [[TMP8:%.*]] = shufflevector <4 x double> [[TMP7]], <4 x double> poison, <6 x i32> <i32 0, i32 1, i32 2, i32 3, i32 poison, i32 poison>
+; CHECK-NEXT: [[TMP12:%.*]] = shufflevector <2 x double> [[TMP6]], <2 x double> poison, <6 x i32> <i32 0, i32 1, i32 poison, i32 poison, i32 poison, i32 poison>
+; CHECK-NEXT: [[TMP10:%.*]] = shufflevector <2 x double> [[TMP6]], <2 x double> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
+; CHECK-NEXT: [[TMP13:%.*]] = shufflevector <4 x double> [[TMP7]], <4 x double> [[TMP10]], <6 x i32> <i32 2, i32 4, i32 0, i32 3, i32 5, i32 1>
+; CHECK-NEXT: store <6 x double> [[TMP13]], ptr [[TMP5]], align 8
; CHECK-NEXT: [[TMP21:%.*]] = getelementptr i8, ptr [[TMP0]], i64 40
; CHECK-NEXT: [[TMP22:%.*]] = load double, ptr [[TMP21]], align 8
; CHECK-NEXT: [[TMP23:%.*]] = getelementptr i8, ptr [[TMP1]], i64 56
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Allows non-power-of-2 vectorization for stores, but still requires, that
vectorized number of elements forms full vector registers.