diff options
Diffstat (limited to 'llvm/lib/IR/ConstantFold.cpp')
-rw-r--r-- | llvm/lib/IR/ConstantFold.cpp | 27 |
1 files changed, 21 insertions, 6 deletions
diff --git a/llvm/lib/IR/ConstantFold.cpp b/llvm/lib/IR/ConstantFold.cpp index 41b4f2919221..98adff107cec 100644 --- a/llvm/lib/IR/ConstantFold.cpp +++ b/llvm/lib/IR/ConstantFold.cpp @@ -1218,9 +1218,13 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode, Constant *C1, if (Instruction::isIntDivRem(Opcode) && C2Splat->isNullValue()) return PoisonValue::get(VTy); if (Constant *C1Splat = C1->getSplatValue()) { - return ConstantVector::getSplat( - VTy->getElementCount(), - ConstantExpr::get(Opcode, C1Splat, C2Splat)); + Constant *Res = + ConstantExpr::isDesirableBinOp(Opcode) + ? ConstantExpr::get(Opcode, C1Splat, C2Splat) + : ConstantFoldBinaryInstruction(Opcode, C1Splat, C2Splat); + if (!Res) + return nullptr; + return ConstantVector::getSplat(VTy->getElementCount(), Res); } } @@ -1237,7 +1241,12 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode, Constant *C1, if (Instruction::isIntDivRem(Opcode) && RHS->isNullValue()) return PoisonValue::get(VTy); - Result.push_back(ConstantExpr::get(Opcode, LHS, RHS)); + Constant *Res = ConstantExpr::isDesirableBinOp(Opcode) + ? ConstantExpr::get(Opcode, LHS, RHS) + : ConstantFoldBinaryInstruction(Opcode, LHS, RHS); + if (!Res) + return nullptr; + Result.push_back(Res); } return ConstantVector::get(Result); @@ -2218,9 +2227,15 @@ Constant *llvm::ConstantFoldGetElementPtr(Type *PointeeTy, Constant *C, : cast<FixedVectorType>(CurrIdx->getType())->getNumElements(), Factor); - NewIdxs[i] = ConstantExpr::getSRem(CurrIdx, Factor); + NewIdxs[i] = + ConstantFoldBinaryInstruction(Instruction::SRem, CurrIdx, Factor); + + Constant *Div = + ConstantFoldBinaryInstruction(Instruction::SDiv, CurrIdx, Factor); - Constant *Div = ConstantExpr::getSDiv(CurrIdx, Factor); + // We're working on either ConstantInt or vectors of ConstantInt, + // so these should always fold. + assert(NewIdxs[i] != nullptr && Div != nullptr && "Should have folded"); unsigned CommonExtendedWidth = std::max(PrevIdx->getType()->getScalarSizeInBits(), |