diff options
Diffstat (limited to 'lib/Transforms/InstCombine/InstCombineSelect.cpp')
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineSelect.cpp | 379 |
1 files changed, 216 insertions, 163 deletions
diff --git a/lib/Transforms/InstCombine/InstCombineSelect.cpp b/lib/Transforms/InstCombine/InstCombineSelect.cpp index 51219bcb0b7b..d7eed790e2ab 100644 --- a/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -116,25 +116,41 @@ static Constant *GetSelectFoldableConstant(Instruction *I) { } } -/// Here we have (select c, TI, FI), and we know that TI and FI -/// have the same opcode and only one use each. Try to simplify this. +/// We have (select c, TI, FI), and we know that TI and FI have the same opcode. Instruction *InstCombiner::FoldSelectOpOp(SelectInst &SI, Instruction *TI, Instruction *FI) { - if (TI->getNumOperands() == 1) { - // If this is a non-volatile load or a cast from the same type, - // merge. - if (TI->isCast()) { - Type *FIOpndTy = FI->getOperand(0)->getType(); - if (TI->getOperand(0)->getType() != FIOpndTy) + // If this is a cast from the same type, merge. + if (TI->getNumOperands() == 1 && TI->isCast()) { + Type *FIOpndTy = FI->getOperand(0)->getType(); + if (TI->getOperand(0)->getType() != FIOpndTy) + return nullptr; + + // The select condition may be a vector. We may only change the operand + // type if the vector width remains the same (and matches the condition). + Type *CondTy = SI.getCondition()->getType(); + if (CondTy->isVectorTy()) { + if (!FIOpndTy->isVectorTy()) return nullptr; - // The select condition may be a vector. We may only change the operand - // type if the vector width remains the same (and matches the condition). - Type *CondTy = SI.getCondition()->getType(); - if (CondTy->isVectorTy() && (!FIOpndTy->isVectorTy() || - CondTy->getVectorNumElements() != FIOpndTy->getVectorNumElements())) + if (CondTy->getVectorNumElements() != FIOpndTy->getVectorNumElements()) return nullptr; - } else { - return nullptr; // unknown unary op. + + // TODO: If the backend knew how to deal with casts better, we could + // remove this limitation. For now, there's too much potential to create + // worse codegen by promoting the select ahead of size-altering casts + // (PR28160). + // + // Note that ValueTracking's matchSelectPattern() looks through casts + // without checking 'hasOneUse' when it matches min/max patterns, so this + // transform may end up happening anyway. + if (TI->getOpcode() != Instruction::BitCast && + (!TI->hasOneUse() || !FI->hasOneUse())) + return nullptr; + + } else if (!TI->hasOneUse() || !FI->hasOneUse()) { + // TODO: The one-use restrictions for a scalar select could be eased if + // the fold of a select in visitLoadInst() was enhanced to match a pattern + // that includes a cast. + return nullptr; } // Fold this by inserting a select from the input values. @@ -144,8 +160,13 @@ Instruction *InstCombiner::FoldSelectOpOp(SelectInst &SI, Instruction *TI, TI->getType()); } - // Only handle binary operators here. - if (!isa<BinaryOperator>(TI)) + // TODO: This function ends awkwardly in unreachable - fix to be more normal. + + // Only handle binary operators with one-use here. As with the cast case + // above, it may be possible to relax the one-use constraint, but that needs + // be examined carefully since it may not reduce the total number of + // instructions. + if (!isa<BinaryOperator>(TI) || !TI->hasOneUse() || !FI->hasOneUse()) return nullptr; // Figure out if the operations have any operands in common. @@ -231,12 +252,7 @@ Instruction *InstCombiner::FoldSelectIntoOp(SelectInst &SI, Value *TrueVal, BinaryOperator *TVI_BO = cast<BinaryOperator>(TVI); BinaryOperator *BO = BinaryOperator::Create(TVI_BO->getOpcode(), FalseVal, NewSel); - if (isa<PossiblyExactOperator>(BO)) - BO->setIsExact(TVI_BO->isExact()); - if (isa<OverflowingBinaryOperator>(BO)) { - BO->setHasNoUnsignedWrap(TVI_BO->hasNoUnsignedWrap()); - BO->setHasNoSignedWrap(TVI_BO->hasNoSignedWrap()); - } + BO->copyIRFlags(TVI_BO); return BO; } } @@ -266,12 +282,7 @@ Instruction *InstCombiner::FoldSelectIntoOp(SelectInst &SI, Value *TrueVal, BinaryOperator *FVI_BO = cast<BinaryOperator>(FVI); BinaryOperator *BO = BinaryOperator::Create(FVI_BO->getOpcode(), TrueVal, NewSel); - if (isa<PossiblyExactOperator>(BO)) - BO->setIsExact(FVI_BO->isExact()); - if (isa<OverflowingBinaryOperator>(BO)) { - BO->setHasNoUnsignedWrap(FVI_BO->hasNoUnsignedWrap()); - BO->setHasNoSignedWrap(FVI_BO->hasNoSignedWrap()); - } + BO->copyIRFlags(FVI_BO); return BO; } } @@ -353,7 +364,7 @@ static Value *foldSelectICmpAndOr(const SelectInst &SI, Value *TrueVal, /// %1 = icmp ne i32 %x, 0 /// %2 = select i1 %1, i32 %0, i32 32 /// \code -/// +/// /// into: /// %0 = tail call i32 @llvm.cttz.i32(i32 %x, i1 false) static Value *foldSelectCttzCtlz(ICmpInst *ICI, Value *TrueVal, Value *FalseVal, @@ -519,10 +530,10 @@ Instruction *InstCombiner::visitSelectInstWithICmp(SelectInst &SI, // Check if we can express the operation with a single or. if (C2->isAllOnesValue()) - return ReplaceInstUsesWith(SI, Builder->CreateOr(AShr, C1)); + return replaceInstUsesWith(SI, Builder->CreateOr(AShr, C1)); Value *And = Builder->CreateAnd(AShr, C2->getValue()-C1->getValue()); - return ReplaceInstUsesWith(SI, Builder->CreateAdd(And, C1)); + return replaceInstUsesWith(SI, Builder->CreateAdd(And, C1)); } } } @@ -585,15 +596,15 @@ Instruction *InstCombiner::visitSelectInstWithICmp(SelectInst &SI, V = Builder->CreateOr(X, *Y); if (V) - return ReplaceInstUsesWith(SI, V); + return replaceInstUsesWith(SI, V); } } if (Value *V = foldSelectICmpAndOr(SI, TrueVal, FalseVal, Builder)) - return ReplaceInstUsesWith(SI, V); + return replaceInstUsesWith(SI, V); if (Value *V = foldSelectCttzCtlz(ICI, TrueVal, FalseVal, Builder)) - return ReplaceInstUsesWith(SI, V); + return replaceInstUsesWith(SI, V); return Changed ? &SI : nullptr; } @@ -642,11 +653,14 @@ Instruction *InstCombiner::FoldSPFofSPF(Instruction *Inner, Value *A, Value *B, Instruction &Outer, SelectPatternFlavor SPF2, Value *C) { + if (Outer.getType() != Inner->getType()) + return nullptr; + if (C == A || C == B) { // MAX(MAX(A, B), B) -> MAX(A, B) // MIN(MIN(a, b), a) -> MIN(a, b) if (SPF1 == SPF2) - return ReplaceInstUsesWith(Outer, Inner); + return replaceInstUsesWith(Outer, Inner); // MAX(MIN(a, b), a) -> a // MIN(MAX(a, b), a) -> a @@ -654,14 +668,14 @@ Instruction *InstCombiner::FoldSPFofSPF(Instruction *Inner, (SPF1 == SPF_SMAX && SPF2 == SPF_SMIN) || (SPF1 == SPF_UMIN && SPF2 == SPF_UMAX) || (SPF1 == SPF_UMAX && SPF2 == SPF_UMIN)) - return ReplaceInstUsesWith(Outer, C); + return replaceInstUsesWith(Outer, C); } if (SPF1 == SPF2) { if (ConstantInt *CB = dyn_cast<ConstantInt>(B)) { if (ConstantInt *CC = dyn_cast<ConstantInt>(C)) { - APInt ACB = CB->getValue(); - APInt ACC = CC->getValue(); + const APInt &ACB = CB->getValue(); + const APInt &ACC = CC->getValue(); // MIN(MIN(A, 23), 97) -> MIN(A, 23) // MAX(MAX(A, 97), 23) -> MAX(A, 97) @@ -669,7 +683,7 @@ Instruction *InstCombiner::FoldSPFofSPF(Instruction *Inner, (SPF1 == SPF_SMIN && ACB.sle(ACC)) || (SPF1 == SPF_UMAX && ACB.uge(ACC)) || (SPF1 == SPF_SMAX && ACB.sge(ACC))) - return ReplaceInstUsesWith(Outer, Inner); + return replaceInstUsesWith(Outer, Inner); // MIN(MIN(A, 97), 23) -> MIN(A, 23) // MAX(MAX(A, 23), 97) -> MAX(A, 97) @@ -687,7 +701,7 @@ Instruction *InstCombiner::FoldSPFofSPF(Instruction *Inner, // ABS(ABS(X)) -> ABS(X) // NABS(NABS(X)) -> NABS(X) if (SPF1 == SPF2 && (SPF1 == SPF_ABS || SPF1 == SPF_NABS)) { - return ReplaceInstUsesWith(Outer, Inner); + return replaceInstUsesWith(Outer, Inner); } // ABS(NABS(X)) -> ABS(X) @@ -697,7 +711,7 @@ Instruction *InstCombiner::FoldSPFofSPF(Instruction *Inner, SelectInst *SI = cast<SelectInst>(Inner); Value *NewSI = Builder->CreateSelect( SI->getCondition(), SI->getFalseValue(), SI->getTrueValue()); - return ReplaceInstUsesWith(Outer, NewSI); + return replaceInstUsesWith(Outer, NewSI); } auto IsFreeOrProfitableToInvert = @@ -742,7 +756,7 @@ Instruction *InstCombiner::FoldSPFofSPF(Instruction *Inner, Builder, getInverseMinMaxSelectPattern(SPF1), NotA, NotB); Value *NewOuter = Builder->CreateNot(generateMinMaxSelectPattern( Builder, getInverseMinMaxSelectPattern(SPF2), NewInner, NotC)); - return ReplaceInstUsesWith(Outer, NewOuter); + return replaceInstUsesWith(Outer, NewOuter); } return nullptr; @@ -823,76 +837,156 @@ static Value *foldSelectICmpAnd(const SelectInst &SI, ConstantInt *TrueVal, return V; } +/// Turn select C, (X + Y), (X - Y) --> (X + (select C, Y, (-Y))). +/// This is even legal for FP. +static Instruction *foldAddSubSelect(SelectInst &SI, + InstCombiner::BuilderTy &Builder) { + Value *CondVal = SI.getCondition(); + Value *TrueVal = SI.getTrueValue(); + Value *FalseVal = SI.getFalseValue(); + auto *TI = dyn_cast<Instruction>(TrueVal); + auto *FI = dyn_cast<Instruction>(FalseVal); + if (!TI || !FI || !TI->hasOneUse() || !FI->hasOneUse()) + return nullptr; + + Instruction *AddOp = nullptr, *SubOp = nullptr; + if ((TI->getOpcode() == Instruction::Sub && + FI->getOpcode() == Instruction::Add) || + (TI->getOpcode() == Instruction::FSub && + FI->getOpcode() == Instruction::FAdd)) { + AddOp = FI; + SubOp = TI; + } else if ((FI->getOpcode() == Instruction::Sub && + TI->getOpcode() == Instruction::Add) || + (FI->getOpcode() == Instruction::FSub && + TI->getOpcode() == Instruction::FAdd)) { + AddOp = TI; + SubOp = FI; + } + + if (AddOp) { + Value *OtherAddOp = nullptr; + if (SubOp->getOperand(0) == AddOp->getOperand(0)) { + OtherAddOp = AddOp->getOperand(1); + } else if (SubOp->getOperand(0) == AddOp->getOperand(1)) { + OtherAddOp = AddOp->getOperand(0); + } + + if (OtherAddOp) { + // So at this point we know we have (Y -> OtherAddOp): + // select C, (add X, Y), (sub X, Z) + Value *NegVal; // Compute -Z + if (SI.getType()->isFPOrFPVectorTy()) { + NegVal = Builder.CreateFNeg(SubOp->getOperand(1)); + if (Instruction *NegInst = dyn_cast<Instruction>(NegVal)) { + FastMathFlags Flags = AddOp->getFastMathFlags(); + Flags &= SubOp->getFastMathFlags(); + NegInst->setFastMathFlags(Flags); + } + } else { + NegVal = Builder.CreateNeg(SubOp->getOperand(1)); + } + + Value *NewTrueOp = OtherAddOp; + Value *NewFalseOp = NegVal; + if (AddOp != TI) + std::swap(NewTrueOp, NewFalseOp); + Value *NewSel = Builder.CreateSelect(CondVal, NewTrueOp, NewFalseOp, + SI.getName() + ".p"); + + if (SI.getType()->isFPOrFPVectorTy()) { + Instruction *RI = + BinaryOperator::CreateFAdd(SubOp->getOperand(0), NewSel); + + FastMathFlags Flags = AddOp->getFastMathFlags(); + Flags &= SubOp->getFastMathFlags(); + RI->setFastMathFlags(Flags); + return RI; + } else + return BinaryOperator::CreateAdd(SubOp->getOperand(0), NewSel); + } + } + return nullptr; +} + Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { Value *CondVal = SI.getCondition(); Value *TrueVal = SI.getTrueValue(); Value *FalseVal = SI.getFalseValue(); + Type *SelType = SI.getType(); if (Value *V = SimplifySelectInst(CondVal, TrueVal, FalseVal, DL, TLI, DT, AC)) - return ReplaceInstUsesWith(SI, V); + return replaceInstUsesWith(SI, V); - if (SI.getType()->isIntegerTy(1)) { - if (ConstantInt *C = dyn_cast<ConstantInt>(TrueVal)) { - if (C->getZExtValue()) { - // Change: A = select B, true, C --> A = or B, C - return BinaryOperator::CreateOr(CondVal, FalseVal); - } + if (SelType->getScalarType()->isIntegerTy(1) && + TrueVal->getType() == CondVal->getType()) { + if (match(TrueVal, m_One())) { + // Change: A = select B, true, C --> A = or B, C + return BinaryOperator::CreateOr(CondVal, FalseVal); + } + if (match(TrueVal, m_Zero())) { // Change: A = select B, false, C --> A = and !B, C - Value *NotCond = Builder->CreateNot(CondVal, "not."+CondVal->getName()); + Value *NotCond = Builder->CreateNot(CondVal, "not." + CondVal->getName()); return BinaryOperator::CreateAnd(NotCond, FalseVal); } - if (ConstantInt *C = dyn_cast<ConstantInt>(FalseVal)) { - if (!C->getZExtValue()) { - // Change: A = select B, C, false --> A = and B, C - return BinaryOperator::CreateAnd(CondVal, TrueVal); - } + if (match(FalseVal, m_Zero())) { + // Change: A = select B, C, false --> A = and B, C + return BinaryOperator::CreateAnd(CondVal, TrueVal); + } + if (match(FalseVal, m_One())) { // Change: A = select B, C, true --> A = or !B, C - Value *NotCond = Builder->CreateNot(CondVal, "not."+CondVal->getName()); + Value *NotCond = Builder->CreateNot(CondVal, "not." + CondVal->getName()); return BinaryOperator::CreateOr(NotCond, TrueVal); } - // select a, b, a -> a&b - // select a, a, b -> a|b + // select a, a, b -> a | b + // select a, b, a -> a & b if (CondVal == TrueVal) return BinaryOperator::CreateOr(CondVal, FalseVal); if (CondVal == FalseVal) return BinaryOperator::CreateAnd(CondVal, TrueVal); - // select a, ~a, b -> (~a)&b - // select a, b, ~a -> (~a)|b + // select a, ~a, b -> (~a) & b + // select a, b, ~a -> (~a) | b if (match(TrueVal, m_Not(m_Specific(CondVal)))) return BinaryOperator::CreateAnd(TrueVal, FalseVal); if (match(FalseVal, m_Not(m_Specific(CondVal)))) return BinaryOperator::CreateOr(TrueVal, FalseVal); } - // Selecting between two integer constants? - if (ConstantInt *TrueValC = dyn_cast<ConstantInt>(TrueVal)) - if (ConstantInt *FalseValC = dyn_cast<ConstantInt>(FalseVal)) { - // select C, 1, 0 -> zext C to int - if (FalseValC->isZero() && TrueValC->getValue() == 1) - return new ZExtInst(CondVal, SI.getType()); - - // select C, -1, 0 -> sext C to int - if (FalseValC->isZero() && TrueValC->isAllOnesValue()) - return new SExtInst(CondVal, SI.getType()); - - // select C, 0, 1 -> zext !C to int - if (TrueValC->isZero() && FalseValC->getValue() == 1) { - Value *NotCond = Builder->CreateNot(CondVal, "not."+CondVal->getName()); - return new ZExtInst(NotCond, SI.getType()); - } + // Selecting between two integer or vector splat integer constants? + // + // Note that we don't handle a scalar select of vectors: + // select i1 %c, <2 x i8> <1, 1>, <2 x i8> <0, 0> + // because that may need 3 instructions to splat the condition value: + // extend, insertelement, shufflevector. + if (CondVal->getType()->isVectorTy() == SelType->isVectorTy()) { + // select C, 1, 0 -> zext C to int + if (match(TrueVal, m_One()) && match(FalseVal, m_Zero())) + return new ZExtInst(CondVal, SelType); + + // select C, -1, 0 -> sext C to int + if (match(TrueVal, m_AllOnes()) && match(FalseVal, m_Zero())) + return new SExtInst(CondVal, SelType); + + // select C, 0, 1 -> zext !C to int + if (match(TrueVal, m_Zero()) && match(FalseVal, m_One())) { + Value *NotCond = Builder->CreateNot(CondVal, "not." + CondVal->getName()); + return new ZExtInst(NotCond, SelType); + } - // select C, 0, -1 -> sext !C to int - if (TrueValC->isZero() && FalseValC->isAllOnesValue()) { - Value *NotCond = Builder->CreateNot(CondVal, "not."+CondVal->getName()); - return new SExtInst(NotCond, SI.getType()); - } + // select C, 0, -1 -> sext !C to int + if (match(TrueVal, m_Zero()) && match(FalseVal, m_AllOnes())) { + Value *NotCond = Builder->CreateNot(CondVal, "not." + CondVal->getName()); + return new SExtInst(NotCond, SelType); + } + } + if (ConstantInt *TrueValC = dyn_cast<ConstantInt>(TrueVal)) + if (ConstantInt *FalseValC = dyn_cast<ConstantInt>(FalseVal)) if (Value *V = foldSelectICmpAnd(SI, TrueValC, FalseValC, Builder)) - return ReplaceInstUsesWith(SI, V); - } + return replaceInstUsesWith(SI, V); // See if we are selecting two values based on a comparison of the two values. if (FCmpInst *FCI = dyn_cast<FCmpInst>(CondVal)) { @@ -907,7 +1001,7 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { !CFPt->getValueAPF().isZero()) || ((CFPf = dyn_cast<ConstantFP>(FalseVal)) && !CFPf->getValueAPF().isZero())) - return ReplaceInstUsesWith(SI, FalseVal); + return replaceInstUsesWith(SI, FalseVal); } // Transform (X une Y) ? X : Y -> X if (FCI->getPredicate() == FCmpInst::FCMP_UNE) { @@ -919,7 +1013,7 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { !CFPt->getValueAPF().isZero()) || ((CFPf = dyn_cast<ConstantFP>(FalseVal)) && !CFPf->getValueAPF().isZero())) - return ReplaceInstUsesWith(SI, TrueVal); + return replaceInstUsesWith(SI, TrueVal); } // Canonicalize to use ordered comparisons by swapping the select @@ -950,7 +1044,7 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { !CFPt->getValueAPF().isZero()) || ((CFPf = dyn_cast<ConstantFP>(FalseVal)) && !CFPf->getValueAPF().isZero())) - return ReplaceInstUsesWith(SI, FalseVal); + return replaceInstUsesWith(SI, FalseVal); } // Transform (X une Y) ? Y : X -> Y if (FCI->getPredicate() == FCmpInst::FCMP_UNE) { @@ -962,7 +1056,7 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { !CFPt->getValueAPF().isZero()) || ((CFPf = dyn_cast<ConstantFP>(FalseVal)) && !CFPf->getValueAPF().isZero())) - return ReplaceInstUsesWith(SI, TrueVal); + return replaceInstUsesWith(SI, TrueVal); } // Canonicalize to use ordered comparisons by swapping the select @@ -991,77 +1085,18 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { if (Instruction *Result = visitSelectInstWithICmp(SI, ICI)) return Result; - if (Instruction *TI = dyn_cast<Instruction>(TrueVal)) - if (Instruction *FI = dyn_cast<Instruction>(FalseVal)) - if (TI->hasOneUse() && FI->hasOneUse()) { - Instruction *AddOp = nullptr, *SubOp = nullptr; - - // Turn (select C, (op X, Y), (op X, Z)) -> (op X, (select C, Y, Z)) - if (TI->getOpcode() == FI->getOpcode()) - if (Instruction *IV = FoldSelectOpOp(SI, TI, FI)) - return IV; - - // Turn select C, (X+Y), (X-Y) --> (X+(select C, Y, (-Y))). This is - // even legal for FP. - if ((TI->getOpcode() == Instruction::Sub && - FI->getOpcode() == Instruction::Add) || - (TI->getOpcode() == Instruction::FSub && - FI->getOpcode() == Instruction::FAdd)) { - AddOp = FI; SubOp = TI; - } else if ((FI->getOpcode() == Instruction::Sub && - TI->getOpcode() == Instruction::Add) || - (FI->getOpcode() == Instruction::FSub && - TI->getOpcode() == Instruction::FAdd)) { - AddOp = TI; SubOp = FI; - } + if (Instruction *Add = foldAddSubSelect(SI, *Builder)) + return Add; - if (AddOp) { - Value *OtherAddOp = nullptr; - if (SubOp->getOperand(0) == AddOp->getOperand(0)) { - OtherAddOp = AddOp->getOperand(1); - } else if (SubOp->getOperand(0) == AddOp->getOperand(1)) { - OtherAddOp = AddOp->getOperand(0); - } - - if (OtherAddOp) { - // So at this point we know we have (Y -> OtherAddOp): - // select C, (add X, Y), (sub X, Z) - Value *NegVal; // Compute -Z - if (SI.getType()->isFPOrFPVectorTy()) { - NegVal = Builder->CreateFNeg(SubOp->getOperand(1)); - if (Instruction *NegInst = dyn_cast<Instruction>(NegVal)) { - FastMathFlags Flags = AddOp->getFastMathFlags(); - Flags &= SubOp->getFastMathFlags(); - NegInst->setFastMathFlags(Flags); - } - } else { - NegVal = Builder->CreateNeg(SubOp->getOperand(1)); - } - - Value *NewTrueOp = OtherAddOp; - Value *NewFalseOp = NegVal; - if (AddOp != TI) - std::swap(NewTrueOp, NewFalseOp); - Value *NewSel = - Builder->CreateSelect(CondVal, NewTrueOp, - NewFalseOp, SI.getName() + ".p"); - - if (SI.getType()->isFPOrFPVectorTy()) { - Instruction *RI = - BinaryOperator::CreateFAdd(SubOp->getOperand(0), NewSel); - - FastMathFlags Flags = AddOp->getFastMathFlags(); - Flags &= SubOp->getFastMathFlags(); - RI->setFastMathFlags(Flags); - return RI; - } else - return BinaryOperator::CreateAdd(SubOp->getOperand(0), NewSel); - } - } - } + // Turn (select C, (op X, Y), (op X, Z)) -> (op X, (select C, Y, Z)) + auto *TI = dyn_cast<Instruction>(TrueVal); + auto *FI = dyn_cast<Instruction>(FalseVal); + if (TI && FI && TI->getOpcode() == FI->getOpcode()) + if (Instruction *IV = FoldSelectOpOp(SI, TI, FI)) + return IV; // See if we can fold the select into one of our operands. - if (SI.getType()->isIntOrIntVectorTy() || SI.getType()->isFPOrFPVectorTy()) { + if (SelType->isIntOrIntVectorTy() || SelType->isFPOrFPVectorTy()) { if (Instruction *FoldI = FoldSelectIntoOp(SI, TrueVal, FalseVal)) return FoldI; @@ -1073,7 +1108,7 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { if (SelectPatternResult::isMinOrMax(SPF)) { // Canonicalize so that type casts are outside select patterns. if (LHS->getType()->getPrimitiveSizeInBits() != - SI.getType()->getPrimitiveSizeInBits()) { + SelType->getPrimitiveSizeInBits()) { CmpInst::Predicate Pred = getCmpPredicateForMinMax(SPF, SPR.Ordered); Value *Cmp; @@ -1088,8 +1123,8 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { Value *NewSI = Builder->CreateCast(CastOp, Builder->CreateSelect(Cmp, LHS, RHS), - SI.getType()); - return ReplaceInstUsesWith(SI, NewSI); + SelType); + return replaceInstUsesWith(SI, NewSI); } } @@ -1132,7 +1167,7 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { : Builder->CreateICmpULT(NewLHS, NewRHS); Value *NewSI = Builder->CreateNot(Builder->CreateSelect(NewCmp, NewLHS, NewRHS)); - return ReplaceInstUsesWith(SI, NewSI); + return replaceInstUsesWith(SI, NewSI); } } } @@ -1195,18 +1230,36 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { return &SI; } - if (VectorType* VecTy = dyn_cast<VectorType>(SI.getType())) { + if (VectorType* VecTy = dyn_cast<VectorType>(SelType)) { unsigned VWidth = VecTy->getNumElements(); APInt UndefElts(VWidth, 0); APInt AllOnesEltMask(APInt::getAllOnesValue(VWidth)); if (Value *V = SimplifyDemandedVectorElts(&SI, AllOnesEltMask, UndefElts)) { if (V != &SI) - return ReplaceInstUsesWith(SI, V); + return replaceInstUsesWith(SI, V); return &SI; } if (isa<ConstantAggregateZero>(CondVal)) { - return ReplaceInstUsesWith(SI, FalseVal); + return replaceInstUsesWith(SI, FalseVal); + } + } + + // See if we can determine the result of this select based on a dominating + // condition. + BasicBlock *Parent = SI.getParent(); + if (BasicBlock *Dom = Parent->getSinglePredecessor()) { + auto *PBI = dyn_cast_or_null<BranchInst>(Dom->getTerminator()); + if (PBI && PBI->isConditional() && + PBI->getSuccessor(0) != PBI->getSuccessor(1) && + (PBI->getSuccessor(0) == Parent || PBI->getSuccessor(1) == Parent)) { + bool CondIsFalse = PBI->getSuccessor(1) == Parent; + Optional<bool> Implication = isImpliedCondition( + PBI->getCondition(), SI.getCondition(), DL, CondIsFalse); + if (Implication) { + Value *V = *Implication ? TrueVal : FalseVal; + return replaceInstUsesWith(SI, V); + } } } |