diff options
Diffstat (limited to 'lib/Transforms/Scalar/NaryReassociate.cpp')
-rw-r--r-- | lib/Transforms/Scalar/NaryReassociate.cpp | 199 |
1 files changed, 115 insertions, 84 deletions
diff --git a/lib/Transforms/Scalar/NaryReassociate.cpp b/lib/Transforms/Scalar/NaryReassociate.cpp index f42f8306fccc..c8f885e7eec5 100644 --- a/lib/Transforms/Scalar/NaryReassociate.cpp +++ b/lib/Transforms/Scalar/NaryReassociate.cpp @@ -71,8 +71,8 @@ // // Limitations and TODO items: // -// 1) We only considers n-ary adds for now. This should be extended and -// generalized. +// 1) We only considers n-ary adds and muls for now. This should be extended +// and generalized. // //===----------------------------------------------------------------------===// @@ -110,11 +110,11 @@ public: void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addPreserved<DominatorTreeWrapperPass>(); - AU.addPreserved<ScalarEvolution>(); + AU.addPreserved<ScalarEvolutionWrapperPass>(); AU.addPreserved<TargetLibraryInfoWrapperPass>(); AU.addRequired<AssumptionCacheTracker>(); AU.addRequired<DominatorTreeWrapperPass>(); - AU.addRequired<ScalarEvolution>(); + AU.addRequired<ScalarEvolutionWrapperPass>(); AU.addRequired<TargetLibraryInfoWrapperPass>(); AU.addRequired<TargetTransformInfoWrapperPass>(); AU.setPreservesCFG(); @@ -145,12 +145,23 @@ private: unsigned I, Value *LHS, Value *RHS, Type *IndexedType); - // Reassociate Add for better CSE. - Instruction *tryReassociateAdd(BinaryOperator *I); - // A helper function for tryReassociateAdd. LHS and RHS are explicitly passed. - Instruction *tryReassociateAdd(Value *LHS, Value *RHS, Instruction *I); - // Rewrites I to LHS + RHS if LHS is computed already. - Instruction *tryReassociatedAdd(const SCEV *LHS, Value *RHS, Instruction *I); + // Reassociate binary operators for better CSE. + Instruction *tryReassociateBinaryOp(BinaryOperator *I); + + // A helper function for tryReassociateBinaryOp. LHS and RHS are explicitly + // passed. + Instruction *tryReassociateBinaryOp(Value *LHS, Value *RHS, + BinaryOperator *I); + // Rewrites I to (LHS op RHS) if LHS is computed already. + Instruction *tryReassociatedBinaryOp(const SCEV *LHS, Value *RHS, + BinaryOperator *I); + + // Tries to match Op1 and Op2 by using V. + bool matchTernaryOp(BinaryOperator *I, Value *V, Value *&Op1, Value *&Op2); + + // Gets SCEV for (LHS op RHS). + const SCEV *getBinarySCEV(BinaryOperator *I, const SCEV *LHS, + const SCEV *RHS); // Returns the closest dominator of \c Dominatee that computes // \c CandidateExpr. Returns null if not found. @@ -161,11 +172,6 @@ private: // GEP's pointer size, i.e., whether Index needs to be sign-extended in order // to be an index of GEP. bool requiresSignExtension(Value *Index, GetElementPtrInst *GEP); - // Returns whether V is known to be non-negative at context \c Ctxt. - bool isKnownNonNegative(Value *V, Instruction *Ctxt); - // Returns whether AO may sign overflow at context \c Ctxt. It computes a - // conservative result -- it answers true when not sure. - bool maySignOverflow(AddOperator *AO, Instruction *Ctxt); AssumptionCache *AC; const DataLayout *DL; @@ -182,7 +188,7 @@ private: // foo(a + b); // if (p2) // bar(a + b); - DenseMap<const SCEV *, SmallVector<Instruction *, 2>> SeenExprs; + DenseMap<const SCEV *, SmallVector<WeakVH, 2>> SeenExprs; }; } // anonymous namespace @@ -191,7 +197,7 @@ INITIALIZE_PASS_BEGIN(NaryReassociate, "nary-reassociate", "Nary reassociation", false, false) INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(ScalarEvolution) +INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) INITIALIZE_PASS_END(NaryReassociate, "nary-reassociate", "Nary reassociation", @@ -207,7 +213,7 @@ bool NaryReassociate::runOnFunction(Function &F) { AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - SE = &getAnalysis<ScalarEvolution>(); + SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); @@ -224,6 +230,7 @@ static bool isPotentiallyNaryReassociable(Instruction *I) { switch (I->getOpcode()) { case Instruction::Add: case Instruction::GetElementPtr: + case Instruction::Mul: return true; default: return false; @@ -239,19 +246,21 @@ bool NaryReassociate::doOneIteration(Function &F) { Node != GraphTraits<DominatorTree *>::nodes_end(DT); ++Node) { BasicBlock *BB = Node->getBlock(); for (auto I = BB->begin(); I != BB->end(); ++I) { - if (SE->isSCEVable(I->getType()) && isPotentiallyNaryReassociable(I)) { - const SCEV *OldSCEV = SE->getSCEV(I); - if (Instruction *NewI = tryReassociate(I)) { + if (SE->isSCEVable(I->getType()) && isPotentiallyNaryReassociable(&*I)) { + const SCEV *OldSCEV = SE->getSCEV(&*I); + if (Instruction *NewI = tryReassociate(&*I)) { Changed = true; - SE->forgetValue(I); + SE->forgetValue(&*I); I->replaceAllUsesWith(NewI); - RecursivelyDeleteTriviallyDeadInstructions(I, TLI); - I = NewI; + // If SeenExprs constains I's WeakVH, that entry will be replaced with + // nullptr. + RecursivelyDeleteTriviallyDeadInstructions(&*I, TLI); + I = NewI->getIterator(); } // Add the rewritten instruction to SeenExprs; the original instruction // is deleted. - const SCEV *NewSCEV = SE->getSCEV(I); - SeenExprs[NewSCEV].push_back(I); + const SCEV *NewSCEV = SE->getSCEV(&*I); + SeenExprs[NewSCEV].push_back(WeakVH(&*I)); // Ideally, NewSCEV should equal OldSCEV because tryReassociate(I) // is equivalent to I. However, ScalarEvolution::getSCEV may // weaken nsw causing NewSCEV not to equal OldSCEV. For example, suppose @@ -271,7 +280,7 @@ bool NaryReassociate::doOneIteration(Function &F) { // // This improvement is exercised in @reassociate_gep_nsw in nary-gep.ll. if (NewSCEV != OldSCEV) - SeenExprs[OldSCEV].push_back(I); + SeenExprs[OldSCEV].push_back(WeakVH(&*I)); } } } @@ -281,7 +290,8 @@ bool NaryReassociate::doOneIteration(Function &F) { Instruction *NaryReassociate::tryReassociate(Instruction *I) { switch (I->getOpcode()) { case Instruction::Add: - return tryReassociateAdd(cast<BinaryOperator>(I)); + case Instruction::Mul: + return tryReassociateBinaryOp(cast<BinaryOperator>(I)); case Instruction::GetElementPtr: return tryReassociateGEP(cast<GetElementPtrInst>(I)); default: @@ -352,27 +362,6 @@ bool NaryReassociate::requiresSignExtension(Value *Index, return cast<IntegerType>(Index->getType())->getBitWidth() < PointerSizeInBits; } -bool NaryReassociate::isKnownNonNegative(Value *V, Instruction *Ctxt) { - bool NonNegative, Negative; - // TODO: ComputeSignBits is expensive. Consider caching the results. - ComputeSignBit(V, NonNegative, Negative, *DL, 0, AC, Ctxt, DT); - return NonNegative; -} - -bool NaryReassociate::maySignOverflow(AddOperator *AO, Instruction *Ctxt) { - if (AO->hasNoSignedWrap()) - return false; - - Value *LHS = AO->getOperand(0), *RHS = AO->getOperand(1); - // If LHS or RHS has the same sign as the sum, AO doesn't sign overflow. - // TODO: handle the negative case as well. - if (isKnownNonNegative(AO, Ctxt) && - (isKnownNonNegative(LHS, Ctxt) || isKnownNonNegative(RHS, Ctxt))) - return false; - - return true; -} - GetElementPtrInst * NaryReassociate::tryReassociateGEPAtIndex(GetElementPtrInst *GEP, unsigned I, Type *IndexedType) { @@ -381,7 +370,7 @@ NaryReassociate::tryReassociateGEPAtIndex(GetElementPtrInst *GEP, unsigned I, IndexToSplit = SExt->getOperand(0); } else if (ZExtInst *ZExt = dyn_cast<ZExtInst>(IndexToSplit)) { // zext can be treated as sext if the source is non-negative. - if (isKnownNonNegative(ZExt->getOperand(0), GEP)) + if (isKnownNonNegative(ZExt->getOperand(0), *DL, 0, AC, GEP, DT)) IndexToSplit = ZExt->getOperand(0); } @@ -389,8 +378,11 @@ NaryReassociate::tryReassociateGEPAtIndex(GetElementPtrInst *GEP, unsigned I, // If the I-th index needs sext and the underlying add is not equipped with // nsw, we cannot split the add because // sext(LHS + RHS) != sext(LHS) + sext(RHS). - if (requiresSignExtension(IndexToSplit, GEP) && maySignOverflow(AO, GEP)) + if (requiresSignExtension(IndexToSplit, GEP) && + computeOverflowForSignedAdd(AO, *DL, AC, GEP, DT) != + OverflowResult::NeverOverflows) return nullptr; + Value *LHS = AO->getOperand(0), *RHS = AO->getOperand(1); // IndexToSplit = LHS + RHS. if (auto *NewGEP = tryReassociateGEPAtIndex(GEP, I, LHS, RHS, IndexedType)) @@ -415,7 +407,7 @@ GetElementPtrInst *NaryReassociate::tryReassociateGEPAtIndex( IndexExprs.push_back(SE->getSCEV(*Index)); // Replace the I-th index with LHS. IndexExprs[I] = SE->getSCEV(LHS); - if (isKnownNonNegative(LHS, GEP) && + if (isKnownNonNegative(LHS, *DL, 0, AC, GEP, DT) && DL->getTypeSizeInBits(LHS->getType()) < DL->getTypeSizeInBits(GEP->getOperand(I)->getType())) { // Zero-extend LHS if it is non-negative. InstCombine canonicalizes sext to @@ -429,19 +421,20 @@ GetElementPtrInst *NaryReassociate::tryReassociateGEPAtIndex( GEP->getSourceElementType(), SE->getSCEV(GEP->getPointerOperand()), IndexExprs, GEP->isInBounds()); - auto *Candidate = findClosestMatchingDominator(CandidateExpr, GEP); + Value *Candidate = findClosestMatchingDominator(CandidateExpr, GEP); if (Candidate == nullptr) return nullptr; - PointerType *TypeOfCandidate = dyn_cast<PointerType>(Candidate->getType()); - // Pretty rare but theoretically possible when a numeric value happens to - // share CandidateExpr. - if (TypeOfCandidate == nullptr) - return nullptr; + IRBuilder<> Builder(GEP); + // Candidate does not necessarily have the same pointer type as GEP. Use + // bitcast or pointer cast to make sure they have the same type, so that the + // later RAUW doesn't complain. + Candidate = Builder.CreateBitOrPointerCast(Candidate, GEP->getType()); + assert(Candidate->getType() == GEP->getType()); // NewGEP = (char *)Candidate + RHS * sizeof(IndexedType) uint64_t IndexedSize = DL->getTypeAllocSize(IndexedType); - Type *ElementType = TypeOfCandidate->getElementType(); + Type *ElementType = GEP->getType()->getElementType(); uint64_t ElementSize = DL->getTypeAllocSize(ElementType); // Another less rare case: because I is not necessarily the last index of the // GEP, the size of the type at the I-th index (IndexedSize) is not @@ -461,8 +454,7 @@ GetElementPtrInst *NaryReassociate::tryReassociateGEPAtIndex( return nullptr; // NewGEP = &Candidate[RHS * (sizeof(IndexedType) / sizeof(Candidate[0]))); - IRBuilder<> Builder(GEP); - Type *IntPtrTy = DL->getIntPtrType(TypeOfCandidate); + Type *IntPtrTy = DL->getIntPtrType(GEP->getType()); if (RHS->getType() != IntPtrTy) RHS = Builder.CreateSExtOrTrunc(RHS, IntPtrTy); if (IndexedSize != ElementSize) { @@ -476,54 +468,89 @@ GetElementPtrInst *NaryReassociate::tryReassociateGEPAtIndex( return NewGEP; } -Instruction *NaryReassociate::tryReassociateAdd(BinaryOperator *I) { +Instruction *NaryReassociate::tryReassociateBinaryOp(BinaryOperator *I) { Value *LHS = I->getOperand(0), *RHS = I->getOperand(1); - if (auto *NewI = tryReassociateAdd(LHS, RHS, I)) + if (auto *NewI = tryReassociateBinaryOp(LHS, RHS, I)) return NewI; - if (auto *NewI = tryReassociateAdd(RHS, LHS, I)) + if (auto *NewI = tryReassociateBinaryOp(RHS, LHS, I)) return NewI; return nullptr; } -Instruction *NaryReassociate::tryReassociateAdd(Value *LHS, Value *RHS, - Instruction *I) { +Instruction *NaryReassociate::tryReassociateBinaryOp(Value *LHS, Value *RHS, + BinaryOperator *I) { Value *A = nullptr, *B = nullptr; - // To be conservative, we reassociate I only when it is the only user of A+B. - if (LHS->hasOneUse() && match(LHS, m_Add(m_Value(A), m_Value(B)))) { - // I = (A + B) + RHS - // = (A + RHS) + B or (B + RHS) + A + // To be conservative, we reassociate I only when it is the only user of (A op + // B). + if (LHS->hasOneUse() && matchTernaryOp(I, LHS, A, B)) { + // I = (A op B) op RHS + // = (A op RHS) op B or (B op RHS) op A const SCEV *AExpr = SE->getSCEV(A), *BExpr = SE->getSCEV(B); const SCEV *RHSExpr = SE->getSCEV(RHS); if (BExpr != RHSExpr) { - if (auto *NewI = tryReassociatedAdd(SE->getAddExpr(AExpr, RHSExpr), B, I)) + if (auto *NewI = + tryReassociatedBinaryOp(getBinarySCEV(I, AExpr, RHSExpr), B, I)) return NewI; } if (AExpr != RHSExpr) { - if (auto *NewI = tryReassociatedAdd(SE->getAddExpr(BExpr, RHSExpr), A, I)) + if (auto *NewI = + tryReassociatedBinaryOp(getBinarySCEV(I, BExpr, RHSExpr), A, I)) return NewI; } } return nullptr; } -Instruction *NaryReassociate::tryReassociatedAdd(const SCEV *LHSExpr, - Value *RHS, Instruction *I) { - auto Pos = SeenExprs.find(LHSExpr); - // Bail out if LHSExpr is not previously seen. - if (Pos == SeenExprs.end()) - return nullptr; - +Instruction *NaryReassociate::tryReassociatedBinaryOp(const SCEV *LHSExpr, + Value *RHS, + BinaryOperator *I) { // Look for the closest dominator LHS of I that computes LHSExpr, and replace - // I with LHS + RHS. + // I with LHS op RHS. auto *LHS = findClosestMatchingDominator(LHSExpr, I); if (LHS == nullptr) return nullptr; - Instruction *NewI = BinaryOperator::CreateAdd(LHS, RHS, "", I); + Instruction *NewI = nullptr; + switch (I->getOpcode()) { + case Instruction::Add: + NewI = BinaryOperator::CreateAdd(LHS, RHS, "", I); + break; + case Instruction::Mul: + NewI = BinaryOperator::CreateMul(LHS, RHS, "", I); + break; + default: + llvm_unreachable("Unexpected instruction."); + } NewI->takeName(I); return NewI; } +bool NaryReassociate::matchTernaryOp(BinaryOperator *I, Value *V, Value *&Op1, + Value *&Op2) { + switch (I->getOpcode()) { + case Instruction::Add: + return match(V, m_Add(m_Value(Op1), m_Value(Op2))); + case Instruction::Mul: + return match(V, m_Mul(m_Value(Op1), m_Value(Op2))); + default: + llvm_unreachable("Unexpected instruction."); + } + return false; +} + +const SCEV *NaryReassociate::getBinarySCEV(BinaryOperator *I, const SCEV *LHS, + const SCEV *RHS) { + switch (I->getOpcode()) { + case Instruction::Add: + return SE->getAddExpr(LHS, RHS); + case Instruction::Mul: + return SE->getMulExpr(LHS, RHS); + default: + llvm_unreachable("Unexpected instruction."); + } + return nullptr; +} + Instruction * NaryReassociate::findClosestMatchingDominator(const SCEV *CandidateExpr, Instruction *Dominatee) { @@ -537,9 +564,13 @@ NaryReassociate::findClosestMatchingDominator(const SCEV *CandidateExpr, // future instruction either. Therefore, we pop it out of the stack. This // optimization makes the algorithm O(n). while (!Candidates.empty()) { - Instruction *Candidate = Candidates.back(); - if (DT->dominates(Candidate, Dominatee)) - return Candidate; + // Candidates stores WeakVHs, so a candidate can be nullptr if it's removed + // during rewriting. + if (Value *Candidate = Candidates.back()) { + Instruction *CandidateInstruction = cast<Instruction>(Candidate); + if (DT->dominates(CandidateInstruction, Dominatee)) + return CandidateInstruction; + } Candidates.pop_back(); } return nullptr; |