aboutsummaryrefslogtreecommitdiff
path: root/contrib/llvm-project/llvm/lib/CodeGen/ExpandVectorPredication.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/llvm-project/llvm/lib/CodeGen/ExpandVectorPredication.cpp')
-rw-r--r--contrib/llvm-project/llvm/lib/CodeGen/ExpandVectorPredication.cpp87
1 files changed, 85 insertions, 2 deletions
diff --git a/contrib/llvm-project/llvm/lib/CodeGen/ExpandVectorPredication.cpp b/contrib/llvm-project/llvm/lib/CodeGen/ExpandVectorPredication.cpp
index 59932a542bbc..db4d42bf3ca4 100644
--- a/contrib/llvm-project/llvm/lib/CodeGen/ExpandVectorPredication.cpp
+++ b/contrib/llvm-project/llvm/lib/CodeGen/ExpandVectorPredication.cpp
@@ -15,6 +15,7 @@
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/Analysis/ValueTracking.h"
+#include "llvm/Analysis/VectorUtils.h"
#include "llvm/CodeGen/Passes.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Function.h"
@@ -82,8 +83,11 @@ STATISTIC(NumLoweredVPOps, "Number of folded vector predication operations");
/// \returns Whether the vector mask \p MaskVal has all lane bits set.
static bool isAllTrueMask(Value *MaskVal) {
- auto *ConstVec = dyn_cast<ConstantVector>(MaskVal);
- return ConstVec && ConstVec->isAllOnesValue();
+ if (Value *SplattedVal = getSplatValue(MaskVal))
+ if (auto *ConstValue = dyn_cast<Constant>(SplattedVal))
+ return ConstValue->isAllOnesValue();
+
+ return false;
}
/// \returns A non-excepting divisor constant for this type.
@@ -171,6 +175,10 @@ struct CachingVPExpander {
Value *expandPredicationInReduction(IRBuilder<> &Builder,
VPReductionIntrinsic &PI);
+ /// \brief Lower this VP memory operation to a non-VP intrinsic.
+ Value *expandPredicationInMemoryIntrinsic(IRBuilder<> &Builder,
+ VPIntrinsic &VPI);
+
/// \brief Query TTI and expand the vector predication in \p P accordingly.
Value *expandPredication(VPIntrinsic &PI);
@@ -389,6 +397,71 @@ CachingVPExpander::expandPredicationInReduction(IRBuilder<> &Builder,
return Reduction;
}
+Value *
+CachingVPExpander::expandPredicationInMemoryIntrinsic(IRBuilder<> &Builder,
+ VPIntrinsic &VPI) {
+ assert(VPI.canIgnoreVectorLengthParam());
+
+ const auto &DL = F.getParent()->getDataLayout();
+
+ Value *MaskParam = VPI.getMaskParam();
+ Value *PtrParam = VPI.getMemoryPointerParam();
+ Value *DataParam = VPI.getMemoryDataParam();
+ bool IsUnmasked = isAllTrueMask(MaskParam);
+
+ MaybeAlign AlignOpt = VPI.getPointerAlignment();
+
+ Value *NewMemoryInst = nullptr;
+ switch (VPI.getIntrinsicID()) {
+ default:
+ llvm_unreachable("Not a VP memory intrinsic");
+ case Intrinsic::vp_store:
+ if (IsUnmasked) {
+ StoreInst *NewStore =
+ Builder.CreateStore(DataParam, PtrParam, /*IsVolatile*/ false);
+ if (AlignOpt.has_value())
+ NewStore->setAlignment(AlignOpt.value());
+ NewMemoryInst = NewStore;
+ } else
+ NewMemoryInst = Builder.CreateMaskedStore(
+ DataParam, PtrParam, AlignOpt.valueOrOne(), MaskParam);
+
+ break;
+ case Intrinsic::vp_load:
+ if (IsUnmasked) {
+ LoadInst *NewLoad =
+ Builder.CreateLoad(VPI.getType(), PtrParam, /*IsVolatile*/ false);
+ if (AlignOpt.has_value())
+ NewLoad->setAlignment(AlignOpt.value());
+ NewMemoryInst = NewLoad;
+ } else
+ NewMemoryInst = Builder.CreateMaskedLoad(
+ VPI.getType(), PtrParam, AlignOpt.valueOrOne(), MaskParam);
+
+ break;
+ case Intrinsic::vp_scatter: {
+ auto *ElementType =
+ cast<VectorType>(DataParam->getType())->getElementType();
+ NewMemoryInst = Builder.CreateMaskedScatter(
+ DataParam, PtrParam,
+ AlignOpt.value_or(DL.getPrefTypeAlign(ElementType)), MaskParam);
+ break;
+ }
+ case Intrinsic::vp_gather: {
+ auto *ElementType = cast<VectorType>(VPI.getType())->getElementType();
+ NewMemoryInst = Builder.CreateMaskedGather(
+ VPI.getType(), PtrParam,
+ AlignOpt.value_or(DL.getPrefTypeAlign(ElementType)), MaskParam, nullptr,
+ VPI.getName());
+ break;
+ }
+ }
+
+ assert(NewMemoryInst);
+ replaceOperation(*NewMemoryInst, VPI);
+ return NewMemoryInst;
+}
+
void CachingVPExpander::discardEVLParameter(VPIntrinsic &VPI) {
LLVM_DEBUG(dbgs() << "Discard EVL parameter in " << VPI << "\n");
@@ -465,6 +538,16 @@ Value *CachingVPExpander::expandPredication(VPIntrinsic &VPI) {
if (auto *VPRI = dyn_cast<VPReductionIntrinsic>(&VPI))
return expandPredicationInReduction(Builder, *VPRI);
+ switch (VPI.getIntrinsicID()) {
+ default:
+ break;
+ case Intrinsic::vp_load:
+ case Intrinsic::vp_store:
+ case Intrinsic::vp_gather:
+ case Intrinsic::vp_scatter:
+ return expandPredicationInMemoryIntrinsic(Builder, VPI);
+ }
+
return &VPI;
}