diff options
Diffstat (limited to 'contrib/llvm-project/llvm/lib/CodeGen/ExpandVectorPredication.cpp')
-rw-r--r-- | contrib/llvm-project/llvm/lib/CodeGen/ExpandVectorPredication.cpp | 87 |
1 files changed, 85 insertions, 2 deletions
diff --git a/contrib/llvm-project/llvm/lib/CodeGen/ExpandVectorPredication.cpp b/contrib/llvm-project/llvm/lib/CodeGen/ExpandVectorPredication.cpp index 59932a542bbc..db4d42bf3ca4 100644 --- a/contrib/llvm-project/llvm/lib/CodeGen/ExpandVectorPredication.cpp +++ b/contrib/llvm-project/llvm/lib/CodeGen/ExpandVectorPredication.cpp @@ -15,6 +15,7 @@ #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/Analysis/VectorUtils.h" #include "llvm/CodeGen/Passes.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Function.h" @@ -82,8 +83,11 @@ STATISTIC(NumLoweredVPOps, "Number of folded vector predication operations"); /// \returns Whether the vector mask \p MaskVal has all lane bits set. static bool isAllTrueMask(Value *MaskVal) { - auto *ConstVec = dyn_cast<ConstantVector>(MaskVal); - return ConstVec && ConstVec->isAllOnesValue(); + if (Value *SplattedVal = getSplatValue(MaskVal)) + if (auto *ConstValue = dyn_cast<Constant>(SplattedVal)) + return ConstValue->isAllOnesValue(); + + return false; } /// \returns A non-excepting divisor constant for this type. @@ -171,6 +175,10 @@ struct CachingVPExpander { Value *expandPredicationInReduction(IRBuilder<> &Builder, VPReductionIntrinsic &PI); + /// \brief Lower this VP memory operation to a non-VP intrinsic. + Value *expandPredicationInMemoryIntrinsic(IRBuilder<> &Builder, + VPIntrinsic &VPI); + /// \brief Query TTI and expand the vector predication in \p P accordingly. Value *expandPredication(VPIntrinsic &PI); @@ -389,6 +397,71 @@ CachingVPExpander::expandPredicationInReduction(IRBuilder<> &Builder, return Reduction; } +Value * +CachingVPExpander::expandPredicationInMemoryIntrinsic(IRBuilder<> &Builder, + VPIntrinsic &VPI) { + assert(VPI.canIgnoreVectorLengthParam()); + + const auto &DL = F.getParent()->getDataLayout(); + + Value *MaskParam = VPI.getMaskParam(); + Value *PtrParam = VPI.getMemoryPointerParam(); + Value *DataParam = VPI.getMemoryDataParam(); + bool IsUnmasked = isAllTrueMask(MaskParam); + + MaybeAlign AlignOpt = VPI.getPointerAlignment(); + + Value *NewMemoryInst = nullptr; + switch (VPI.getIntrinsicID()) { + default: + llvm_unreachable("Not a VP memory intrinsic"); + case Intrinsic::vp_store: + if (IsUnmasked) { + StoreInst *NewStore = + Builder.CreateStore(DataParam, PtrParam, /*IsVolatile*/ false); + if (AlignOpt.has_value()) + NewStore->setAlignment(AlignOpt.value()); + NewMemoryInst = NewStore; + } else + NewMemoryInst = Builder.CreateMaskedStore( + DataParam, PtrParam, AlignOpt.valueOrOne(), MaskParam); + + break; + case Intrinsic::vp_load: + if (IsUnmasked) { + LoadInst *NewLoad = + Builder.CreateLoad(VPI.getType(), PtrParam, /*IsVolatile*/ false); + if (AlignOpt.has_value()) + NewLoad->setAlignment(AlignOpt.value()); + NewMemoryInst = NewLoad; + } else + NewMemoryInst = Builder.CreateMaskedLoad( + VPI.getType(), PtrParam, AlignOpt.valueOrOne(), MaskParam); + + break; + case Intrinsic::vp_scatter: { + auto *ElementType = + cast<VectorType>(DataParam->getType())->getElementType(); + NewMemoryInst = Builder.CreateMaskedScatter( + DataParam, PtrParam, + AlignOpt.value_or(DL.getPrefTypeAlign(ElementType)), MaskParam); + break; + } + case Intrinsic::vp_gather: { + auto *ElementType = cast<VectorType>(VPI.getType())->getElementType(); + NewMemoryInst = Builder.CreateMaskedGather( + VPI.getType(), PtrParam, + AlignOpt.value_or(DL.getPrefTypeAlign(ElementType)), MaskParam, nullptr, + VPI.getName()); + break; + } + } + + assert(NewMemoryInst); + replaceOperation(*NewMemoryInst, VPI); + return NewMemoryInst; +} + void CachingVPExpander::discardEVLParameter(VPIntrinsic &VPI) { LLVM_DEBUG(dbgs() << "Discard EVL parameter in " << VPI << "\n"); @@ -465,6 +538,16 @@ Value *CachingVPExpander::expandPredication(VPIntrinsic &VPI) { if (auto *VPRI = dyn_cast<VPReductionIntrinsic>(&VPI)) return expandPredicationInReduction(Builder, *VPRI); + switch (VPI.getIntrinsicID()) { + default: + break; + case Intrinsic::vp_load: + case Intrinsic::vp_store: + case Intrinsic::vp_gather: + case Intrinsic::vp_scatter: + return expandPredicationInMemoryIntrinsic(Builder, VPI); + } + return &VPI; } |