aboutsummaryrefslogtreecommitdiff
path: root/lib/Transforms/InstCombine/InstCombineShifts.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Transforms/InstCombine/InstCombineShifts.cpp')
-rw-r--r--lib/Transforms/InstCombine/InstCombineShifts.cpp151
1 files changed, 76 insertions, 75 deletions
diff --git a/lib/Transforms/InstCombine/InstCombineShifts.cpp b/lib/Transforms/InstCombine/InstCombineShifts.cpp
index 0c7defa5fff8..08e16a7ee1af 100644
--- a/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -55,6 +55,51 @@ Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) {
return nullptr;
}
+/// Return true if we can simplify two logical (either left or right) shifts
+/// that have constant shift amounts.
+static bool canEvaluateShiftedShift(unsigned FirstShiftAmt,
+ bool IsFirstShiftLeft,
+ Instruction *SecondShift, InstCombiner &IC,
+ Instruction *CxtI) {
+ assert(SecondShift->isLogicalShift() && "Unexpected instruction type");
+
+ // We need constant shifts.
+ auto *SecondShiftConst = dyn_cast<ConstantInt>(SecondShift->getOperand(1));
+ if (!SecondShiftConst)
+ return false;
+
+ unsigned SecondShiftAmt = SecondShiftConst->getZExtValue();
+ bool IsSecondShiftLeft = SecondShift->getOpcode() == Instruction::Shl;
+
+ // We can always fold shl(c1) + shl(c2) -> shl(c1+c2).
+ // We can always fold lshr(c1) + lshr(c2) -> lshr(c1+c2).
+ if (IsFirstShiftLeft == IsSecondShiftLeft)
+ return true;
+
+ // We can always fold lshr(c) + shl(c) -> and(c2).
+ // We can always fold shl(c) + lshr(c) -> and(c2).
+ if (FirstShiftAmt == SecondShiftAmt)
+ return true;
+
+ unsigned TypeWidth = SecondShift->getType()->getScalarSizeInBits();
+
+ // If the 2nd shift is bigger than the 1st, we can fold:
+ // lshr(c1) + shl(c2) -> shl(c3) + and(c4) or
+ // shl(c1) + lshr(c2) -> lshr(c3) + and(c4),
+ // but it isn't profitable unless we know the and'd out bits are already zero.
+ // Also check that the 2nd shift is valid (less than the type width) or we'll
+ // crash trying to produce the bit mask for the 'and'.
+ if (SecondShiftAmt > FirstShiftAmt && SecondShiftAmt < TypeWidth) {
+ unsigned MaskShift = IsSecondShiftLeft ? TypeWidth - SecondShiftAmt
+ : SecondShiftAmt - FirstShiftAmt;
+ APInt Mask = APInt::getLowBitsSet(TypeWidth, FirstShiftAmt) << MaskShift;
+ if (IC.MaskedValueIsZero(SecondShift->getOperand(0), Mask, 0, CxtI))
+ return true;
+ }
+
+ return false;
+}
+
/// See if we can compute the specified value, but shifted
/// logically to the left or right by some number of bits. This should return
/// true if the expression can be computed for the same cost as the current
@@ -67,7 +112,7 @@ Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) {
/// where the client will ask if E can be computed shifted right by 64-bits. If
/// this succeeds, the GetShiftedValue function will be called to produce the
/// value.
-static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool isLeftShift,
+static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift,
InstCombiner &IC, Instruction *CxtI) {
// We can always evaluate constants shifted.
if (isa<Constant>(V))
@@ -81,8 +126,8 @@ static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool isLeftShift,
// the value which means that we don't care if the shift has multiple uses.
// TODO: Handle opposite shift by exact value.
ConstantInt *CI = nullptr;
- if ((isLeftShift && match(I, m_LShr(m_Value(), m_ConstantInt(CI)))) ||
- (!isLeftShift && match(I, m_Shl(m_Value(), m_ConstantInt(CI))))) {
+ if ((IsLeftShift && match(I, m_LShr(m_Value(), m_ConstantInt(CI)))) ||
+ (!IsLeftShift && match(I, m_Shl(m_Value(), m_ConstantInt(CI))))) {
if (CI->getZExtValue() == NumBits) {
// TODO: Check that the input bits are already zero with MaskedValueIsZero
#if 0
@@ -111,64 +156,19 @@ static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool isLeftShift,
case Instruction::Or:
case Instruction::Xor:
// Bitwise operators can all arbitrarily be arbitrarily evaluated shifted.
- return CanEvaluateShifted(I->getOperand(0), NumBits, isLeftShift, IC, I) &&
- CanEvaluateShifted(I->getOperand(1), NumBits, isLeftShift, IC, I);
-
- case Instruction::Shl: {
- // We can often fold the shift into shifts-by-a-constant.
- CI = dyn_cast<ConstantInt>(I->getOperand(1));
- if (!CI) return false;
-
- // We can always fold shl(c1)+shl(c2) -> shl(c1+c2).
- if (isLeftShift) return true;
+ return CanEvaluateShifted(I->getOperand(0), NumBits, IsLeftShift, IC, I) &&
+ CanEvaluateShifted(I->getOperand(1), NumBits, IsLeftShift, IC, I);
- // We can always turn shl(c)+shr(c) -> and(c2).
- if (CI->getValue() == NumBits) return true;
+ case Instruction::Shl:
+ case Instruction::LShr:
+ return canEvaluateShiftedShift(NumBits, IsLeftShift, I, IC, CxtI);
- unsigned TypeWidth = I->getType()->getScalarSizeInBits();
-
- // We can turn shl(c1)+shr(c2) -> shl(c3)+and(c4), but it isn't
- // profitable unless we know the and'd out bits are already zero.
- if (CI->getZExtValue() > NumBits) {
- unsigned LowBits = TypeWidth - CI->getZExtValue();
- if (IC.MaskedValueIsZero(I->getOperand(0),
- APInt::getLowBitsSet(TypeWidth, NumBits) << LowBits,
- 0, CxtI))
- return true;
- }
-
- return false;
- }
- case Instruction::LShr: {
- // We can often fold the shift into shifts-by-a-constant.
- CI = dyn_cast<ConstantInt>(I->getOperand(1));
- if (!CI) return false;
-
- // We can always fold lshr(c1)+lshr(c2) -> lshr(c1+c2).
- if (!isLeftShift) return true;
-
- // We can always turn lshr(c)+shl(c) -> and(c2).
- if (CI->getValue() == NumBits) return true;
-
- unsigned TypeWidth = I->getType()->getScalarSizeInBits();
-
- // We can always turn lshr(c1)+shl(c2) -> lshr(c3)+and(c4), but it isn't
- // profitable unless we know the and'd out bits are already zero.
- if (CI->getValue().ult(TypeWidth) && CI->getZExtValue() > NumBits) {
- unsigned LowBits = CI->getZExtValue() - NumBits;
- if (IC.MaskedValueIsZero(I->getOperand(0),
- APInt::getLowBitsSet(TypeWidth, NumBits) << LowBits,
- 0, CxtI))
- return true;
- }
-
- return false;
- }
case Instruction::Select: {
SelectInst *SI = cast<SelectInst>(I);
- return CanEvaluateShifted(SI->getTrueValue(), NumBits, isLeftShift,
- IC, SI) &&
- CanEvaluateShifted(SI->getFalseValue(), NumBits, isLeftShift, IC, SI);
+ Value *TrueVal = SI->getTrueValue();
+ Value *FalseVal = SI->getFalseValue();
+ return CanEvaluateShifted(TrueVal, NumBits, IsLeftShift, IC, SI) &&
+ CanEvaluateShifted(FalseVal, NumBits, IsLeftShift, IC, SI);
}
case Instruction::PHI: {
// We can change a phi if we can change all operands. Note that we never
@@ -176,8 +176,7 @@ static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool isLeftShift,
// instructions with a single use.
PHINode *PN = cast<PHINode>(I);
for (Value *IncValue : PN->incoming_values())
- if (!CanEvaluateShifted(IncValue, NumBits, isLeftShift,
- IC, PN))
+ if (!CanEvaluateShifted(IncValue, NumBits, IsLeftShift, IC, PN))
return false;
return true;
}
@@ -257,6 +256,8 @@ static Value *GetShiftedValue(Value *V, unsigned NumBits, bool isLeftShift,
BO->setHasNoSignedWrap(false);
return BO;
}
+ // FIXME: This is almost identical to the SHL case. Refactor both cases into
+ // a helper function.
case Instruction::LShr: {
BinaryOperator *BO = cast<BinaryOperator>(I);
unsigned TypeWidth = BO->getType()->getScalarSizeInBits();
@@ -340,7 +341,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,
DEBUG(dbgs() << "ICE: GetShiftedValue propagating shift through expression"
" to eliminate shift:\n IN: " << *Op0 << "\n SH: " << I <<"\n");
- return ReplaceInstUsesWith(
+ return replaceInstUsesWith(
I, GetShiftedValue(Op0, COp1->getZExtValue(), isLeftShift, *this, DL));
}
@@ -356,7 +357,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,
if (BO->getOpcode() == Instruction::Mul && isLeftShift)
if (Constant *BOOp = dyn_cast<Constant>(BO->getOperand(1)))
return BinaryOperator::CreateMul(BO->getOperand(0),
- ConstantExpr::getShl(BOOp, Op1));
+ ConstantExpr::getShl(BOOp, Op1));
// Try to fold constant and into select arguments.
if (SelectInst *SI = dyn_cast<SelectInst>(Op0))
@@ -573,7 +574,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,
// saturates.
if (AmtSum >= TypeBits) {
if (I.getOpcode() != Instruction::AShr)
- return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType()));
+ return replaceInstUsesWith(I, Constant::getNullValue(I.getType()));
AmtSum = TypeBits-1; // Saturate to 31 for i32 ashr.
}
@@ -694,12 +695,12 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,
Instruction *InstCombiner::visitShl(BinaryOperator &I) {
if (Value *V = SimplifyVectorOp(I))
- return ReplaceInstUsesWith(I, V);
+ return replaceInstUsesWith(I, V);
if (Value *V =
SimplifyShlInst(I.getOperand(0), I.getOperand(1), I.hasNoSignedWrap(),
I.hasNoUnsignedWrap(), DL, TLI, DT, AC))
- return ReplaceInstUsesWith(I, V);
+ return replaceInstUsesWith(I, V);
if (Instruction *V = commonShiftTransforms(I))
return V;
@@ -710,11 +711,11 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) {
// If the shifted-out value is known-zero, then this is a NUW shift.
if (!I.hasNoUnsignedWrap() &&
MaskedValueIsZero(I.getOperand(0),
- APInt::getHighBitsSet(Op1C->getBitWidth(), ShAmt),
- 0, &I)) {
- I.setHasNoUnsignedWrap();
- return &I;
- }
+ APInt::getHighBitsSet(Op1C->getBitWidth(), ShAmt), 0,
+ &I)) {
+ I.setHasNoUnsignedWrap();
+ return &I;
+ }
// If the shifted out value is all signbits, this is a NSW shift.
if (!I.hasNoSignedWrap() &&
@@ -736,11 +737,11 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) {
Instruction *InstCombiner::visitLShr(BinaryOperator &I) {
if (Value *V = SimplifyVectorOp(I))
- return ReplaceInstUsesWith(I, V);
+ return replaceInstUsesWith(I, V);
if (Value *V = SimplifyLShrInst(I.getOperand(0), I.getOperand(1), I.isExact(),
DL, TLI, DT, AC))
- return ReplaceInstUsesWith(I, V);
+ return replaceInstUsesWith(I, V);
if (Instruction *R = commonShiftTransforms(I))
return R;
@@ -780,11 +781,11 @@ Instruction *InstCombiner::visitLShr(BinaryOperator &I) {
Instruction *InstCombiner::visitAShr(BinaryOperator &I) {
if (Value *V = SimplifyVectorOp(I))
- return ReplaceInstUsesWith(I, V);
+ return replaceInstUsesWith(I, V);
if (Value *V = SimplifyAShrInst(I.getOperand(0), I.getOperand(1), I.isExact(),
DL, TLI, DT, AC))
- return ReplaceInstUsesWith(I, V);
+ return replaceInstUsesWith(I, V);
if (Instruction *R = commonShiftTransforms(I))
return R;
@@ -813,8 +814,8 @@ Instruction *InstCombiner::visitAShr(BinaryOperator &I) {
// If the shifted-out value is known-zero, then this is an exact shift.
if (!I.isExact() &&
- MaskedValueIsZero(Op0,APInt::getLowBitsSet(Op1C->getBitWidth(),ShAmt),
- 0, &I)){
+ MaskedValueIsZero(Op0, APInt::getLowBitsSet(Op1C->getBitWidth(), ShAmt),
+ 0, &I)) {
I.setIsExact();
return &I;
}