aboutsummaryrefslogtreecommitdiff
path: root/llvm/include/llvm/IR/PatternMatch.h
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/include/llvm/IR/PatternMatch.h')
-rw-r--r--llvm/include/llvm/IR/PatternMatch.h129
1 files changed, 103 insertions, 26 deletions
diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 166ad23de969..cbd429f84ee4 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -88,8 +88,52 @@ inline class_match<BinaryOperator> m_BinOp() {
/// Matches any compare instruction and ignore it.
inline class_match<CmpInst> m_Cmp() { return class_match<CmpInst>(); }
-/// Match an arbitrary undef constant.
-inline class_match<UndefValue> m_Undef() { return class_match<UndefValue>(); }
+struct undef_match {
+ static bool check(const Value *V) {
+ if (isa<UndefValue>(V))
+ return true;
+
+ const auto *CA = dyn_cast<ConstantAggregate>(V);
+ if (!CA)
+ return false;
+
+ SmallPtrSet<const ConstantAggregate *, 8> Seen;
+ SmallVector<const ConstantAggregate *, 8> Worklist;
+
+ // Either UndefValue, PoisonValue, or an aggregate that only contains
+ // these is accepted by matcher.
+ // CheckValue returns false if CA cannot satisfy this constraint.
+ auto CheckValue = [&](const ConstantAggregate *CA) {
+ for (const Value *Op : CA->operand_values()) {
+ if (isa<UndefValue>(Op))
+ continue;
+
+ const auto *CA = dyn_cast<ConstantAggregate>(Op);
+ if (!CA)
+ return false;
+ if (Seen.insert(CA).second)
+ Worklist.emplace_back(CA);
+ }
+
+ return true;
+ };
+
+ if (!CheckValue(CA))
+ return false;
+
+ while (!Worklist.empty()) {
+ if (!CheckValue(Worklist.pop_back_val()))
+ return false;
+ }
+ return true;
+ }
+ template <typename ITy> bool match(ITy *V) { return check(V); }
+};
+
+/// Match an arbitrary undef constant. This matches poison as well.
+/// If this is an aggregate and contains a non-aggregate element that is
+/// neither undef nor poison, the aggregate is not matched.
+inline auto m_Undef() { return undef_match(); }
/// Match an arbitrary poison constant.
inline class_match<PoisonValue> m_Poison() { return class_match<PoisonValue>(); }
@@ -708,6 +752,10 @@ inline bind_ty<UnaryOperator> m_UnOp(UnaryOperator *&I) { return I; }
inline bind_ty<BinaryOperator> m_BinOp(BinaryOperator *&I) { return I; }
/// Match a with overflow intrinsic, capturing it if we match.
inline bind_ty<WithOverflowInst> m_WithOverflowInst(WithOverflowInst *&I) { return I; }
+inline bind_ty<const WithOverflowInst>
+m_WithOverflowInst(const WithOverflowInst *&I) {
+ return I;
+}
/// Match a Constant, capturing the value if we match.
inline bind_ty<Constant> m_Constant(Constant *&C) { return C; }
@@ -763,7 +811,12 @@ template <typename Class> struct deferredval_ty {
template <typename ITy> bool match(ITy *const V) { return V == Val; }
};
-/// A commutative-friendly version of m_Specific().
+/// Like m_Specific(), but works if the specific value to match is determined
+/// as part of the same match() expression. For example:
+/// m_Add(m_Value(X), m_Specific(X)) is incorrect, because m_Specific() will
+/// bind X before the pattern match starts.
+/// m_Add(m_Value(X), m_Deferred(X)) is correct, and will check against
+/// whichever value m_Value(X) populated.
inline deferredval_ty<Value> m_Deferred(Value *const &V) { return V; }
inline deferredval_ty<const Value> m_Deferred(const Value *const &V) {
return V;
@@ -1115,10 +1168,10 @@ struct OverflowingBinaryOp_match {
if (auto *Op = dyn_cast<OverflowingBinaryOperator>(V)) {
if (Op->getOpcode() != Opcode)
return false;
- if (WrapFlags & OverflowingBinaryOperator::NoUnsignedWrap &&
+ if ((WrapFlags & OverflowingBinaryOperator::NoUnsignedWrap) &&
!Op->hasNoUnsignedWrap())
return false;
- if (WrapFlags & OverflowingBinaryOperator::NoSignedWrap &&
+ if ((WrapFlags & OverflowingBinaryOperator::NoSignedWrap) &&
!Op->hasNoSignedWrap())
return false;
return L.match(Op->getOperand(0)) && R.match(Op->getOperand(1));
@@ -1703,6 +1756,7 @@ m_Br(const Cond_t &C, const TrueBlock_t &T, const FalseBlock_t &F) {
template <typename CmpInst_t, typename LHS_t, typename RHS_t, typename Pred_t,
bool Commutable = false>
struct MaxMin_match {
+ using PredType = Pred_t;
LHS_t L;
RHS_t R;
@@ -1731,10 +1785,10 @@ struct MaxMin_match {
return false;
// At this point we have a select conditioned on a comparison. Check that
// it is the values returned by the select that are being compared.
- Value *TrueVal = SI->getTrueValue();
- Value *FalseVal = SI->getFalseValue();
- Value *LHS = Cmp->getOperand(0);
- Value *RHS = Cmp->getOperand(1);
+ auto *TrueVal = SI->getTrueValue();
+ auto *FalseVal = SI->getFalseValue();
+ auto *LHS = Cmp->getOperand(0);
+ auto *RHS = Cmp->getOperand(1);
if ((TrueVal != LHS || FalseVal != RHS) &&
(TrueVal != RHS || FalseVal != LHS))
return false;
@@ -2055,6 +2109,14 @@ template <Intrinsic::ID IntrID> inline IntrinsicID_match m_Intrinsic() {
return IntrinsicID_match(IntrID);
}
+/// Matches MaskedLoad Intrinsic.
+template <typename Opnd0, typename Opnd1, typename Opnd2, typename Opnd3>
+inline typename m_Intrinsic_Ty<Opnd0, Opnd1, Opnd2, Opnd3>::Ty
+m_MaskedLoad(const Opnd0 &Op0, const Opnd1 &Op1, const Opnd2 &Op2,
+ const Opnd3 &Op3) {
+ return m_Intrinsic<Intrinsic::masked_load>(Op0, Op1, Op2, Op3);
+}
+
template <Intrinsic::ID IntrID, typename T0>
inline typename m_Intrinsic_Ty<T0>::Ty m_Intrinsic(const T0 &Op0) {
return m_CombineAnd(m_Intrinsic<IntrID>(), m_Argument<0>(Op0));
@@ -2314,9 +2376,13 @@ template <int Ind, typename Opnd_t> struct ExtractValue_match {
ExtractValue_match(const Opnd_t &V) : Val(V) {}
template <typename OpTy> bool match(OpTy *V) {
- if (auto *I = dyn_cast<ExtractValueInst>(V))
- return I->getNumIndices() == 1 && I->getIndices()[0] == Ind &&
- Val.match(I->getAggregateOperand());
+ if (auto *I = dyn_cast<ExtractValueInst>(V)) {
+ // If Ind is -1, don't inspect indices
+ if (Ind != -1 &&
+ !(I->getNumIndices() == 1 && I->getIndices()[0] == (unsigned)Ind))
+ return false;
+ return Val.match(I->getAggregateOperand());
+ }
return false;
}
};
@@ -2328,6 +2394,13 @@ inline ExtractValue_match<Ind, Val_t> m_ExtractValue(const Val_t &V) {
return ExtractValue_match<Ind, Val_t>(V);
}
+/// Match an ExtractValue instruction with any index.
+/// For example m_ExtractValue(...)
+template <typename Val_t>
+inline ExtractValue_match<-1, Val_t> m_ExtractValue(const Val_t &V) {
+ return ExtractValue_match<-1, Val_t>(V);
+}
+
/// Matcher for a single index InsertValue instruction.
template <int Ind, typename T0, typename T1> struct InsertValue_match {
T0 Op0;
@@ -2356,14 +2429,6 @@ inline InsertValue_match<Ind, Val_t, Elt_t> m_InsertValue(const Val_t &Val,
/// `ptrtoint(gep <vscale x 1 x i8>, <vscale x 1 x i8>* null, i32 1>`
/// under the right conditions determined by DataLayout.
struct VScaleVal_match {
-private:
- template <typename Base, typename Offset>
- inline BinaryOp_match<Base, Offset, Instruction::GetElementPtr>
- m_OffsetGep(const Base &B, const Offset &O) {
- return BinaryOp_match<Base, Offset, Instruction::GetElementPtr>(B, O);
- }
-
-public:
const DataLayout &DL;
VScaleVal_match(const DataLayout &DL) : DL(DL) {}
@@ -2371,12 +2436,16 @@ public:
if (m_Intrinsic<Intrinsic::vscale>().match(V))
return true;
- if (m_PtrToInt(m_OffsetGep(m_Zero(), m_SpecificInt(1))).match(V)) {
- Type *PtrTy = cast<Operator>(V)->getOperand(0)->getType();
- auto *DerefTy = PtrTy->getPointerElementType();
- if (isa<ScalableVectorType>(DerefTy) &&
- DL.getTypeAllocSizeInBits(DerefTy).getKnownMinSize() == 8)
- return true;
+ Value *Ptr;
+ if (m_PtrToInt(m_Value(Ptr)).match(V)) {
+ if (auto *GEP = dyn_cast<GEPOperator>(Ptr)) {
+ auto *DerefTy = GEP->getSourceElementType();
+ if (GEP->getNumIndices() == 1 && isa<ScalableVectorType>(DerefTy) &&
+ m_Zero().match(GEP->getPointerOperand()) &&
+ m_SpecificInt(1).match(GEP->idx_begin()->get()) &&
+ DL.getTypeAllocSizeInBits(DerefTy).getKnownMinSize() == 8)
+ return true;
+ }
}
return false;
@@ -2431,6 +2500,9 @@ m_LogicalAnd(const LHS &L, const RHS &R) {
return LogicalOp_match<LHS, RHS, Instruction::And>(L, R);
}
+/// Matches L && R where L and R are arbitrary values.
+inline auto m_LogicalAnd() { return m_LogicalAnd(m_Value(), m_Value()); }
+
/// Matches L || R either in the form of L | R or L ? true : R.
/// Note that the latter form is poison-blocking.
template <typename LHS, typename RHS>
@@ -2439,6 +2511,11 @@ m_LogicalOr(const LHS &L, const RHS &R) {
return LogicalOp_match<LHS, RHS, Instruction::Or>(L, R);
}
+/// Matches L || R where L and R are arbitrary values.
+inline auto m_LogicalOr() {
+ return m_LogicalOr(m_Value(), m_Value());
+}
+
} // end namespace PatternMatch
} // end namespace llvm