Skip to content

[SLP]Initial support for non-power-of-2 (but whole reg) vectorization for stores #111194

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

alexey-bataev
Copy link
Member

Allows non-power-of-2 vectorization for stores, but still requires, that
vectorized number of elements forms full vector registers.

Created using spr 1.3.5
@llvmbot
Copy link
Member

llvmbot commented Oct 4, 2024

@llvm/pr-subscribers-llvm-transforms

Author: Alexey Bataev (alexey-bataev)

Changes

Allows non-power-of-2 vectorization for stores, but still requires, that
vectorized number of elements forms full vector registers.


Full diff: https://github.com/llvm/llvm-project/pull/111194.diff

2 Files Affected:

  • (modified) llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp (+50-17)
  • (modified) llvm/test/Transforms/SLPVectorizer/X86/long-full-reg-stores.ll (+7-21)
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index dc9ad5335f8a52..228837872506bf 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -278,6 +278,22 @@ static unsigned getFullVectorNumberOfElements(const TargetTransformInfo &TTI,
   return bit_ceil(divideCeil(Sz, NumParts)) * NumParts;
 }
 
+/// Returns the number of elements of the given type \p Ty, not greater than \p
+/// Sz, which forms type, which splits by \p TTI into whole vector types during
+/// legalization.
+static unsigned
+getFloorFullVectorNumberOfElements(const TargetTransformInfo &TTI, Type *Ty,
+                                   unsigned Sz) {
+  if (!isValidElementType(Ty))
+    return bit_floor(Sz);
+  // Find the number of elements, which forms full vectors.
+  unsigned NumParts = TTI.getNumberOfParts(getWidenedType(Ty, Sz));
+  if (NumParts == 0 || NumParts >= Sz)
+    return bit_floor(Sz);
+  unsigned RegVF = bit_ceil(divideCeil(Sz, NumParts));
+  return (Sz / RegVF) * RegVF;
+}
+
 static void transformScalarShuffleIndiciesToVector(unsigned VecTyNumElements,
                                                    SmallVectorImpl<int> &Mask) {
   // The ShuffleBuilder implementation use shufflevector to splat an "element".
@@ -7651,7 +7667,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
     }
     size_t NumUniqueScalarValues = UniqueValues.size();
     bool IsFullVectors = hasFullVectorsOrPowerOf2(
-        *TTI, UniqueValues.front()->getType(), NumUniqueScalarValues);
+        *TTI, getValueType(UniqueValues.front()), NumUniqueScalarValues);
     if (NumUniqueScalarValues == VL.size() &&
         (VectorizeNonPowerOf2 || IsFullVectors)) {
       ReuseShuffleIndices.clear();
@@ -17385,7 +17401,11 @@ SLPVectorizerPass::vectorizeStoreChain(ArrayRef<Value *> Chain, BoUpSLP &R,
   const unsigned Sz = R.getVectorElementSize(Chain[0]);
   unsigned VF = Chain.size();
 
-  if (!has_single_bit(Sz) || !has_single_bit(VF) || VF < 2 || VF < MinVF) {
+  if (!has_single_bit(Sz) ||
+      !hasFullVectorsOrPowerOf2(
+          *TTI, cast<StoreInst>(Chain.front())->getValueOperand()->getType(),
+          VF) ||
+      VF < 2 || VF < MinVF) {
     // Check if vectorizing with a non-power-of-2 VF should be considered. At
     // the moment, only consider cases where VF + 1 is a power-of-2, i.e. almost
     // all vector lanes are used.
@@ -17403,10 +17423,12 @@ SLPVectorizerPass::vectorizeStoreChain(ArrayRef<Value *> Chain, BoUpSLP &R,
   InstructionsState S = getSameOpcode(ValOps.getArrayRef(), *TLI);
   if (all_of(ValOps, IsaPred<Instruction>) && ValOps.size() > 1) {
     DenseSet<Value *> Stores(Chain.begin(), Chain.end());
-    bool IsPowerOf2 =
-        has_single_bit(ValOps.size()) ||
+    bool IsAllowedSize =
+        hasFullVectorsOrPowerOf2(*TTI, ValOps.front()->getType(),
+                                 ValOps.size()) ||
         (VectorizeNonPowerOf2 && has_single_bit(ValOps.size() + 1));
-    if ((!IsPowerOf2 && S.getOpcode() && S.getOpcode() != Instruction::Load &&
+    if ((!IsAllowedSize && S.getOpcode() &&
+         S.getOpcode() != Instruction::Load &&
          (!S.MainOp->isSafeToRemove() ||
           any_of(ValOps.getArrayRef(),
                  [&](Value *V) {
@@ -17417,7 +17439,7 @@ SLPVectorizerPass::vectorizeStoreChain(ArrayRef<Value *> Chain, BoUpSLP &R,
                            }));
                  }))) ||
         (ValOps.size() > Chain.size() / 2 && !S.getOpcode())) {
-      Size = (!IsPowerOf2 && S.getOpcode()) ? 1 : 2;
+      Size = (!IsAllowedSize && S.getOpcode()) ? 1 : 2;
       return false;
     }
   }
@@ -17545,15 +17567,11 @@ bool SLPVectorizerPass::vectorizeStores(
 
       unsigned MaxVF =
           std::min(R.getMaximumVF(EltSize, Instruction::Store), MaxElts);
-      unsigned MaxRegVF = MaxVF;
       auto *Store = cast<StoreInst>(Operands[0]);
       Type *StoreTy = Store->getValueOperand()->getType();
       Type *ValueTy = StoreTy;
       if (auto *Trunc = dyn_cast<TruncInst>(Store->getValueOperand()))
         ValueTy = Trunc->getSrcTy();
-      if (ValueTy == StoreTy &&
-          R.getVectorElementSize(Store->getValueOperand()) <= EltSize)
-        MaxVF = std::min<unsigned>(MaxVF, bit_floor(Operands.size()));
       unsigned MinVF = std::max<unsigned>(
           2, PowerOf2Ceil(TTI->getStoreMinimumVF(
                  R.getMinVF(DL->getTypeStoreSizeInBits(StoreTy)), StoreTy,
@@ -17571,10 +17589,21 @@ bool SLPVectorizerPass::vectorizeStores(
         // First try vectorizing with a non-power-of-2 VF. At the moment, only
         // consider cases where VF + 1 is a power-of-2, i.e. almost all vector
         // lanes are used.
-        unsigned CandVF =
-            std::clamp<unsigned>(Operands.size(), MaxVF, MaxRegVF);
-        if (has_single_bit(CandVF + 1))
+        unsigned CandVF = std::clamp<unsigned>(Operands.size(), MinVF, MaxVF);
+        if (has_single_bit(CandVF + 1)) {
           NonPowerOf2VF = CandVF;
+          assert(NonPowerOf2VF != MaxVF &&
+                 "Non-power-of-2 VF should not be equal to MaxVF");
+        }
+      }
+
+      unsigned MaxRegVF = MaxVF;
+      MaxVF = std::min<unsigned>(MaxVF, bit_floor(Operands.size()));
+      if (MaxVF < MinVF) {
+        LLVM_DEBUG(dbgs() << "SLP: Vectorization infeasible as MaxVF (" << MaxVF
+                          << ") < "
+                          << "MinVF (" << MinVF << ")\n");
+        continue;
       }
 
       unsigned Sz = 1 + Log2_32(MaxVF) - Log2_32(MinVF);
@@ -17742,7 +17771,7 @@ bool SLPVectorizerPass::vectorizeStores(
             (Repeat > 1 && (RepeatChanged || !AnyProfitableGraph)))
           break;
         constexpr unsigned StoresLimit = 64;
-        const unsigned MaxTotalNum = bit_floor(std::min<unsigned>(
+        const unsigned MaxTotalNum = std::min<unsigned>(
             Operands.size(),
             static_cast<unsigned>(
                 End -
@@ -17750,8 +17779,13 @@ bool SLPVectorizerPass::vectorizeStores(
                     RangeSizes.begin(),
                     find_if(RangeSizes, std::bind(IsNotVectorized, true,
                                                   std::placeholders::_1))) +
-                1)));
-        unsigned VF = PowerOf2Ceil(CandidateVFs.front()) * 2;
+                1));
+        unsigned VF = bit_ceil(CandidateVFs.front()) * 2;
+        unsigned Limit =
+            getFloorFullVectorNumberOfElements(*TTI, StoreTy, MaxTotalNum);
+        CandidateVFs.clear();
+        if (bit_floor(Limit) == VF)
+          CandidateVFs.push_back(Limit);
         if (VF > MaxTotalNum || VF >= StoresLimit)
           break;
         for_each(RangeSizes, [&](std::pair<unsigned, unsigned> &P) {
@@ -17760,7 +17794,6 @@ bool SLPVectorizerPass::vectorizeStores(
         });
         // Last attempt to vectorize max number of elements, if all previous
         // attempts were unsuccessful because of the cost issues.
-        CandidateVFs.clear();
         CandidateVFs.push_back(VF);
       }
     }
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/long-full-reg-stores.ll b/llvm/test/Transforms/SLPVectorizer/X86/long-full-reg-stores.ll
index b5f993f986c7cc..aff66dd7c10ea7 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/long-full-reg-stores.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/long-full-reg-stores.ll
@@ -4,30 +4,16 @@
 define void @test(ptr noalias %0, ptr noalias %1) {
 ; CHECK-LABEL: define void @test(
 ; CHECK-SAME: ptr noalias [[TMP0:%.*]], ptr noalias [[TMP1:%.*]]) {
-; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr i8, ptr [[TMP1]], i64 24
-; CHECK-NEXT:    [[TMP4:%.*]] = getelementptr i8, ptr [[TMP1]], i64 48
 ; CHECK-NEXT:    [[TMP5:%.*]] = getelementptr i8, ptr [[TMP1]], i64 8
-; CHECK-NEXT:    [[TMP6:%.*]] = getelementptr i8, ptr [[TMP1]], i64 16
-; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr i8, ptr [[TMP0]], i64 24
-; CHECK-NEXT:    [[TMP8:%.*]] = load double, ptr [[TMP7]], align 8
-; CHECK-NEXT:    store double [[TMP8]], ptr [[TMP5]], align 8
 ; CHECK-NEXT:    [[TMP9:%.*]] = getelementptr i8, ptr [[TMP0]], i64 48
-; CHECK-NEXT:    [[TMP10:%.*]] = load double, ptr [[TMP9]], align 16
-; CHECK-NEXT:    store double [[TMP10]], ptr [[TMP6]], align 16
 ; CHECK-NEXT:    [[TMP11:%.*]] = getelementptr i8, ptr [[TMP0]], i64 8
-; CHECK-NEXT:    [[TMP12:%.*]] = load double, ptr [[TMP11]], align 8
-; CHECK-NEXT:    store double [[TMP12]], ptr [[TMP3]], align 8
-; CHECK-NEXT:    [[TMP13:%.*]] = getelementptr i8, ptr [[TMP0]], i64 32
-; CHECK-NEXT:    [[TMP14:%.*]] = load double, ptr [[TMP13]], align 16
-; CHECK-NEXT:    [[TMP15:%.*]] = getelementptr i8, ptr [[TMP1]], i64 32
-; CHECK-NEXT:    store double [[TMP14]], ptr [[TMP15]], align 16
-; CHECK-NEXT:    [[TMP16:%.*]] = getelementptr i8, ptr [[TMP0]], i64 56
-; CHECK-NEXT:    [[TMP17:%.*]] = load double, ptr [[TMP16]], align 8
-; CHECK-NEXT:    [[TMP18:%.*]] = getelementptr i8, ptr [[TMP1]], i64 40
-; CHECK-NEXT:    store double [[TMP17]], ptr [[TMP18]], align 8
-; CHECK-NEXT:    [[TMP19:%.*]] = getelementptr i8, ptr [[TMP0]], i64 16
-; CHECK-NEXT:    [[TMP20:%.*]] = load double, ptr [[TMP19]], align 16
-; CHECK-NEXT:    store double [[TMP20]], ptr [[TMP4]], align 16
+; CHECK-NEXT:    [[TMP6:%.*]] = load <2 x double>, ptr [[TMP9]], align 16
+; CHECK-NEXT:    [[TMP7:%.*]] = load <4 x double>, ptr [[TMP11]], align 8
+; CHECK-NEXT:    [[TMP8:%.*]] = shufflevector <4 x double> [[TMP7]], <4 x double> poison, <6 x i32> <i32 0, i32 1, i32 2, i32 3, i32 poison, i32 poison>
+; CHECK-NEXT:    [[TMP12:%.*]] = shufflevector <2 x double> [[TMP6]], <2 x double> poison, <6 x i32> <i32 0, i32 1, i32 poison, i32 poison, i32 poison, i32 poison>
+; CHECK-NEXT:    [[TMP10:%.*]] = shufflevector <2 x double> [[TMP6]], <2 x double> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
+; CHECK-NEXT:    [[TMP13:%.*]] = shufflevector <4 x double> [[TMP7]], <4 x double> [[TMP10]], <6 x i32> <i32 2, i32 4, i32 0, i32 3, i32 5, i32 1>
+; CHECK-NEXT:    store <6 x double> [[TMP13]], ptr [[TMP5]], align 8
 ; CHECK-NEXT:    [[TMP21:%.*]] = getelementptr i8, ptr [[TMP0]], i64 40
 ; CHECK-NEXT:    [[TMP22:%.*]] = load double, ptr [[TMP21]], align 8
 ; CHECK-NEXT:    [[TMP23:%.*]] = getelementptr i8, ptr [[TMP1]], i64 56

@alexey-bataev alexey-bataev requested a review from RKSimon October 4, 2024 18:41
Copy link
Collaborator

@RKSimon RKSimon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Created using spr 1.3.5
@alexey-bataev alexey-bataev merged commit f020bf1 into main Oct 9, 2024
5 of 6 checks passed
@alexey-bataev alexey-bataev deleted the users/alexey-bataev/spr/slpinitial-support-for-non-power-of-2-but-whole-reg-vectorization-for-stores branch October 9, 2024 19:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants