diff options
Diffstat (limited to 'llvm/lib/Transforms/Scalar/LoopFlatten.cpp')
-rw-r--r-- | llvm/lib/Transforms/Scalar/LoopFlatten.cpp | 245 |
1 files changed, 129 insertions, 116 deletions
diff --git a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp index aaff68436c13..f54289f85ef5 100644 --- a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp +++ b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp @@ -63,7 +63,7 @@ static cl::opt<bool> AssumeNoOverflow("loop-flatten-assume-no-overflow", cl::Hidden, cl::init(false), cl::desc("Assume that the product of the two iteration " - "limits will never overflow")); + "trip counts will never overflow")); static cl::opt<bool> WidenIV("loop-flatten-widen-iv", cl::Hidden, @@ -74,10 +74,12 @@ static cl::opt<bool> struct FlattenInfo { Loop *OuterLoop = nullptr; Loop *InnerLoop = nullptr; + // These PHINodes correspond to loop induction variables, which are expected + // to start at zero and increment by one on each loop. PHINode *InnerInductionPHI = nullptr; PHINode *OuterInductionPHI = nullptr; - Value *InnerLimit = nullptr; - Value *OuterLimit = nullptr; + Value *InnerTripCount = nullptr; + Value *OuterTripCount = nullptr; BinaryOperator *InnerIncrement = nullptr; BinaryOperator *OuterIncrement = nullptr; BranchInst *InnerBranch = nullptr; @@ -91,12 +93,12 @@ struct FlattenInfo { FlattenInfo(Loop *OL, Loop *IL) : OuterLoop(OL), InnerLoop(IL) {}; }; -// Finds the induction variable, increment and limit for a simple loop that we -// can flatten. +// Finds the induction variable, increment and trip count for a simple loop that +// we can flatten. static bool findLoopComponents( Loop *L, SmallPtrSetImpl<Instruction *> &IterationInstructions, - PHINode *&InductionPHI, Value *&Limit, BinaryOperator *&Increment, - BranchInst *&BackBranch, ScalarEvolution *SE) { + PHINode *&InductionPHI, Value *&TripCount, BinaryOperator *&Increment, + BranchInst *&BackBranch, ScalarEvolution *SE, bool IsWidened) { LLVM_DEBUG(dbgs() << "Finding components of loop: " << L->getName() << "\n"); if (!L->isLoopSimplifyForm()) { @@ -104,6 +106,13 @@ static bool findLoopComponents( return false; } + // Currently, to simplify the implementation, the Loop induction variable must + // start at zero and increment with a step size of one. + if (!L->isCanonical(*SE)) { + LLVM_DEBUG(dbgs() << "Loop is not canonical\n"); + return false; + } + // There must be exactly one exiting block, and it must be the same at the // latch. BasicBlock *Latch = L->getLoopLatch(); @@ -111,33 +120,18 @@ static bool findLoopComponents( LLVM_DEBUG(dbgs() << "Exiting and latch block are different\n"); return false; } - // Latch block must end in a conditional branch. - BackBranch = dyn_cast<BranchInst>(Latch->getTerminator()); - if (!BackBranch || !BackBranch->isConditional()) { - LLVM_DEBUG(dbgs() << "Could not find back-branch\n"); - return false; - } - IterationInstructions.insert(BackBranch); - LLVM_DEBUG(dbgs() << "Found back branch: "; BackBranch->dump()); - bool ContinueOnTrue = L->contains(BackBranch->getSuccessor(0)); // Find the induction PHI. If there is no induction PHI, we can't do the // transformation. TODO: could other variables trigger this? Do we have to // search for the best one? - InductionPHI = nullptr; - for (PHINode &PHI : L->getHeader()->phis()) { - InductionDescriptor ID; - if (InductionDescriptor::isInductionPHI(&PHI, L, SE, ID)) { - InductionPHI = &PHI; - LLVM_DEBUG(dbgs() << "Found induction PHI: "; InductionPHI->dump()); - break; - } - } + InductionPHI = L->getInductionVariable(*SE); if (!InductionPHI) { LLVM_DEBUG(dbgs() << "Could not find induction PHI\n"); return false; } + LLVM_DEBUG(dbgs() << "Found induction PHI: "; InductionPHI->dump()); + bool ContinueOnTrue = L->contains(Latch->getTerminator()->getSuccessor(0)); auto IsValidPredicate = [&](ICmpInst::Predicate Pred) { if (ContinueOnTrue) return Pred == CmpInst::ICMP_NE || Pred == CmpInst::ICMP_ULT; @@ -145,53 +139,64 @@ static bool findLoopComponents( return Pred == CmpInst::ICMP_EQ; }; - // Find Compare and make sure it is valid - ICmpInst *Compare = dyn_cast<ICmpInst>(BackBranch->getCondition()); + // Find Compare and make sure it is valid. getLatchCmpInst checks that the + // back branch of the latch is conditional. + ICmpInst *Compare = L->getLatchCmpInst(); if (!Compare || !IsValidPredicate(Compare->getUnsignedPredicate()) || Compare->hasNUsesOrMore(2)) { LLVM_DEBUG(dbgs() << "Could not find valid comparison\n"); return false; } + BackBranch = cast<BranchInst>(Latch->getTerminator()); + IterationInstructions.insert(BackBranch); + LLVM_DEBUG(dbgs() << "Found back branch: "; BackBranch->dump()); IterationInstructions.insert(Compare); LLVM_DEBUG(dbgs() << "Found comparison: "; Compare->dump()); - // Find increment and limit from the compare - Increment = nullptr; - if (match(Compare->getOperand(0), - m_c_Add(m_Specific(InductionPHI), m_ConstantInt<1>()))) { - Increment = dyn_cast<BinaryOperator>(Compare->getOperand(0)); - Limit = Compare->getOperand(1); - } else if (Compare->getUnsignedPredicate() == CmpInst::ICMP_NE && - match(Compare->getOperand(1), - m_c_Add(m_Specific(InductionPHI), m_ConstantInt<1>()))) { - Increment = dyn_cast<BinaryOperator>(Compare->getOperand(1)); - Limit = Compare->getOperand(0); - } - if (!Increment || Increment->hasNUsesOrMore(3)) { - LLVM_DEBUG(dbgs() << "Cound not find valid increment\n"); + // Find increment and trip count. + // There are exactly 2 incoming values to the induction phi; one from the + // pre-header and one from the latch. The incoming latch value is the + // increment variable. + Increment = + dyn_cast<BinaryOperator>(InductionPHI->getIncomingValueForBlock(Latch)); + if (Increment->hasNUsesOrMore(3)) { + LLVM_DEBUG(dbgs() << "Could not find valid increment\n"); return false; } + // The trip count is the RHS of the compare. If this doesn't match the trip + // count computed by SCEV then this is either because the trip count variable + // has been widened (then leave the trip count as it is), or because it is a + // constant and another transformation has changed the compare, e.g. + // icmp ult %inc, tripcount -> icmp ult %j, tripcount-1, then we don't flatten + // the loop (yet). + TripCount = Compare->getOperand(1); + const SCEV *SCEVTripCount = + SE->getTripCountFromExitCount(SE->getBackedgeTakenCount(L)); + if (SE->getSCEV(TripCount) != SCEVTripCount) { + if (!IsWidened) { + LLVM_DEBUG(dbgs() << "Could not find valid trip count\n"); + return false; + } + auto TripCountInst = dyn_cast<Instruction>(TripCount); + if (!TripCountInst) { + LLVM_DEBUG(dbgs() << "Could not find valid extended trip count\n"); + return false; + } + if ((!isa<ZExtInst>(TripCountInst) && !isa<SExtInst>(TripCountInst)) || + SE->getSCEV(TripCountInst->getOperand(0)) != SCEVTripCount) { + LLVM_DEBUG(dbgs() << "Could not find valid extended trip count\n"); + return false; + } + } IterationInstructions.insert(Increment); LLVM_DEBUG(dbgs() << "Found increment: "; Increment->dump()); - LLVM_DEBUG(dbgs() << "Found limit: "; Limit->dump()); - - assert(InductionPHI->getNumIncomingValues() == 2); - assert(InductionPHI->getIncomingValueForBlock(Latch) == Increment && - "PHI value is not increment inst"); - - auto *CI = dyn_cast<ConstantInt>( - InductionPHI->getIncomingValueForBlock(L->getLoopPreheader())); - if (!CI || !CI->isZero()) { - LLVM_DEBUG(dbgs() << "PHI value is not zero: "; CI->dump()); - return false; - } + LLVM_DEBUG(dbgs() << "Found trip count: "; TripCount->dump()); LLVM_DEBUG(dbgs() << "Successfully found all loop components\n"); return true; } -static bool checkPHIs(struct FlattenInfo &FI, - const TargetTransformInfo *TTI) { +static bool checkPHIs(FlattenInfo &FI, const TargetTransformInfo *TTI) { // All PHIs in the inner and outer headers must either be: // - The induction PHI, which we are going to rewrite as one induction in // the new loop. This is already checked by findLoopComponents. @@ -272,7 +277,7 @@ static bool checkPHIs(struct FlattenInfo &FI, } static bool -checkOuterLoopInsts(struct FlattenInfo &FI, +checkOuterLoopInsts(FlattenInfo &FI, SmallPtrSetImpl<Instruction *> &IterationInstructions, const TargetTransformInfo *TTI) { // Check for instructions in the outer but not inner loop. If any of these @@ -280,7 +285,7 @@ checkOuterLoopInsts(struct FlattenInfo &FI, // a significant amount of code here which can't be optimised out that it's // not profitable (as these instructions would get executed for each // iteration of the inner loop). - unsigned RepeatedInstrCost = 0; + InstructionCost RepeatedInstrCost = 0; for (auto *B : FI.OuterLoop->getBlocks()) { if (FI.InnerLoop->contains(B)) continue; @@ -308,9 +313,10 @@ checkOuterLoopInsts(struct FlattenInfo &FI, // Multiplies of the outer iteration variable and inner iteration // count will be optimised out. if (match(&I, m_c_Mul(m_Specific(FI.OuterInductionPHI), - m_Specific(FI.InnerLimit)))) + m_Specific(FI.InnerTripCount)))) continue; - int Cost = TTI->getUserCost(&I, TargetTransformInfo::TCK_SizeAndLatency); + InstructionCost Cost = + TTI->getUserCost(&I, TargetTransformInfo::TCK_SizeAndLatency); LLVM_DEBUG(dbgs() << "Cost " << Cost << ": "; I.dump()); RepeatedInstrCost += Cost; } @@ -329,19 +335,19 @@ checkOuterLoopInsts(struct FlattenInfo &FI, return true; } -static bool checkIVUsers(struct FlattenInfo &FI) { +static bool checkIVUsers(FlattenInfo &FI) { // We require all uses of both induction variables to match this pattern: // - // (OuterPHI * InnerLimit) + InnerPHI + // (OuterPHI * InnerTripCount) + InnerPHI // // Any uses of the induction variables not matching that pattern would // require a div/mod to reconstruct in the flattened loop, so the // transformation wouldn't be profitable. - Value *InnerLimit = FI.InnerLimit; + Value *InnerTripCount = FI.InnerTripCount; if (FI.Widened && - (isa<SExtInst>(InnerLimit) || isa<ZExtInst>(InnerLimit))) - InnerLimit = cast<Instruction>(InnerLimit)->getOperand(0); + (isa<SExtInst>(InnerTripCount) || isa<ZExtInst>(InnerTripCount))) + InnerTripCount = cast<Instruction>(InnerTripCount)->getOperand(0); // Check that all uses of the inner loop's induction variable match the // expected pattern, recording the uses of the outer IV. @@ -375,7 +381,7 @@ static bool checkIVUsers(struct FlattenInfo &FI) { m_c_Mul(m_Trunc(m_Specific(FI.OuterInductionPHI)), m_Value(MatchedItCount))); - if ((IsAdd || IsAddTrunc) && MatchedItCount == InnerLimit) { + if ((IsAdd || IsAddTrunc) && MatchedItCount == InnerTripCount) { LLVM_DEBUG(dbgs() << "Use is optimisable\n"); ValidOuterPHIUses.insert(MatchedMul); FI.LinearIVUses.insert(U); @@ -424,9 +430,9 @@ static bool checkIVUsers(struct FlattenInfo &FI) { } // Return an OverflowResult dependant on if overflow of the multiplication of -// InnerLimit and OuterLimit can be assumed not to happen. -static OverflowResult checkOverflow(struct FlattenInfo &FI, - DominatorTree *DT, AssumptionCache *AC) { +// InnerTripCount and OuterTripCount can be assumed not to happen. +static OverflowResult checkOverflow(FlattenInfo &FI, DominatorTree *DT, + AssumptionCache *AC) { Function *F = FI.OuterLoop->getHeader()->getParent(); const DataLayout &DL = F->getParent()->getDataLayout(); @@ -437,7 +443,7 @@ static OverflowResult checkOverflow(struct FlattenInfo &FI, // Check if the multiply could not overflow due to known ranges of the // input values. OverflowResult OR = computeOverflowForUnsignedMul( - FI.InnerLimit, FI.OuterLimit, DL, AC, + FI.InnerTripCount, FI.OuterTripCount, DL, AC, FI.OuterLoop->getLoopPreheader()->getTerminator(), DT); if (OR != OverflowResult::MayOverflow) return OR; @@ -464,25 +470,27 @@ static OverflowResult checkOverflow(struct FlattenInfo &FI, return OverflowResult::MayOverflow; } -static bool CanFlattenLoopPair(struct FlattenInfo &FI, DominatorTree *DT, - LoopInfo *LI, ScalarEvolution *SE, - AssumptionCache *AC, const TargetTransformInfo *TTI) { +static bool CanFlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI, + ScalarEvolution *SE, AssumptionCache *AC, + const TargetTransformInfo *TTI) { SmallPtrSet<Instruction *, 8> IterationInstructions; - if (!findLoopComponents(FI.InnerLoop, IterationInstructions, FI.InnerInductionPHI, - FI.InnerLimit, FI.InnerIncrement, FI.InnerBranch, SE)) + if (!findLoopComponents(FI.InnerLoop, IterationInstructions, + FI.InnerInductionPHI, FI.InnerTripCount, + FI.InnerIncrement, FI.InnerBranch, SE, FI.Widened)) return false; - if (!findLoopComponents(FI.OuterLoop, IterationInstructions, FI.OuterInductionPHI, - FI.OuterLimit, FI.OuterIncrement, FI.OuterBranch, SE)) + if (!findLoopComponents(FI.OuterLoop, IterationInstructions, + FI.OuterInductionPHI, FI.OuterTripCount, + FI.OuterIncrement, FI.OuterBranch, SE, FI.Widened)) return false; - // Both of the loop limit values must be invariant in the outer loop + // Both of the loop trip count values must be invariant in the outer loop // (non-instructions are all inherently invariant). - if (!FI.OuterLoop->isLoopInvariant(FI.InnerLimit)) { - LLVM_DEBUG(dbgs() << "inner loop limit not invariant\n"); + if (!FI.OuterLoop->isLoopInvariant(FI.InnerTripCount)) { + LLVM_DEBUG(dbgs() << "inner loop trip count not invariant\n"); return false; } - if (!FI.OuterLoop->isLoopInvariant(FI.OuterLimit)) { - LLVM_DEBUG(dbgs() << "outer loop limit not invariant\n"); + if (!FI.OuterLoop->isLoopInvariant(FI.OuterTripCount)) { + LLVM_DEBUG(dbgs() << "outer loop trip count not invariant\n"); return false; } @@ -508,9 +516,8 @@ static bool CanFlattenLoopPair(struct FlattenInfo &FI, DominatorTree *DT, return true; } -static bool DoFlattenLoopPair(struct FlattenInfo &FI, DominatorTree *DT, - LoopInfo *LI, ScalarEvolution *SE, - AssumptionCache *AC, +static bool DoFlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI, + ScalarEvolution *SE, AssumptionCache *AC, const TargetTransformInfo *TTI) { Function *F = FI.OuterLoop->getHeader()->getParent(); LLVM_DEBUG(dbgs() << "Checks all passed, doing the transformation\n"); @@ -523,9 +530,9 @@ static bool DoFlattenLoopPair(struct FlattenInfo &FI, DominatorTree *DT, ORE.emit(Remark); } - Value *NewTripCount = - BinaryOperator::CreateMul(FI.InnerLimit, FI.OuterLimit, "flatten.tripcount", - FI.OuterLoop->getLoopPreheader()->getTerminator()); + Value *NewTripCount = BinaryOperator::CreateMul( + FI.InnerTripCount, FI.OuterTripCount, "flatten.tripcount", + FI.OuterLoop->getLoopPreheader()->getTerminator()); LLVM_DEBUG(dbgs() << "Created new trip count in preheader: "; NewTripCount->dump()); @@ -571,9 +578,9 @@ static bool DoFlattenLoopPair(struct FlattenInfo &FI, DominatorTree *DT, return true; } -static bool CanWidenIV(struct FlattenInfo &FI, DominatorTree *DT, - LoopInfo *LI, ScalarEvolution *SE, - AssumptionCache *AC, const TargetTransformInfo *TTI) { +static bool CanWidenIV(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI, + ScalarEvolution *SE, AssumptionCache *AC, + const TargetTransformInfo *TTI) { if (!WidenIV) { LLVM_DEBUG(dbgs() << "Widening the IVs is disabled\n"); return false; @@ -589,7 +596,7 @@ static bool CanWidenIV(struct FlattenInfo &FI, DominatorTree *DT, // If both induction types are less than the maximum legal integer width, // promote both to the widest type available so we know calculating - // (OuterLimit * InnerLimit) as the new trip count is safe. + // (OuterTripCount * InnerTripCount) as the new trip count is safe. if (InnerType != OuterType || InnerType->getScalarSizeInBits() >= MaxLegalSize || MaxLegalType->getScalarSizeInBits() < InnerType->getScalarSizeInBits() * 2) { @@ -602,28 +609,27 @@ static bool CanWidenIV(struct FlattenInfo &FI, DominatorTree *DT, SmallVector<WeakTrackingVH, 4> DeadInsts; WideIVs.push_back( {FI.InnerInductionPHI, MaxLegalType, false }); WideIVs.push_back( {FI.OuterInductionPHI, MaxLegalType, false }); - unsigned ElimExt; - unsigned Widened; + unsigned ElimExt = 0; + unsigned Widened = 0; - for (unsigned i = 0; i < WideIVs.size(); i++) { - PHINode *WidePhi = createWideIV(WideIVs[i], LI, SE, Rewriter, DT, DeadInsts, + for (const auto &WideIV : WideIVs) { + PHINode *WidePhi = createWideIV(WideIV, LI, SE, Rewriter, DT, DeadInsts, ElimExt, Widened, true /* HasGuards */, true /* UsePostIncrementRanges */); if (!WidePhi) return false; LLVM_DEBUG(dbgs() << "Created wide phi: "; WidePhi->dump()); - LLVM_DEBUG(dbgs() << "Deleting old phi: "; WideIVs[i].NarrowIV->dump()); - RecursivelyDeleteDeadPHINode(WideIVs[i].NarrowIV); + LLVM_DEBUG(dbgs() << "Deleting old phi: "; WideIV.NarrowIV->dump()); + RecursivelyDeleteDeadPHINode(WideIV.NarrowIV); } // After widening, rediscover all the loop components. - assert(Widened && "Widenend IV expected"); + assert(Widened && "Widened IV expected"); FI.Widened = true; return CanFlattenLoopPair(FI, DT, LI, SE, AC, TTI); } -static bool FlattenLoopPair(struct FlattenInfo &FI, DominatorTree *DT, - LoopInfo *LI, ScalarEvolution *SE, - AssumptionCache *AC, +static bool FlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI, + ScalarEvolution *SE, AssumptionCache *AC, const TargetTransformInfo *TTI) { LLVM_DEBUG( dbgs() << "Loop flattening running on outer loop " @@ -656,33 +662,35 @@ static bool FlattenLoopPair(struct FlattenInfo &FI, DominatorTree *DT, return DoFlattenLoopPair(FI, DT, LI, SE, AC, TTI); } -bool Flatten(DominatorTree *DT, LoopInfo *LI, ScalarEvolution *SE, +bool Flatten(LoopNest &LN, DominatorTree *DT, LoopInfo *LI, ScalarEvolution *SE, AssumptionCache *AC, TargetTransformInfo *TTI) { bool Changed = false; - for (auto *InnerLoop : LI->getLoopsInPreorder()) { + for (Loop *InnerLoop : LN.getLoops()) { auto *OuterLoop = InnerLoop->getParentLoop(); if (!OuterLoop) continue; - struct FlattenInfo FI(OuterLoop, InnerLoop); + FlattenInfo FI(OuterLoop, InnerLoop); Changed |= FlattenLoopPair(FI, DT, LI, SE, AC, TTI); } return Changed; } -PreservedAnalyses LoopFlattenPass::run(Function &F, - FunctionAnalysisManager &AM) { - auto *DT = &AM.getResult<DominatorTreeAnalysis>(F); - auto *LI = &AM.getResult<LoopAnalysis>(F); - auto *SE = &AM.getResult<ScalarEvolutionAnalysis>(F); - auto *AC = &AM.getResult<AssumptionAnalysis>(F); - auto *TTI = &AM.getResult<TargetIRAnalysis>(F); +PreservedAnalyses LoopFlattenPass::run(LoopNest &LN, LoopAnalysisManager &LAM, + LoopStandardAnalysisResults &AR, + LPMUpdater &U) { + + bool Changed = false; + + // The loop flattening pass requires loops to be + // in simplified form, and also needs LCSSA. Running + // this pass will simplify all loops that contain inner loops, + // regardless of whether anything ends up being flattened. + Changed |= Flatten(LN, &AR.DT, &AR.LI, &AR.SE, &AR.AC, &AR.TTI); - if (!Flatten(DT, LI, SE, AC, TTI)) + if (!Changed) return PreservedAnalyses::all(); - PreservedAnalyses PA; - PA.preserveSet<CFGAnalyses>(); - return PA; + return PreservedAnalyses::none(); } namespace { @@ -724,5 +732,10 @@ bool LoopFlattenLegacyPass::runOnFunction(Function &F) { auto &TTIP = getAnalysis<TargetTransformInfoWrapperPass>(); auto *TTI = &TTIP.getTTI(F); auto *AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); - return Flatten(DT, LI, SE, AC, TTI); + bool Changed = false; + for (Loop *L : *LI) { + auto LN = LoopNest::getLoopNest(*L, *SE); + Changed |= Flatten(*LN, DT, LI, SE, AC, TTI); + } + return Changed; } |