aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms/Scalar/LoopFlatten.cpp')
-rw-r--r--llvm/lib/Transforms/Scalar/LoopFlatten.cpp245
1 files changed, 129 insertions, 116 deletions
diff --git a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
index aaff68436c13..f54289f85ef5 100644
--- a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
@@ -63,7 +63,7 @@ static cl::opt<bool>
AssumeNoOverflow("loop-flatten-assume-no-overflow", cl::Hidden,
cl::init(false),
cl::desc("Assume that the product of the two iteration "
- "limits will never overflow"));
+ "trip counts will never overflow"));
static cl::opt<bool>
WidenIV("loop-flatten-widen-iv", cl::Hidden,
@@ -74,10 +74,12 @@ static cl::opt<bool>
struct FlattenInfo {
Loop *OuterLoop = nullptr;
Loop *InnerLoop = nullptr;
+ // These PHINodes correspond to loop induction variables, which are expected
+ // to start at zero and increment by one on each loop.
PHINode *InnerInductionPHI = nullptr;
PHINode *OuterInductionPHI = nullptr;
- Value *InnerLimit = nullptr;
- Value *OuterLimit = nullptr;
+ Value *InnerTripCount = nullptr;
+ Value *OuterTripCount = nullptr;
BinaryOperator *InnerIncrement = nullptr;
BinaryOperator *OuterIncrement = nullptr;
BranchInst *InnerBranch = nullptr;
@@ -91,12 +93,12 @@ struct FlattenInfo {
FlattenInfo(Loop *OL, Loop *IL) : OuterLoop(OL), InnerLoop(IL) {};
};
-// Finds the induction variable, increment and limit for a simple loop that we
-// can flatten.
+// Finds the induction variable, increment and trip count for a simple loop that
+// we can flatten.
static bool findLoopComponents(
Loop *L, SmallPtrSetImpl<Instruction *> &IterationInstructions,
- PHINode *&InductionPHI, Value *&Limit, BinaryOperator *&Increment,
- BranchInst *&BackBranch, ScalarEvolution *SE) {
+ PHINode *&InductionPHI, Value *&TripCount, BinaryOperator *&Increment,
+ BranchInst *&BackBranch, ScalarEvolution *SE, bool IsWidened) {
LLVM_DEBUG(dbgs() << "Finding components of loop: " << L->getName() << "\n");
if (!L->isLoopSimplifyForm()) {
@@ -104,6 +106,13 @@ static bool findLoopComponents(
return false;
}
+ // Currently, to simplify the implementation, the Loop induction variable must
+ // start at zero and increment with a step size of one.
+ if (!L->isCanonical(*SE)) {
+ LLVM_DEBUG(dbgs() << "Loop is not canonical\n");
+ return false;
+ }
+
// There must be exactly one exiting block, and it must be the same at the
// latch.
BasicBlock *Latch = L->getLoopLatch();
@@ -111,33 +120,18 @@ static bool findLoopComponents(
LLVM_DEBUG(dbgs() << "Exiting and latch block are different\n");
return false;
}
- // Latch block must end in a conditional branch.
- BackBranch = dyn_cast<BranchInst>(Latch->getTerminator());
- if (!BackBranch || !BackBranch->isConditional()) {
- LLVM_DEBUG(dbgs() << "Could not find back-branch\n");
- return false;
- }
- IterationInstructions.insert(BackBranch);
- LLVM_DEBUG(dbgs() << "Found back branch: "; BackBranch->dump());
- bool ContinueOnTrue = L->contains(BackBranch->getSuccessor(0));
// Find the induction PHI. If there is no induction PHI, we can't do the
// transformation. TODO: could other variables trigger this? Do we have to
// search for the best one?
- InductionPHI = nullptr;
- for (PHINode &PHI : L->getHeader()->phis()) {
- InductionDescriptor ID;
- if (InductionDescriptor::isInductionPHI(&PHI, L, SE, ID)) {
- InductionPHI = &PHI;
- LLVM_DEBUG(dbgs() << "Found induction PHI: "; InductionPHI->dump());
- break;
- }
- }
+ InductionPHI = L->getInductionVariable(*SE);
if (!InductionPHI) {
LLVM_DEBUG(dbgs() << "Could not find induction PHI\n");
return false;
}
+ LLVM_DEBUG(dbgs() << "Found induction PHI: "; InductionPHI->dump());
+ bool ContinueOnTrue = L->contains(Latch->getTerminator()->getSuccessor(0));
auto IsValidPredicate = [&](ICmpInst::Predicate Pred) {
if (ContinueOnTrue)
return Pred == CmpInst::ICMP_NE || Pred == CmpInst::ICMP_ULT;
@@ -145,53 +139,64 @@ static bool findLoopComponents(
return Pred == CmpInst::ICMP_EQ;
};
- // Find Compare and make sure it is valid
- ICmpInst *Compare = dyn_cast<ICmpInst>(BackBranch->getCondition());
+ // Find Compare and make sure it is valid. getLatchCmpInst checks that the
+ // back branch of the latch is conditional.
+ ICmpInst *Compare = L->getLatchCmpInst();
if (!Compare || !IsValidPredicate(Compare->getUnsignedPredicate()) ||
Compare->hasNUsesOrMore(2)) {
LLVM_DEBUG(dbgs() << "Could not find valid comparison\n");
return false;
}
+ BackBranch = cast<BranchInst>(Latch->getTerminator());
+ IterationInstructions.insert(BackBranch);
+ LLVM_DEBUG(dbgs() << "Found back branch: "; BackBranch->dump());
IterationInstructions.insert(Compare);
LLVM_DEBUG(dbgs() << "Found comparison: "; Compare->dump());
- // Find increment and limit from the compare
- Increment = nullptr;
- if (match(Compare->getOperand(0),
- m_c_Add(m_Specific(InductionPHI), m_ConstantInt<1>()))) {
- Increment = dyn_cast<BinaryOperator>(Compare->getOperand(0));
- Limit = Compare->getOperand(1);
- } else if (Compare->getUnsignedPredicate() == CmpInst::ICMP_NE &&
- match(Compare->getOperand(1),
- m_c_Add(m_Specific(InductionPHI), m_ConstantInt<1>()))) {
- Increment = dyn_cast<BinaryOperator>(Compare->getOperand(1));
- Limit = Compare->getOperand(0);
- }
- if (!Increment || Increment->hasNUsesOrMore(3)) {
- LLVM_DEBUG(dbgs() << "Cound not find valid increment\n");
+ // Find increment and trip count.
+ // There are exactly 2 incoming values to the induction phi; one from the
+ // pre-header and one from the latch. The incoming latch value is the
+ // increment variable.
+ Increment =
+ dyn_cast<BinaryOperator>(InductionPHI->getIncomingValueForBlock(Latch));
+ if (Increment->hasNUsesOrMore(3)) {
+ LLVM_DEBUG(dbgs() << "Could not find valid increment\n");
return false;
}
+ // The trip count is the RHS of the compare. If this doesn't match the trip
+ // count computed by SCEV then this is either because the trip count variable
+ // has been widened (then leave the trip count as it is), or because it is a
+ // constant and another transformation has changed the compare, e.g.
+ // icmp ult %inc, tripcount -> icmp ult %j, tripcount-1, then we don't flatten
+ // the loop (yet).
+ TripCount = Compare->getOperand(1);
+ const SCEV *SCEVTripCount =
+ SE->getTripCountFromExitCount(SE->getBackedgeTakenCount(L));
+ if (SE->getSCEV(TripCount) != SCEVTripCount) {
+ if (!IsWidened) {
+ LLVM_DEBUG(dbgs() << "Could not find valid trip count\n");
+ return false;
+ }
+ auto TripCountInst = dyn_cast<Instruction>(TripCount);
+ if (!TripCountInst) {
+ LLVM_DEBUG(dbgs() << "Could not find valid extended trip count\n");
+ return false;
+ }
+ if ((!isa<ZExtInst>(TripCountInst) && !isa<SExtInst>(TripCountInst)) ||
+ SE->getSCEV(TripCountInst->getOperand(0)) != SCEVTripCount) {
+ LLVM_DEBUG(dbgs() << "Could not find valid extended trip count\n");
+ return false;
+ }
+ }
IterationInstructions.insert(Increment);
LLVM_DEBUG(dbgs() << "Found increment: "; Increment->dump());
- LLVM_DEBUG(dbgs() << "Found limit: "; Limit->dump());
-
- assert(InductionPHI->getNumIncomingValues() == 2);
- assert(InductionPHI->getIncomingValueForBlock(Latch) == Increment &&
- "PHI value is not increment inst");
-
- auto *CI = dyn_cast<ConstantInt>(
- InductionPHI->getIncomingValueForBlock(L->getLoopPreheader()));
- if (!CI || !CI->isZero()) {
- LLVM_DEBUG(dbgs() << "PHI value is not zero: "; CI->dump());
- return false;
- }
+ LLVM_DEBUG(dbgs() << "Found trip count: "; TripCount->dump());
LLVM_DEBUG(dbgs() << "Successfully found all loop components\n");
return true;
}
-static bool checkPHIs(struct FlattenInfo &FI,
- const TargetTransformInfo *TTI) {
+static bool checkPHIs(FlattenInfo &FI, const TargetTransformInfo *TTI) {
// All PHIs in the inner and outer headers must either be:
// - The induction PHI, which we are going to rewrite as one induction in
// the new loop. This is already checked by findLoopComponents.
@@ -272,7 +277,7 @@ static bool checkPHIs(struct FlattenInfo &FI,
}
static bool
-checkOuterLoopInsts(struct FlattenInfo &FI,
+checkOuterLoopInsts(FlattenInfo &FI,
SmallPtrSetImpl<Instruction *> &IterationInstructions,
const TargetTransformInfo *TTI) {
// Check for instructions in the outer but not inner loop. If any of these
@@ -280,7 +285,7 @@ checkOuterLoopInsts(struct FlattenInfo &FI,
// a significant amount of code here which can't be optimised out that it's
// not profitable (as these instructions would get executed for each
// iteration of the inner loop).
- unsigned RepeatedInstrCost = 0;
+ InstructionCost RepeatedInstrCost = 0;
for (auto *B : FI.OuterLoop->getBlocks()) {
if (FI.InnerLoop->contains(B))
continue;
@@ -308,9 +313,10 @@ checkOuterLoopInsts(struct FlattenInfo &FI,
// Multiplies of the outer iteration variable and inner iteration
// count will be optimised out.
if (match(&I, m_c_Mul(m_Specific(FI.OuterInductionPHI),
- m_Specific(FI.InnerLimit))))
+ m_Specific(FI.InnerTripCount))))
continue;
- int Cost = TTI->getUserCost(&I, TargetTransformInfo::TCK_SizeAndLatency);
+ InstructionCost Cost =
+ TTI->getUserCost(&I, TargetTransformInfo::TCK_SizeAndLatency);
LLVM_DEBUG(dbgs() << "Cost " << Cost << ": "; I.dump());
RepeatedInstrCost += Cost;
}
@@ -329,19 +335,19 @@ checkOuterLoopInsts(struct FlattenInfo &FI,
return true;
}
-static bool checkIVUsers(struct FlattenInfo &FI) {
+static bool checkIVUsers(FlattenInfo &FI) {
// We require all uses of both induction variables to match this pattern:
//
- // (OuterPHI * InnerLimit) + InnerPHI
+ // (OuterPHI * InnerTripCount) + InnerPHI
//
// Any uses of the induction variables not matching that pattern would
// require a div/mod to reconstruct in the flattened loop, so the
// transformation wouldn't be profitable.
- Value *InnerLimit = FI.InnerLimit;
+ Value *InnerTripCount = FI.InnerTripCount;
if (FI.Widened &&
- (isa<SExtInst>(InnerLimit) || isa<ZExtInst>(InnerLimit)))
- InnerLimit = cast<Instruction>(InnerLimit)->getOperand(0);
+ (isa<SExtInst>(InnerTripCount) || isa<ZExtInst>(InnerTripCount)))
+ InnerTripCount = cast<Instruction>(InnerTripCount)->getOperand(0);
// Check that all uses of the inner loop's induction variable match the
// expected pattern, recording the uses of the outer IV.
@@ -375,7 +381,7 @@ static bool checkIVUsers(struct FlattenInfo &FI) {
m_c_Mul(m_Trunc(m_Specific(FI.OuterInductionPHI)),
m_Value(MatchedItCount)));
- if ((IsAdd || IsAddTrunc) && MatchedItCount == InnerLimit) {
+ if ((IsAdd || IsAddTrunc) && MatchedItCount == InnerTripCount) {
LLVM_DEBUG(dbgs() << "Use is optimisable\n");
ValidOuterPHIUses.insert(MatchedMul);
FI.LinearIVUses.insert(U);
@@ -424,9 +430,9 @@ static bool checkIVUsers(struct FlattenInfo &FI) {
}
// Return an OverflowResult dependant on if overflow of the multiplication of
-// InnerLimit and OuterLimit can be assumed not to happen.
-static OverflowResult checkOverflow(struct FlattenInfo &FI,
- DominatorTree *DT, AssumptionCache *AC) {
+// InnerTripCount and OuterTripCount can be assumed not to happen.
+static OverflowResult checkOverflow(FlattenInfo &FI, DominatorTree *DT,
+ AssumptionCache *AC) {
Function *F = FI.OuterLoop->getHeader()->getParent();
const DataLayout &DL = F->getParent()->getDataLayout();
@@ -437,7 +443,7 @@ static OverflowResult checkOverflow(struct FlattenInfo &FI,
// Check if the multiply could not overflow due to known ranges of the
// input values.
OverflowResult OR = computeOverflowForUnsignedMul(
- FI.InnerLimit, FI.OuterLimit, DL, AC,
+ FI.InnerTripCount, FI.OuterTripCount, DL, AC,
FI.OuterLoop->getLoopPreheader()->getTerminator(), DT);
if (OR != OverflowResult::MayOverflow)
return OR;
@@ -464,25 +470,27 @@ static OverflowResult checkOverflow(struct FlattenInfo &FI,
return OverflowResult::MayOverflow;
}
-static bool CanFlattenLoopPair(struct FlattenInfo &FI, DominatorTree *DT,
- LoopInfo *LI, ScalarEvolution *SE,
- AssumptionCache *AC, const TargetTransformInfo *TTI) {
+static bool CanFlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
+ ScalarEvolution *SE, AssumptionCache *AC,
+ const TargetTransformInfo *TTI) {
SmallPtrSet<Instruction *, 8> IterationInstructions;
- if (!findLoopComponents(FI.InnerLoop, IterationInstructions, FI.InnerInductionPHI,
- FI.InnerLimit, FI.InnerIncrement, FI.InnerBranch, SE))
+ if (!findLoopComponents(FI.InnerLoop, IterationInstructions,
+ FI.InnerInductionPHI, FI.InnerTripCount,
+ FI.InnerIncrement, FI.InnerBranch, SE, FI.Widened))
return false;
- if (!findLoopComponents(FI.OuterLoop, IterationInstructions, FI.OuterInductionPHI,
- FI.OuterLimit, FI.OuterIncrement, FI.OuterBranch, SE))
+ if (!findLoopComponents(FI.OuterLoop, IterationInstructions,
+ FI.OuterInductionPHI, FI.OuterTripCount,
+ FI.OuterIncrement, FI.OuterBranch, SE, FI.Widened))
return false;
- // Both of the loop limit values must be invariant in the outer loop
+ // Both of the loop trip count values must be invariant in the outer loop
// (non-instructions are all inherently invariant).
- if (!FI.OuterLoop->isLoopInvariant(FI.InnerLimit)) {
- LLVM_DEBUG(dbgs() << "inner loop limit not invariant\n");
+ if (!FI.OuterLoop->isLoopInvariant(FI.InnerTripCount)) {
+ LLVM_DEBUG(dbgs() << "inner loop trip count not invariant\n");
return false;
}
- if (!FI.OuterLoop->isLoopInvariant(FI.OuterLimit)) {
- LLVM_DEBUG(dbgs() << "outer loop limit not invariant\n");
+ if (!FI.OuterLoop->isLoopInvariant(FI.OuterTripCount)) {
+ LLVM_DEBUG(dbgs() << "outer loop trip count not invariant\n");
return false;
}
@@ -508,9 +516,8 @@ static bool CanFlattenLoopPair(struct FlattenInfo &FI, DominatorTree *DT,
return true;
}
-static bool DoFlattenLoopPair(struct FlattenInfo &FI, DominatorTree *DT,
- LoopInfo *LI, ScalarEvolution *SE,
- AssumptionCache *AC,
+static bool DoFlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
+ ScalarEvolution *SE, AssumptionCache *AC,
const TargetTransformInfo *TTI) {
Function *F = FI.OuterLoop->getHeader()->getParent();
LLVM_DEBUG(dbgs() << "Checks all passed, doing the transformation\n");
@@ -523,9 +530,9 @@ static bool DoFlattenLoopPair(struct FlattenInfo &FI, DominatorTree *DT,
ORE.emit(Remark);
}
- Value *NewTripCount =
- BinaryOperator::CreateMul(FI.InnerLimit, FI.OuterLimit, "flatten.tripcount",
- FI.OuterLoop->getLoopPreheader()->getTerminator());
+ Value *NewTripCount = BinaryOperator::CreateMul(
+ FI.InnerTripCount, FI.OuterTripCount, "flatten.tripcount",
+ FI.OuterLoop->getLoopPreheader()->getTerminator());
LLVM_DEBUG(dbgs() << "Created new trip count in preheader: ";
NewTripCount->dump());
@@ -571,9 +578,9 @@ static bool DoFlattenLoopPair(struct FlattenInfo &FI, DominatorTree *DT,
return true;
}
-static bool CanWidenIV(struct FlattenInfo &FI, DominatorTree *DT,
- LoopInfo *LI, ScalarEvolution *SE,
- AssumptionCache *AC, const TargetTransformInfo *TTI) {
+static bool CanWidenIV(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
+ ScalarEvolution *SE, AssumptionCache *AC,
+ const TargetTransformInfo *TTI) {
if (!WidenIV) {
LLVM_DEBUG(dbgs() << "Widening the IVs is disabled\n");
return false;
@@ -589,7 +596,7 @@ static bool CanWidenIV(struct FlattenInfo &FI, DominatorTree *DT,
// If both induction types are less than the maximum legal integer width,
// promote both to the widest type available so we know calculating
- // (OuterLimit * InnerLimit) as the new trip count is safe.
+ // (OuterTripCount * InnerTripCount) as the new trip count is safe.
if (InnerType != OuterType ||
InnerType->getScalarSizeInBits() >= MaxLegalSize ||
MaxLegalType->getScalarSizeInBits() < InnerType->getScalarSizeInBits() * 2) {
@@ -602,28 +609,27 @@ static bool CanWidenIV(struct FlattenInfo &FI, DominatorTree *DT,
SmallVector<WeakTrackingVH, 4> DeadInsts;
WideIVs.push_back( {FI.InnerInductionPHI, MaxLegalType, false });
WideIVs.push_back( {FI.OuterInductionPHI, MaxLegalType, false });
- unsigned ElimExt;
- unsigned Widened;
+ unsigned ElimExt = 0;
+ unsigned Widened = 0;
- for (unsigned i = 0; i < WideIVs.size(); i++) {
- PHINode *WidePhi = createWideIV(WideIVs[i], LI, SE, Rewriter, DT, DeadInsts,
+ for (const auto &WideIV : WideIVs) {
+ PHINode *WidePhi = createWideIV(WideIV, LI, SE, Rewriter, DT, DeadInsts,
ElimExt, Widened, true /* HasGuards */,
true /* UsePostIncrementRanges */);
if (!WidePhi)
return false;
LLVM_DEBUG(dbgs() << "Created wide phi: "; WidePhi->dump());
- LLVM_DEBUG(dbgs() << "Deleting old phi: "; WideIVs[i].NarrowIV->dump());
- RecursivelyDeleteDeadPHINode(WideIVs[i].NarrowIV);
+ LLVM_DEBUG(dbgs() << "Deleting old phi: "; WideIV.NarrowIV->dump());
+ RecursivelyDeleteDeadPHINode(WideIV.NarrowIV);
}
// After widening, rediscover all the loop components.
- assert(Widened && "Widenend IV expected");
+ assert(Widened && "Widened IV expected");
FI.Widened = true;
return CanFlattenLoopPair(FI, DT, LI, SE, AC, TTI);
}
-static bool FlattenLoopPair(struct FlattenInfo &FI, DominatorTree *DT,
- LoopInfo *LI, ScalarEvolution *SE,
- AssumptionCache *AC,
+static bool FlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
+ ScalarEvolution *SE, AssumptionCache *AC,
const TargetTransformInfo *TTI) {
LLVM_DEBUG(
dbgs() << "Loop flattening running on outer loop "
@@ -656,33 +662,35 @@ static bool FlattenLoopPair(struct FlattenInfo &FI, DominatorTree *DT,
return DoFlattenLoopPair(FI, DT, LI, SE, AC, TTI);
}
-bool Flatten(DominatorTree *DT, LoopInfo *LI, ScalarEvolution *SE,
+bool Flatten(LoopNest &LN, DominatorTree *DT, LoopInfo *LI, ScalarEvolution *SE,
AssumptionCache *AC, TargetTransformInfo *TTI) {
bool Changed = false;
- for (auto *InnerLoop : LI->getLoopsInPreorder()) {
+ for (Loop *InnerLoop : LN.getLoops()) {
auto *OuterLoop = InnerLoop->getParentLoop();
if (!OuterLoop)
continue;
- struct FlattenInfo FI(OuterLoop, InnerLoop);
+ FlattenInfo FI(OuterLoop, InnerLoop);
Changed |= FlattenLoopPair(FI, DT, LI, SE, AC, TTI);
}
return Changed;
}
-PreservedAnalyses LoopFlattenPass::run(Function &F,
- FunctionAnalysisManager &AM) {
- auto *DT = &AM.getResult<DominatorTreeAnalysis>(F);
- auto *LI = &AM.getResult<LoopAnalysis>(F);
- auto *SE = &AM.getResult<ScalarEvolutionAnalysis>(F);
- auto *AC = &AM.getResult<AssumptionAnalysis>(F);
- auto *TTI = &AM.getResult<TargetIRAnalysis>(F);
+PreservedAnalyses LoopFlattenPass::run(LoopNest &LN, LoopAnalysisManager &LAM,
+ LoopStandardAnalysisResults &AR,
+ LPMUpdater &U) {
+
+ bool Changed = false;
+
+ // The loop flattening pass requires loops to be
+ // in simplified form, and also needs LCSSA. Running
+ // this pass will simplify all loops that contain inner loops,
+ // regardless of whether anything ends up being flattened.
+ Changed |= Flatten(LN, &AR.DT, &AR.LI, &AR.SE, &AR.AC, &AR.TTI);
- if (!Flatten(DT, LI, SE, AC, TTI))
+ if (!Changed)
return PreservedAnalyses::all();
- PreservedAnalyses PA;
- PA.preserveSet<CFGAnalyses>();
- return PA;
+ return PreservedAnalyses::none();
}
namespace {
@@ -724,5 +732,10 @@ bool LoopFlattenLegacyPass::runOnFunction(Function &F) {
auto &TTIP = getAnalysis<TargetTransformInfoWrapperPass>();
auto *TTI = &TTIP.getTTI(F);
auto *AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
- return Flatten(DT, LI, SE, AC, TTI);
+ bool Changed = false;
+ for (Loop *L : *LI) {
+ auto LN = LoopNest::getLoopNest(*L, *SE);
+ Changed |= Flatten(*LN, DT, LI, SE, AC, TTI);
+ }
+ return Changed;
}