diff options
Diffstat (limited to 'llvm/include/llvm/IR/PatternMatch.h')
-rw-r--r-- | llvm/include/llvm/IR/PatternMatch.h | 129 |
1 files changed, 103 insertions, 26 deletions
diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h index 166ad23de969..cbd429f84ee4 100644 --- a/llvm/include/llvm/IR/PatternMatch.h +++ b/llvm/include/llvm/IR/PatternMatch.h @@ -88,8 +88,52 @@ inline class_match<BinaryOperator> m_BinOp() { /// Matches any compare instruction and ignore it. inline class_match<CmpInst> m_Cmp() { return class_match<CmpInst>(); } -/// Match an arbitrary undef constant. -inline class_match<UndefValue> m_Undef() { return class_match<UndefValue>(); } +struct undef_match { + static bool check(const Value *V) { + if (isa<UndefValue>(V)) + return true; + + const auto *CA = dyn_cast<ConstantAggregate>(V); + if (!CA) + return false; + + SmallPtrSet<const ConstantAggregate *, 8> Seen; + SmallVector<const ConstantAggregate *, 8> Worklist; + + // Either UndefValue, PoisonValue, or an aggregate that only contains + // these is accepted by matcher. + // CheckValue returns false if CA cannot satisfy this constraint. + auto CheckValue = [&](const ConstantAggregate *CA) { + for (const Value *Op : CA->operand_values()) { + if (isa<UndefValue>(Op)) + continue; + + const auto *CA = dyn_cast<ConstantAggregate>(Op); + if (!CA) + return false; + if (Seen.insert(CA).second) + Worklist.emplace_back(CA); + } + + return true; + }; + + if (!CheckValue(CA)) + return false; + + while (!Worklist.empty()) { + if (!CheckValue(Worklist.pop_back_val())) + return false; + } + return true; + } + template <typename ITy> bool match(ITy *V) { return check(V); } +}; + +/// Match an arbitrary undef constant. This matches poison as well. +/// If this is an aggregate and contains a non-aggregate element that is +/// neither undef nor poison, the aggregate is not matched. +inline auto m_Undef() { return undef_match(); } /// Match an arbitrary poison constant. inline class_match<PoisonValue> m_Poison() { return class_match<PoisonValue>(); } @@ -708,6 +752,10 @@ inline bind_ty<UnaryOperator> m_UnOp(UnaryOperator *&I) { return I; } inline bind_ty<BinaryOperator> m_BinOp(BinaryOperator *&I) { return I; } /// Match a with overflow intrinsic, capturing it if we match. inline bind_ty<WithOverflowInst> m_WithOverflowInst(WithOverflowInst *&I) { return I; } +inline bind_ty<const WithOverflowInst> +m_WithOverflowInst(const WithOverflowInst *&I) { + return I; +} /// Match a Constant, capturing the value if we match. inline bind_ty<Constant> m_Constant(Constant *&C) { return C; } @@ -763,7 +811,12 @@ template <typename Class> struct deferredval_ty { template <typename ITy> bool match(ITy *const V) { return V == Val; } }; -/// A commutative-friendly version of m_Specific(). +/// Like m_Specific(), but works if the specific value to match is determined +/// as part of the same match() expression. For example: +/// m_Add(m_Value(X), m_Specific(X)) is incorrect, because m_Specific() will +/// bind X before the pattern match starts. +/// m_Add(m_Value(X), m_Deferred(X)) is correct, and will check against +/// whichever value m_Value(X) populated. inline deferredval_ty<Value> m_Deferred(Value *const &V) { return V; } inline deferredval_ty<const Value> m_Deferred(const Value *const &V) { return V; @@ -1115,10 +1168,10 @@ struct OverflowingBinaryOp_match { if (auto *Op = dyn_cast<OverflowingBinaryOperator>(V)) { if (Op->getOpcode() != Opcode) return false; - if (WrapFlags & OverflowingBinaryOperator::NoUnsignedWrap && + if ((WrapFlags & OverflowingBinaryOperator::NoUnsignedWrap) && !Op->hasNoUnsignedWrap()) return false; - if (WrapFlags & OverflowingBinaryOperator::NoSignedWrap && + if ((WrapFlags & OverflowingBinaryOperator::NoSignedWrap) && !Op->hasNoSignedWrap()) return false; return L.match(Op->getOperand(0)) && R.match(Op->getOperand(1)); @@ -1703,6 +1756,7 @@ m_Br(const Cond_t &C, const TrueBlock_t &T, const FalseBlock_t &F) { template <typename CmpInst_t, typename LHS_t, typename RHS_t, typename Pred_t, bool Commutable = false> struct MaxMin_match { + using PredType = Pred_t; LHS_t L; RHS_t R; @@ -1731,10 +1785,10 @@ struct MaxMin_match { return false; // At this point we have a select conditioned on a comparison. Check that // it is the values returned by the select that are being compared. - Value *TrueVal = SI->getTrueValue(); - Value *FalseVal = SI->getFalseValue(); - Value *LHS = Cmp->getOperand(0); - Value *RHS = Cmp->getOperand(1); + auto *TrueVal = SI->getTrueValue(); + auto *FalseVal = SI->getFalseValue(); + auto *LHS = Cmp->getOperand(0); + auto *RHS = Cmp->getOperand(1); if ((TrueVal != LHS || FalseVal != RHS) && (TrueVal != RHS || FalseVal != LHS)) return false; @@ -2055,6 +2109,14 @@ template <Intrinsic::ID IntrID> inline IntrinsicID_match m_Intrinsic() { return IntrinsicID_match(IntrID); } +/// Matches MaskedLoad Intrinsic. +template <typename Opnd0, typename Opnd1, typename Opnd2, typename Opnd3> +inline typename m_Intrinsic_Ty<Opnd0, Opnd1, Opnd2, Opnd3>::Ty +m_MaskedLoad(const Opnd0 &Op0, const Opnd1 &Op1, const Opnd2 &Op2, + const Opnd3 &Op3) { + return m_Intrinsic<Intrinsic::masked_load>(Op0, Op1, Op2, Op3); +} + template <Intrinsic::ID IntrID, typename T0> inline typename m_Intrinsic_Ty<T0>::Ty m_Intrinsic(const T0 &Op0) { return m_CombineAnd(m_Intrinsic<IntrID>(), m_Argument<0>(Op0)); @@ -2314,9 +2376,13 @@ template <int Ind, typename Opnd_t> struct ExtractValue_match { ExtractValue_match(const Opnd_t &V) : Val(V) {} template <typename OpTy> bool match(OpTy *V) { - if (auto *I = dyn_cast<ExtractValueInst>(V)) - return I->getNumIndices() == 1 && I->getIndices()[0] == Ind && - Val.match(I->getAggregateOperand()); + if (auto *I = dyn_cast<ExtractValueInst>(V)) { + // If Ind is -1, don't inspect indices + if (Ind != -1 && + !(I->getNumIndices() == 1 && I->getIndices()[0] == (unsigned)Ind)) + return false; + return Val.match(I->getAggregateOperand()); + } return false; } }; @@ -2328,6 +2394,13 @@ inline ExtractValue_match<Ind, Val_t> m_ExtractValue(const Val_t &V) { return ExtractValue_match<Ind, Val_t>(V); } +/// Match an ExtractValue instruction with any index. +/// For example m_ExtractValue(...) +template <typename Val_t> +inline ExtractValue_match<-1, Val_t> m_ExtractValue(const Val_t &V) { + return ExtractValue_match<-1, Val_t>(V); +} + /// Matcher for a single index InsertValue instruction. template <int Ind, typename T0, typename T1> struct InsertValue_match { T0 Op0; @@ -2356,14 +2429,6 @@ inline InsertValue_match<Ind, Val_t, Elt_t> m_InsertValue(const Val_t &Val, /// `ptrtoint(gep <vscale x 1 x i8>, <vscale x 1 x i8>* null, i32 1>` /// under the right conditions determined by DataLayout. struct VScaleVal_match { -private: - template <typename Base, typename Offset> - inline BinaryOp_match<Base, Offset, Instruction::GetElementPtr> - m_OffsetGep(const Base &B, const Offset &O) { - return BinaryOp_match<Base, Offset, Instruction::GetElementPtr>(B, O); - } - -public: const DataLayout &DL; VScaleVal_match(const DataLayout &DL) : DL(DL) {} @@ -2371,12 +2436,16 @@ public: if (m_Intrinsic<Intrinsic::vscale>().match(V)) return true; - if (m_PtrToInt(m_OffsetGep(m_Zero(), m_SpecificInt(1))).match(V)) { - Type *PtrTy = cast<Operator>(V)->getOperand(0)->getType(); - auto *DerefTy = PtrTy->getPointerElementType(); - if (isa<ScalableVectorType>(DerefTy) && - DL.getTypeAllocSizeInBits(DerefTy).getKnownMinSize() == 8) - return true; + Value *Ptr; + if (m_PtrToInt(m_Value(Ptr)).match(V)) { + if (auto *GEP = dyn_cast<GEPOperator>(Ptr)) { + auto *DerefTy = GEP->getSourceElementType(); + if (GEP->getNumIndices() == 1 && isa<ScalableVectorType>(DerefTy) && + m_Zero().match(GEP->getPointerOperand()) && + m_SpecificInt(1).match(GEP->idx_begin()->get()) && + DL.getTypeAllocSizeInBits(DerefTy).getKnownMinSize() == 8) + return true; + } } return false; @@ -2431,6 +2500,9 @@ m_LogicalAnd(const LHS &L, const RHS &R) { return LogicalOp_match<LHS, RHS, Instruction::And>(L, R); } +/// Matches L && R where L and R are arbitrary values. +inline auto m_LogicalAnd() { return m_LogicalAnd(m_Value(), m_Value()); } + /// Matches L || R either in the form of L | R or L ? true : R. /// Note that the latter form is poison-blocking. template <typename LHS, typename RHS> @@ -2439,6 +2511,11 @@ m_LogicalOr(const LHS &L, const RHS &R) { return LogicalOp_match<LHS, RHS, Instruction::Or>(L, R); } +/// Matches L || R where L and R are arbitrary values. +inline auto m_LogicalOr() { + return m_LogicalOr(m_Value(), m_Value()); +} + } // end namespace PatternMatch } // end namespace llvm |