diff options
Diffstat (limited to 'lib/Transforms/InstCombine/InstCombineShifts.cpp')
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineShifts.cpp | 151 |
1 files changed, 76 insertions, 75 deletions
diff --git a/lib/Transforms/InstCombine/InstCombineShifts.cpp b/lib/Transforms/InstCombine/InstCombineShifts.cpp index 0c7defa5fff8..08e16a7ee1af 100644 --- a/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -55,6 +55,51 @@ Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) { return nullptr; } +/// Return true if we can simplify two logical (either left or right) shifts +/// that have constant shift amounts. +static bool canEvaluateShiftedShift(unsigned FirstShiftAmt, + bool IsFirstShiftLeft, + Instruction *SecondShift, InstCombiner &IC, + Instruction *CxtI) { + assert(SecondShift->isLogicalShift() && "Unexpected instruction type"); + + // We need constant shifts. + auto *SecondShiftConst = dyn_cast<ConstantInt>(SecondShift->getOperand(1)); + if (!SecondShiftConst) + return false; + + unsigned SecondShiftAmt = SecondShiftConst->getZExtValue(); + bool IsSecondShiftLeft = SecondShift->getOpcode() == Instruction::Shl; + + // We can always fold shl(c1) + shl(c2) -> shl(c1+c2). + // We can always fold lshr(c1) + lshr(c2) -> lshr(c1+c2). + if (IsFirstShiftLeft == IsSecondShiftLeft) + return true; + + // We can always fold lshr(c) + shl(c) -> and(c2). + // We can always fold shl(c) + lshr(c) -> and(c2). + if (FirstShiftAmt == SecondShiftAmt) + return true; + + unsigned TypeWidth = SecondShift->getType()->getScalarSizeInBits(); + + // If the 2nd shift is bigger than the 1st, we can fold: + // lshr(c1) + shl(c2) -> shl(c3) + and(c4) or + // shl(c1) + lshr(c2) -> lshr(c3) + and(c4), + // but it isn't profitable unless we know the and'd out bits are already zero. + // Also check that the 2nd shift is valid (less than the type width) or we'll + // crash trying to produce the bit mask for the 'and'. + if (SecondShiftAmt > FirstShiftAmt && SecondShiftAmt < TypeWidth) { + unsigned MaskShift = IsSecondShiftLeft ? TypeWidth - SecondShiftAmt + : SecondShiftAmt - FirstShiftAmt; + APInt Mask = APInt::getLowBitsSet(TypeWidth, FirstShiftAmt) << MaskShift; + if (IC.MaskedValueIsZero(SecondShift->getOperand(0), Mask, 0, CxtI)) + return true; + } + + return false; +} + /// See if we can compute the specified value, but shifted /// logically to the left or right by some number of bits. This should return /// true if the expression can be computed for the same cost as the current @@ -67,7 +112,7 @@ Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) { /// where the client will ask if E can be computed shifted right by 64-bits. If /// this succeeds, the GetShiftedValue function will be called to produce the /// value. -static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool isLeftShift, +static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift, InstCombiner &IC, Instruction *CxtI) { // We can always evaluate constants shifted. if (isa<Constant>(V)) @@ -81,8 +126,8 @@ static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool isLeftShift, // the value which means that we don't care if the shift has multiple uses. // TODO: Handle opposite shift by exact value. ConstantInt *CI = nullptr; - if ((isLeftShift && match(I, m_LShr(m_Value(), m_ConstantInt(CI)))) || - (!isLeftShift && match(I, m_Shl(m_Value(), m_ConstantInt(CI))))) { + if ((IsLeftShift && match(I, m_LShr(m_Value(), m_ConstantInt(CI)))) || + (!IsLeftShift && match(I, m_Shl(m_Value(), m_ConstantInt(CI))))) { if (CI->getZExtValue() == NumBits) { // TODO: Check that the input bits are already zero with MaskedValueIsZero #if 0 @@ -111,64 +156,19 @@ static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool isLeftShift, case Instruction::Or: case Instruction::Xor: // Bitwise operators can all arbitrarily be arbitrarily evaluated shifted. - return CanEvaluateShifted(I->getOperand(0), NumBits, isLeftShift, IC, I) && - CanEvaluateShifted(I->getOperand(1), NumBits, isLeftShift, IC, I); - - case Instruction::Shl: { - // We can often fold the shift into shifts-by-a-constant. - CI = dyn_cast<ConstantInt>(I->getOperand(1)); - if (!CI) return false; - - // We can always fold shl(c1)+shl(c2) -> shl(c1+c2). - if (isLeftShift) return true; + return CanEvaluateShifted(I->getOperand(0), NumBits, IsLeftShift, IC, I) && + CanEvaluateShifted(I->getOperand(1), NumBits, IsLeftShift, IC, I); - // We can always turn shl(c)+shr(c) -> and(c2). - if (CI->getValue() == NumBits) return true; + case Instruction::Shl: + case Instruction::LShr: + return canEvaluateShiftedShift(NumBits, IsLeftShift, I, IC, CxtI); - unsigned TypeWidth = I->getType()->getScalarSizeInBits(); - - // We can turn shl(c1)+shr(c2) -> shl(c3)+and(c4), but it isn't - // profitable unless we know the and'd out bits are already zero. - if (CI->getZExtValue() > NumBits) { - unsigned LowBits = TypeWidth - CI->getZExtValue(); - if (IC.MaskedValueIsZero(I->getOperand(0), - APInt::getLowBitsSet(TypeWidth, NumBits) << LowBits, - 0, CxtI)) - return true; - } - - return false; - } - case Instruction::LShr: { - // We can often fold the shift into shifts-by-a-constant. - CI = dyn_cast<ConstantInt>(I->getOperand(1)); - if (!CI) return false; - - // We can always fold lshr(c1)+lshr(c2) -> lshr(c1+c2). - if (!isLeftShift) return true; - - // We can always turn lshr(c)+shl(c) -> and(c2). - if (CI->getValue() == NumBits) return true; - - unsigned TypeWidth = I->getType()->getScalarSizeInBits(); - - // We can always turn lshr(c1)+shl(c2) -> lshr(c3)+and(c4), but it isn't - // profitable unless we know the and'd out bits are already zero. - if (CI->getValue().ult(TypeWidth) && CI->getZExtValue() > NumBits) { - unsigned LowBits = CI->getZExtValue() - NumBits; - if (IC.MaskedValueIsZero(I->getOperand(0), - APInt::getLowBitsSet(TypeWidth, NumBits) << LowBits, - 0, CxtI)) - return true; - } - - return false; - } case Instruction::Select: { SelectInst *SI = cast<SelectInst>(I); - return CanEvaluateShifted(SI->getTrueValue(), NumBits, isLeftShift, - IC, SI) && - CanEvaluateShifted(SI->getFalseValue(), NumBits, isLeftShift, IC, SI); + Value *TrueVal = SI->getTrueValue(); + Value *FalseVal = SI->getFalseValue(); + return CanEvaluateShifted(TrueVal, NumBits, IsLeftShift, IC, SI) && + CanEvaluateShifted(FalseVal, NumBits, IsLeftShift, IC, SI); } case Instruction::PHI: { // We can change a phi if we can change all operands. Note that we never @@ -176,8 +176,7 @@ static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool isLeftShift, // instructions with a single use. PHINode *PN = cast<PHINode>(I); for (Value *IncValue : PN->incoming_values()) - if (!CanEvaluateShifted(IncValue, NumBits, isLeftShift, - IC, PN)) + if (!CanEvaluateShifted(IncValue, NumBits, IsLeftShift, IC, PN)) return false; return true; } @@ -257,6 +256,8 @@ static Value *GetShiftedValue(Value *V, unsigned NumBits, bool isLeftShift, BO->setHasNoSignedWrap(false); return BO; } + // FIXME: This is almost identical to the SHL case. Refactor both cases into + // a helper function. case Instruction::LShr: { BinaryOperator *BO = cast<BinaryOperator>(I); unsigned TypeWidth = BO->getType()->getScalarSizeInBits(); @@ -340,7 +341,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, DEBUG(dbgs() << "ICE: GetShiftedValue propagating shift through expression" " to eliminate shift:\n IN: " << *Op0 << "\n SH: " << I <<"\n"); - return ReplaceInstUsesWith( + return replaceInstUsesWith( I, GetShiftedValue(Op0, COp1->getZExtValue(), isLeftShift, *this, DL)); } @@ -356,7 +357,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, if (BO->getOpcode() == Instruction::Mul && isLeftShift) if (Constant *BOOp = dyn_cast<Constant>(BO->getOperand(1))) return BinaryOperator::CreateMul(BO->getOperand(0), - ConstantExpr::getShl(BOOp, Op1)); + ConstantExpr::getShl(BOOp, Op1)); // Try to fold constant and into select arguments. if (SelectInst *SI = dyn_cast<SelectInst>(Op0)) @@ -573,7 +574,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, // saturates. if (AmtSum >= TypeBits) { if (I.getOpcode() != Instruction::AShr) - return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType())); + return replaceInstUsesWith(I, Constant::getNullValue(I.getType())); AmtSum = TypeBits-1; // Saturate to 31 for i32 ashr. } @@ -694,12 +695,12 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, Instruction *InstCombiner::visitShl(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (Value *V = SimplifyShlInst(I.getOperand(0), I.getOperand(1), I.hasNoSignedWrap(), I.hasNoUnsignedWrap(), DL, TLI, DT, AC)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (Instruction *V = commonShiftTransforms(I)) return V; @@ -710,11 +711,11 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) { // If the shifted-out value is known-zero, then this is a NUW shift. if (!I.hasNoUnsignedWrap() && MaskedValueIsZero(I.getOperand(0), - APInt::getHighBitsSet(Op1C->getBitWidth(), ShAmt), - 0, &I)) { - I.setHasNoUnsignedWrap(); - return &I; - } + APInt::getHighBitsSet(Op1C->getBitWidth(), ShAmt), 0, + &I)) { + I.setHasNoUnsignedWrap(); + return &I; + } // If the shifted out value is all signbits, this is a NSW shift. if (!I.hasNoSignedWrap() && @@ -736,11 +737,11 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) { Instruction *InstCombiner::visitLShr(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (Value *V = SimplifyLShrInst(I.getOperand(0), I.getOperand(1), I.isExact(), DL, TLI, DT, AC)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (Instruction *R = commonShiftTransforms(I)) return R; @@ -780,11 +781,11 @@ Instruction *InstCombiner::visitLShr(BinaryOperator &I) { Instruction *InstCombiner::visitAShr(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (Value *V = SimplifyAShrInst(I.getOperand(0), I.getOperand(1), I.isExact(), DL, TLI, DT, AC)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (Instruction *R = commonShiftTransforms(I)) return R; @@ -813,8 +814,8 @@ Instruction *InstCombiner::visitAShr(BinaryOperator &I) { // If the shifted-out value is known-zero, then this is an exact shift. if (!I.isExact() && - MaskedValueIsZero(Op0,APInt::getLowBitsSet(Op1C->getBitWidth(),ShAmt), - 0, &I)){ + MaskedValueIsZero(Op0, APInt::getLowBitsSet(Op1C->getBitWidth(), ShAmt), + 0, &I)) { I.setIsExact(); return &I; } |