aboutsummaryrefslogtreecommitdiff
path: root/lib/Transforms/InstCombine/InstCombineSelect.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Transforms/InstCombine/InstCombineSelect.cpp')
-rw-r--r--lib/Transforms/InstCombine/InstCombineSelect.cpp379
1 files changed, 216 insertions, 163 deletions
diff --git a/lib/Transforms/InstCombine/InstCombineSelect.cpp b/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 51219bcb0b7b..d7eed790e2ab 100644
--- a/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -116,25 +116,41 @@ static Constant *GetSelectFoldableConstant(Instruction *I) {
}
}
-/// Here we have (select c, TI, FI), and we know that TI and FI
-/// have the same opcode and only one use each. Try to simplify this.
+/// We have (select c, TI, FI), and we know that TI and FI have the same opcode.
Instruction *InstCombiner::FoldSelectOpOp(SelectInst &SI, Instruction *TI,
Instruction *FI) {
- if (TI->getNumOperands() == 1) {
- // If this is a non-volatile load or a cast from the same type,
- // merge.
- if (TI->isCast()) {
- Type *FIOpndTy = FI->getOperand(0)->getType();
- if (TI->getOperand(0)->getType() != FIOpndTy)
+ // If this is a cast from the same type, merge.
+ if (TI->getNumOperands() == 1 && TI->isCast()) {
+ Type *FIOpndTy = FI->getOperand(0)->getType();
+ if (TI->getOperand(0)->getType() != FIOpndTy)
+ return nullptr;
+
+ // The select condition may be a vector. We may only change the operand
+ // type if the vector width remains the same (and matches the condition).
+ Type *CondTy = SI.getCondition()->getType();
+ if (CondTy->isVectorTy()) {
+ if (!FIOpndTy->isVectorTy())
return nullptr;
- // The select condition may be a vector. We may only change the operand
- // type if the vector width remains the same (and matches the condition).
- Type *CondTy = SI.getCondition()->getType();
- if (CondTy->isVectorTy() && (!FIOpndTy->isVectorTy() ||
- CondTy->getVectorNumElements() != FIOpndTy->getVectorNumElements()))
+ if (CondTy->getVectorNumElements() != FIOpndTy->getVectorNumElements())
return nullptr;
- } else {
- return nullptr; // unknown unary op.
+
+ // TODO: If the backend knew how to deal with casts better, we could
+ // remove this limitation. For now, there's too much potential to create
+ // worse codegen by promoting the select ahead of size-altering casts
+ // (PR28160).
+ //
+ // Note that ValueTracking's matchSelectPattern() looks through casts
+ // without checking 'hasOneUse' when it matches min/max patterns, so this
+ // transform may end up happening anyway.
+ if (TI->getOpcode() != Instruction::BitCast &&
+ (!TI->hasOneUse() || !FI->hasOneUse()))
+ return nullptr;
+
+ } else if (!TI->hasOneUse() || !FI->hasOneUse()) {
+ // TODO: The one-use restrictions for a scalar select could be eased if
+ // the fold of a select in visitLoadInst() was enhanced to match a pattern
+ // that includes a cast.
+ return nullptr;
}
// Fold this by inserting a select from the input values.
@@ -144,8 +160,13 @@ Instruction *InstCombiner::FoldSelectOpOp(SelectInst &SI, Instruction *TI,
TI->getType());
}
- // Only handle binary operators here.
- if (!isa<BinaryOperator>(TI))
+ // TODO: This function ends awkwardly in unreachable - fix to be more normal.
+
+ // Only handle binary operators with one-use here. As with the cast case
+ // above, it may be possible to relax the one-use constraint, but that needs
+ // be examined carefully since it may not reduce the total number of
+ // instructions.
+ if (!isa<BinaryOperator>(TI) || !TI->hasOneUse() || !FI->hasOneUse())
return nullptr;
// Figure out if the operations have any operands in common.
@@ -231,12 +252,7 @@ Instruction *InstCombiner::FoldSelectIntoOp(SelectInst &SI, Value *TrueVal,
BinaryOperator *TVI_BO = cast<BinaryOperator>(TVI);
BinaryOperator *BO = BinaryOperator::Create(TVI_BO->getOpcode(),
FalseVal, NewSel);
- if (isa<PossiblyExactOperator>(BO))
- BO->setIsExact(TVI_BO->isExact());
- if (isa<OverflowingBinaryOperator>(BO)) {
- BO->setHasNoUnsignedWrap(TVI_BO->hasNoUnsignedWrap());
- BO->setHasNoSignedWrap(TVI_BO->hasNoSignedWrap());
- }
+ BO->copyIRFlags(TVI_BO);
return BO;
}
}
@@ -266,12 +282,7 @@ Instruction *InstCombiner::FoldSelectIntoOp(SelectInst &SI, Value *TrueVal,
BinaryOperator *FVI_BO = cast<BinaryOperator>(FVI);
BinaryOperator *BO = BinaryOperator::Create(FVI_BO->getOpcode(),
TrueVal, NewSel);
- if (isa<PossiblyExactOperator>(BO))
- BO->setIsExact(FVI_BO->isExact());
- if (isa<OverflowingBinaryOperator>(BO)) {
- BO->setHasNoUnsignedWrap(FVI_BO->hasNoUnsignedWrap());
- BO->setHasNoSignedWrap(FVI_BO->hasNoSignedWrap());
- }
+ BO->copyIRFlags(FVI_BO);
return BO;
}
}
@@ -353,7 +364,7 @@ static Value *foldSelectICmpAndOr(const SelectInst &SI, Value *TrueVal,
/// %1 = icmp ne i32 %x, 0
/// %2 = select i1 %1, i32 %0, i32 32
/// \code
-///
+///
/// into:
/// %0 = tail call i32 @llvm.cttz.i32(i32 %x, i1 false)
static Value *foldSelectCttzCtlz(ICmpInst *ICI, Value *TrueVal, Value *FalseVal,
@@ -519,10 +530,10 @@ Instruction *InstCombiner::visitSelectInstWithICmp(SelectInst &SI,
// Check if we can express the operation with a single or.
if (C2->isAllOnesValue())
- return ReplaceInstUsesWith(SI, Builder->CreateOr(AShr, C1));
+ return replaceInstUsesWith(SI, Builder->CreateOr(AShr, C1));
Value *And = Builder->CreateAnd(AShr, C2->getValue()-C1->getValue());
- return ReplaceInstUsesWith(SI, Builder->CreateAdd(And, C1));
+ return replaceInstUsesWith(SI, Builder->CreateAdd(And, C1));
}
}
}
@@ -585,15 +596,15 @@ Instruction *InstCombiner::visitSelectInstWithICmp(SelectInst &SI,
V = Builder->CreateOr(X, *Y);
if (V)
- return ReplaceInstUsesWith(SI, V);
+ return replaceInstUsesWith(SI, V);
}
}
if (Value *V = foldSelectICmpAndOr(SI, TrueVal, FalseVal, Builder))
- return ReplaceInstUsesWith(SI, V);
+ return replaceInstUsesWith(SI, V);
if (Value *V = foldSelectCttzCtlz(ICI, TrueVal, FalseVal, Builder))
- return ReplaceInstUsesWith(SI, V);
+ return replaceInstUsesWith(SI, V);
return Changed ? &SI : nullptr;
}
@@ -642,11 +653,14 @@ Instruction *InstCombiner::FoldSPFofSPF(Instruction *Inner,
Value *A, Value *B,
Instruction &Outer,
SelectPatternFlavor SPF2, Value *C) {
+ if (Outer.getType() != Inner->getType())
+ return nullptr;
+
if (C == A || C == B) {
// MAX(MAX(A, B), B) -> MAX(A, B)
// MIN(MIN(a, b), a) -> MIN(a, b)
if (SPF1 == SPF2)
- return ReplaceInstUsesWith(Outer, Inner);
+ return replaceInstUsesWith(Outer, Inner);
// MAX(MIN(a, b), a) -> a
// MIN(MAX(a, b), a) -> a
@@ -654,14 +668,14 @@ Instruction *InstCombiner::FoldSPFofSPF(Instruction *Inner,
(SPF1 == SPF_SMAX && SPF2 == SPF_SMIN) ||
(SPF1 == SPF_UMIN && SPF2 == SPF_UMAX) ||
(SPF1 == SPF_UMAX && SPF2 == SPF_UMIN))
- return ReplaceInstUsesWith(Outer, C);
+ return replaceInstUsesWith(Outer, C);
}
if (SPF1 == SPF2) {
if (ConstantInt *CB = dyn_cast<ConstantInt>(B)) {
if (ConstantInt *CC = dyn_cast<ConstantInt>(C)) {
- APInt ACB = CB->getValue();
- APInt ACC = CC->getValue();
+ const APInt &ACB = CB->getValue();
+ const APInt &ACC = CC->getValue();
// MIN(MIN(A, 23), 97) -> MIN(A, 23)
// MAX(MAX(A, 97), 23) -> MAX(A, 97)
@@ -669,7 +683,7 @@ Instruction *InstCombiner::FoldSPFofSPF(Instruction *Inner,
(SPF1 == SPF_SMIN && ACB.sle(ACC)) ||
(SPF1 == SPF_UMAX && ACB.uge(ACC)) ||
(SPF1 == SPF_SMAX && ACB.sge(ACC)))
- return ReplaceInstUsesWith(Outer, Inner);
+ return replaceInstUsesWith(Outer, Inner);
// MIN(MIN(A, 97), 23) -> MIN(A, 23)
// MAX(MAX(A, 23), 97) -> MAX(A, 97)
@@ -687,7 +701,7 @@ Instruction *InstCombiner::FoldSPFofSPF(Instruction *Inner,
// ABS(ABS(X)) -> ABS(X)
// NABS(NABS(X)) -> NABS(X)
if (SPF1 == SPF2 && (SPF1 == SPF_ABS || SPF1 == SPF_NABS)) {
- return ReplaceInstUsesWith(Outer, Inner);
+ return replaceInstUsesWith(Outer, Inner);
}
// ABS(NABS(X)) -> ABS(X)
@@ -697,7 +711,7 @@ Instruction *InstCombiner::FoldSPFofSPF(Instruction *Inner,
SelectInst *SI = cast<SelectInst>(Inner);
Value *NewSI = Builder->CreateSelect(
SI->getCondition(), SI->getFalseValue(), SI->getTrueValue());
- return ReplaceInstUsesWith(Outer, NewSI);
+ return replaceInstUsesWith(Outer, NewSI);
}
auto IsFreeOrProfitableToInvert =
@@ -742,7 +756,7 @@ Instruction *InstCombiner::FoldSPFofSPF(Instruction *Inner,
Builder, getInverseMinMaxSelectPattern(SPF1), NotA, NotB);
Value *NewOuter = Builder->CreateNot(generateMinMaxSelectPattern(
Builder, getInverseMinMaxSelectPattern(SPF2), NewInner, NotC));
- return ReplaceInstUsesWith(Outer, NewOuter);
+ return replaceInstUsesWith(Outer, NewOuter);
}
return nullptr;
@@ -823,76 +837,156 @@ static Value *foldSelectICmpAnd(const SelectInst &SI, ConstantInt *TrueVal,
return V;
}
+/// Turn select C, (X + Y), (X - Y) --> (X + (select C, Y, (-Y))).
+/// This is even legal for FP.
+static Instruction *foldAddSubSelect(SelectInst &SI,
+ InstCombiner::BuilderTy &Builder) {
+ Value *CondVal = SI.getCondition();
+ Value *TrueVal = SI.getTrueValue();
+ Value *FalseVal = SI.getFalseValue();
+ auto *TI = dyn_cast<Instruction>(TrueVal);
+ auto *FI = dyn_cast<Instruction>(FalseVal);
+ if (!TI || !FI || !TI->hasOneUse() || !FI->hasOneUse())
+ return nullptr;
+
+ Instruction *AddOp = nullptr, *SubOp = nullptr;
+ if ((TI->getOpcode() == Instruction::Sub &&
+ FI->getOpcode() == Instruction::Add) ||
+ (TI->getOpcode() == Instruction::FSub &&
+ FI->getOpcode() == Instruction::FAdd)) {
+ AddOp = FI;
+ SubOp = TI;
+ } else if ((FI->getOpcode() == Instruction::Sub &&
+ TI->getOpcode() == Instruction::Add) ||
+ (FI->getOpcode() == Instruction::FSub &&
+ TI->getOpcode() == Instruction::FAdd)) {
+ AddOp = TI;
+ SubOp = FI;
+ }
+
+ if (AddOp) {
+ Value *OtherAddOp = nullptr;
+ if (SubOp->getOperand(0) == AddOp->getOperand(0)) {
+ OtherAddOp = AddOp->getOperand(1);
+ } else if (SubOp->getOperand(0) == AddOp->getOperand(1)) {
+ OtherAddOp = AddOp->getOperand(0);
+ }
+
+ if (OtherAddOp) {
+ // So at this point we know we have (Y -> OtherAddOp):
+ // select C, (add X, Y), (sub X, Z)
+ Value *NegVal; // Compute -Z
+ if (SI.getType()->isFPOrFPVectorTy()) {
+ NegVal = Builder.CreateFNeg(SubOp->getOperand(1));
+ if (Instruction *NegInst = dyn_cast<Instruction>(NegVal)) {
+ FastMathFlags Flags = AddOp->getFastMathFlags();
+ Flags &= SubOp->getFastMathFlags();
+ NegInst->setFastMathFlags(Flags);
+ }
+ } else {
+ NegVal = Builder.CreateNeg(SubOp->getOperand(1));
+ }
+
+ Value *NewTrueOp = OtherAddOp;
+ Value *NewFalseOp = NegVal;
+ if (AddOp != TI)
+ std::swap(NewTrueOp, NewFalseOp);
+ Value *NewSel = Builder.CreateSelect(CondVal, NewTrueOp, NewFalseOp,
+ SI.getName() + ".p");
+
+ if (SI.getType()->isFPOrFPVectorTy()) {
+ Instruction *RI =
+ BinaryOperator::CreateFAdd(SubOp->getOperand(0), NewSel);
+
+ FastMathFlags Flags = AddOp->getFastMathFlags();
+ Flags &= SubOp->getFastMathFlags();
+ RI->setFastMathFlags(Flags);
+ return RI;
+ } else
+ return BinaryOperator::CreateAdd(SubOp->getOperand(0), NewSel);
+ }
+ }
+ return nullptr;
+}
+
Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
Value *CondVal = SI.getCondition();
Value *TrueVal = SI.getTrueValue();
Value *FalseVal = SI.getFalseValue();
+ Type *SelType = SI.getType();
if (Value *V =
SimplifySelectInst(CondVal, TrueVal, FalseVal, DL, TLI, DT, AC))
- return ReplaceInstUsesWith(SI, V);
+ return replaceInstUsesWith(SI, V);
- if (SI.getType()->isIntegerTy(1)) {
- if (ConstantInt *C = dyn_cast<ConstantInt>(TrueVal)) {
- if (C->getZExtValue()) {
- // Change: A = select B, true, C --> A = or B, C
- return BinaryOperator::CreateOr(CondVal, FalseVal);
- }
+ if (SelType->getScalarType()->isIntegerTy(1) &&
+ TrueVal->getType() == CondVal->getType()) {
+ if (match(TrueVal, m_One())) {
+ // Change: A = select B, true, C --> A = or B, C
+ return BinaryOperator::CreateOr(CondVal, FalseVal);
+ }
+ if (match(TrueVal, m_Zero())) {
// Change: A = select B, false, C --> A = and !B, C
- Value *NotCond = Builder->CreateNot(CondVal, "not."+CondVal->getName());
+ Value *NotCond = Builder->CreateNot(CondVal, "not." + CondVal->getName());
return BinaryOperator::CreateAnd(NotCond, FalseVal);
}
- if (ConstantInt *C = dyn_cast<ConstantInt>(FalseVal)) {
- if (!C->getZExtValue()) {
- // Change: A = select B, C, false --> A = and B, C
- return BinaryOperator::CreateAnd(CondVal, TrueVal);
- }
+ if (match(FalseVal, m_Zero())) {
+ // Change: A = select B, C, false --> A = and B, C
+ return BinaryOperator::CreateAnd(CondVal, TrueVal);
+ }
+ if (match(FalseVal, m_One())) {
// Change: A = select B, C, true --> A = or !B, C
- Value *NotCond = Builder->CreateNot(CondVal, "not."+CondVal->getName());
+ Value *NotCond = Builder->CreateNot(CondVal, "not." + CondVal->getName());
return BinaryOperator::CreateOr(NotCond, TrueVal);
}
- // select a, b, a -> a&b
- // select a, a, b -> a|b
+ // select a, a, b -> a | b
+ // select a, b, a -> a & b
if (CondVal == TrueVal)
return BinaryOperator::CreateOr(CondVal, FalseVal);
if (CondVal == FalseVal)
return BinaryOperator::CreateAnd(CondVal, TrueVal);
- // select a, ~a, b -> (~a)&b
- // select a, b, ~a -> (~a)|b
+ // select a, ~a, b -> (~a) & b
+ // select a, b, ~a -> (~a) | b
if (match(TrueVal, m_Not(m_Specific(CondVal))))
return BinaryOperator::CreateAnd(TrueVal, FalseVal);
if (match(FalseVal, m_Not(m_Specific(CondVal))))
return BinaryOperator::CreateOr(TrueVal, FalseVal);
}
- // Selecting between two integer constants?
- if (ConstantInt *TrueValC = dyn_cast<ConstantInt>(TrueVal))
- if (ConstantInt *FalseValC = dyn_cast<ConstantInt>(FalseVal)) {
- // select C, 1, 0 -> zext C to int
- if (FalseValC->isZero() && TrueValC->getValue() == 1)
- return new ZExtInst(CondVal, SI.getType());
-
- // select C, -1, 0 -> sext C to int
- if (FalseValC->isZero() && TrueValC->isAllOnesValue())
- return new SExtInst(CondVal, SI.getType());
-
- // select C, 0, 1 -> zext !C to int
- if (TrueValC->isZero() && FalseValC->getValue() == 1) {
- Value *NotCond = Builder->CreateNot(CondVal, "not."+CondVal->getName());
- return new ZExtInst(NotCond, SI.getType());
- }
+ // Selecting between two integer or vector splat integer constants?
+ //
+ // Note that we don't handle a scalar select of vectors:
+ // select i1 %c, <2 x i8> <1, 1>, <2 x i8> <0, 0>
+ // because that may need 3 instructions to splat the condition value:
+ // extend, insertelement, shufflevector.
+ if (CondVal->getType()->isVectorTy() == SelType->isVectorTy()) {
+ // select C, 1, 0 -> zext C to int
+ if (match(TrueVal, m_One()) && match(FalseVal, m_Zero()))
+ return new ZExtInst(CondVal, SelType);
+
+ // select C, -1, 0 -> sext C to int
+ if (match(TrueVal, m_AllOnes()) && match(FalseVal, m_Zero()))
+ return new SExtInst(CondVal, SelType);
+
+ // select C, 0, 1 -> zext !C to int
+ if (match(TrueVal, m_Zero()) && match(FalseVal, m_One())) {
+ Value *NotCond = Builder->CreateNot(CondVal, "not." + CondVal->getName());
+ return new ZExtInst(NotCond, SelType);
+ }
- // select C, 0, -1 -> sext !C to int
- if (TrueValC->isZero() && FalseValC->isAllOnesValue()) {
- Value *NotCond = Builder->CreateNot(CondVal, "not."+CondVal->getName());
- return new SExtInst(NotCond, SI.getType());
- }
+ // select C, 0, -1 -> sext !C to int
+ if (match(TrueVal, m_Zero()) && match(FalseVal, m_AllOnes())) {
+ Value *NotCond = Builder->CreateNot(CondVal, "not." + CondVal->getName());
+ return new SExtInst(NotCond, SelType);
+ }
+ }
+ if (ConstantInt *TrueValC = dyn_cast<ConstantInt>(TrueVal))
+ if (ConstantInt *FalseValC = dyn_cast<ConstantInt>(FalseVal))
if (Value *V = foldSelectICmpAnd(SI, TrueValC, FalseValC, Builder))
- return ReplaceInstUsesWith(SI, V);
- }
+ return replaceInstUsesWith(SI, V);
// See if we are selecting two values based on a comparison of the two values.
if (FCmpInst *FCI = dyn_cast<FCmpInst>(CondVal)) {
@@ -907,7 +1001,7 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
!CFPt->getValueAPF().isZero()) ||
((CFPf = dyn_cast<ConstantFP>(FalseVal)) &&
!CFPf->getValueAPF().isZero()))
- return ReplaceInstUsesWith(SI, FalseVal);
+ return replaceInstUsesWith(SI, FalseVal);
}
// Transform (X une Y) ? X : Y -> X
if (FCI->getPredicate() == FCmpInst::FCMP_UNE) {
@@ -919,7 +1013,7 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
!CFPt->getValueAPF().isZero()) ||
((CFPf = dyn_cast<ConstantFP>(FalseVal)) &&
!CFPf->getValueAPF().isZero()))
- return ReplaceInstUsesWith(SI, TrueVal);
+ return replaceInstUsesWith(SI, TrueVal);
}
// Canonicalize to use ordered comparisons by swapping the select
@@ -950,7 +1044,7 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
!CFPt->getValueAPF().isZero()) ||
((CFPf = dyn_cast<ConstantFP>(FalseVal)) &&
!CFPf->getValueAPF().isZero()))
- return ReplaceInstUsesWith(SI, FalseVal);
+ return replaceInstUsesWith(SI, FalseVal);
}
// Transform (X une Y) ? Y : X -> Y
if (FCI->getPredicate() == FCmpInst::FCMP_UNE) {
@@ -962,7 +1056,7 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
!CFPt->getValueAPF().isZero()) ||
((CFPf = dyn_cast<ConstantFP>(FalseVal)) &&
!CFPf->getValueAPF().isZero()))
- return ReplaceInstUsesWith(SI, TrueVal);
+ return replaceInstUsesWith(SI, TrueVal);
}
// Canonicalize to use ordered comparisons by swapping the select
@@ -991,77 +1085,18 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
if (Instruction *Result = visitSelectInstWithICmp(SI, ICI))
return Result;
- if (Instruction *TI = dyn_cast<Instruction>(TrueVal))
- if (Instruction *FI = dyn_cast<Instruction>(FalseVal))
- if (TI->hasOneUse() && FI->hasOneUse()) {
- Instruction *AddOp = nullptr, *SubOp = nullptr;
-
- // Turn (select C, (op X, Y), (op X, Z)) -> (op X, (select C, Y, Z))
- if (TI->getOpcode() == FI->getOpcode())
- if (Instruction *IV = FoldSelectOpOp(SI, TI, FI))
- return IV;
-
- // Turn select C, (X+Y), (X-Y) --> (X+(select C, Y, (-Y))). This is
- // even legal for FP.
- if ((TI->getOpcode() == Instruction::Sub &&
- FI->getOpcode() == Instruction::Add) ||
- (TI->getOpcode() == Instruction::FSub &&
- FI->getOpcode() == Instruction::FAdd)) {
- AddOp = FI; SubOp = TI;
- } else if ((FI->getOpcode() == Instruction::Sub &&
- TI->getOpcode() == Instruction::Add) ||
- (FI->getOpcode() == Instruction::FSub &&
- TI->getOpcode() == Instruction::FAdd)) {
- AddOp = TI; SubOp = FI;
- }
+ if (Instruction *Add = foldAddSubSelect(SI, *Builder))
+ return Add;
- if (AddOp) {
- Value *OtherAddOp = nullptr;
- if (SubOp->getOperand(0) == AddOp->getOperand(0)) {
- OtherAddOp = AddOp->getOperand(1);
- } else if (SubOp->getOperand(0) == AddOp->getOperand(1)) {
- OtherAddOp = AddOp->getOperand(0);
- }
-
- if (OtherAddOp) {
- // So at this point we know we have (Y -> OtherAddOp):
- // select C, (add X, Y), (sub X, Z)
- Value *NegVal; // Compute -Z
- if (SI.getType()->isFPOrFPVectorTy()) {
- NegVal = Builder->CreateFNeg(SubOp->getOperand(1));
- if (Instruction *NegInst = dyn_cast<Instruction>(NegVal)) {
- FastMathFlags Flags = AddOp->getFastMathFlags();
- Flags &= SubOp->getFastMathFlags();
- NegInst->setFastMathFlags(Flags);
- }
- } else {
- NegVal = Builder->CreateNeg(SubOp->getOperand(1));
- }
-
- Value *NewTrueOp = OtherAddOp;
- Value *NewFalseOp = NegVal;
- if (AddOp != TI)
- std::swap(NewTrueOp, NewFalseOp);
- Value *NewSel =
- Builder->CreateSelect(CondVal, NewTrueOp,
- NewFalseOp, SI.getName() + ".p");
-
- if (SI.getType()->isFPOrFPVectorTy()) {
- Instruction *RI =
- BinaryOperator::CreateFAdd(SubOp->getOperand(0), NewSel);
-
- FastMathFlags Flags = AddOp->getFastMathFlags();
- Flags &= SubOp->getFastMathFlags();
- RI->setFastMathFlags(Flags);
- return RI;
- } else
- return BinaryOperator::CreateAdd(SubOp->getOperand(0), NewSel);
- }
- }
- }
+ // Turn (select C, (op X, Y), (op X, Z)) -> (op X, (select C, Y, Z))
+ auto *TI = dyn_cast<Instruction>(TrueVal);
+ auto *FI = dyn_cast<Instruction>(FalseVal);
+ if (TI && FI && TI->getOpcode() == FI->getOpcode())
+ if (Instruction *IV = FoldSelectOpOp(SI, TI, FI))
+ return IV;
// See if we can fold the select into one of our operands.
- if (SI.getType()->isIntOrIntVectorTy() || SI.getType()->isFPOrFPVectorTy()) {
+ if (SelType->isIntOrIntVectorTy() || SelType->isFPOrFPVectorTy()) {
if (Instruction *FoldI = FoldSelectIntoOp(SI, TrueVal, FalseVal))
return FoldI;
@@ -1073,7 +1108,7 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
if (SelectPatternResult::isMinOrMax(SPF)) {
// Canonicalize so that type casts are outside select patterns.
if (LHS->getType()->getPrimitiveSizeInBits() !=
- SI.getType()->getPrimitiveSizeInBits()) {
+ SelType->getPrimitiveSizeInBits()) {
CmpInst::Predicate Pred = getCmpPredicateForMinMax(SPF, SPR.Ordered);
Value *Cmp;
@@ -1088,8 +1123,8 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
Value *NewSI = Builder->CreateCast(CastOp,
Builder->CreateSelect(Cmp, LHS, RHS),
- SI.getType());
- return ReplaceInstUsesWith(SI, NewSI);
+ SelType);
+ return replaceInstUsesWith(SI, NewSI);
}
}
@@ -1132,7 +1167,7 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
: Builder->CreateICmpULT(NewLHS, NewRHS);
Value *NewSI =
Builder->CreateNot(Builder->CreateSelect(NewCmp, NewLHS, NewRHS));
- return ReplaceInstUsesWith(SI, NewSI);
+ return replaceInstUsesWith(SI, NewSI);
}
}
}
@@ -1195,18 +1230,36 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
return &SI;
}
- if (VectorType* VecTy = dyn_cast<VectorType>(SI.getType())) {
+ if (VectorType* VecTy = dyn_cast<VectorType>(SelType)) {
unsigned VWidth = VecTy->getNumElements();
APInt UndefElts(VWidth, 0);
APInt AllOnesEltMask(APInt::getAllOnesValue(VWidth));
if (Value *V = SimplifyDemandedVectorElts(&SI, AllOnesEltMask, UndefElts)) {
if (V != &SI)
- return ReplaceInstUsesWith(SI, V);
+ return replaceInstUsesWith(SI, V);
return &SI;
}
if (isa<ConstantAggregateZero>(CondVal)) {
- return ReplaceInstUsesWith(SI, FalseVal);
+ return replaceInstUsesWith(SI, FalseVal);
+ }
+ }
+
+ // See if we can determine the result of this select based on a dominating
+ // condition.
+ BasicBlock *Parent = SI.getParent();
+ if (BasicBlock *Dom = Parent->getSinglePredecessor()) {
+ auto *PBI = dyn_cast_or_null<BranchInst>(Dom->getTerminator());
+ if (PBI && PBI->isConditional() &&
+ PBI->getSuccessor(0) != PBI->getSuccessor(1) &&
+ (PBI->getSuccessor(0) == Parent || PBI->getSuccessor(1) == Parent)) {
+ bool CondIsFalse = PBI->getSuccessor(1) == Parent;
+ Optional<bool> Implication = isImpliedCondition(
+ PBI->getCondition(), SI.getCondition(), DL, CondIsFalse);
+ if (Implication) {
+ Value *V = *Implication ? TrueVal : FalseVal;
+ return replaceInstUsesWith(SI, V);
+ }
}
}