diff options
Diffstat (limited to 'lib/Transforms')
120 files changed, 14145 insertions, 7793 deletions
diff --git a/lib/Transforms/IPO/ArgumentPromotion.cpp b/lib/Transforms/IPO/ArgumentPromotion.cpp index f9de54a173d1..328202293867 100644 --- a/lib/Transforms/IPO/ArgumentPromotion.cpp +++ b/lib/Transforms/IPO/ArgumentPromotion.cpp @@ -78,11 +78,15 @@ namespace { const DataLayout *DL; private: + bool isDenselyPacked(Type *type); + bool canPaddingBeAccessed(Argument *Arg); CallGraphNode *PromoteArguments(CallGraphNode *CGN); bool isSafeToPromoteArgument(Argument *Arg, bool isByVal) const; CallGraphNode *DoPromotion(Function *F, - SmallPtrSet<Argument*, 8> &ArgsToPromote, - SmallPtrSet<Argument*, 8> &ByValArgsToTransform); + SmallPtrSetImpl<Argument*> &ArgsToPromote, + SmallPtrSetImpl<Argument*> &ByValArgsToTransform); + + using llvm::Pass::doInitialization; bool doInitialization(CallGraph &CG) override; /// The maximum number of elements to expand, or 0 for unlimited. unsigned maxElements; @@ -123,6 +127,78 @@ bool ArgPromotion::runOnSCC(CallGraphSCC &SCC) { return Changed; } +/// \brief Checks if a type could have padding bytes. +bool ArgPromotion::isDenselyPacked(Type *type) { + + // There is no size information, so be conservative. + if (!type->isSized()) + return false; + + // If the alloc size is not equal to the storage size, then there are padding + // bytes. For x86_fp80 on x86-64, size: 80 alloc size: 128. + if (!DL || DL->getTypeSizeInBits(type) != DL->getTypeAllocSizeInBits(type)) + return false; + + if (!isa<CompositeType>(type)) + return true; + + // For homogenous sequential types, check for padding within members. + if (SequentialType *seqTy = dyn_cast<SequentialType>(type)) + return isa<PointerType>(seqTy) || isDenselyPacked(seqTy->getElementType()); + + // Check for padding within and between elements of a struct. + StructType *StructTy = cast<StructType>(type); + const StructLayout *Layout = DL->getStructLayout(StructTy); + uint64_t StartPos = 0; + for (unsigned i = 0, E = StructTy->getNumElements(); i < E; ++i) { + Type *ElTy = StructTy->getElementType(i); + if (!isDenselyPacked(ElTy)) + return false; + if (StartPos != Layout->getElementOffsetInBits(i)) + return false; + StartPos += DL->getTypeAllocSizeInBits(ElTy); + } + + return true; +} + +/// \brief Checks if the padding bytes of an argument could be accessed. +bool ArgPromotion::canPaddingBeAccessed(Argument *arg) { + + assert(arg->hasByValAttr()); + + // Track all the pointers to the argument to make sure they are not captured. + SmallPtrSet<Value *, 16> PtrValues; + PtrValues.insert(arg); + + // Track all of the stores. + SmallVector<StoreInst *, 16> Stores; + + // Scan through the uses recursively to make sure the pointer is always used + // sanely. + SmallVector<Value *, 16> WorkList; + WorkList.insert(WorkList.end(), arg->user_begin(), arg->user_end()); + while (!WorkList.empty()) { + Value *V = WorkList.back(); + WorkList.pop_back(); + if (isa<GetElementPtrInst>(V) || isa<PHINode>(V)) { + if (PtrValues.insert(V).second) + WorkList.insert(WorkList.end(), V->user_begin(), V->user_end()); + } else if (StoreInst *Store = dyn_cast<StoreInst>(V)) { + Stores.push_back(Store); + } else if (!isa<LoadInst>(V)) { + return true; + } + } + +// Check to make sure the pointers aren't captured + for (StoreInst *Store : Stores) + if (PtrValues.count(Store->getValueOperand())) + return true; + + return false; +} + /// PromoteArguments - This method checks the specified function to see if there /// are any promotable arguments and if it is safe to promote the function (for /// example, all callers are direct). If safe to promote some arguments, it @@ -154,6 +230,13 @@ CallGraphNode *ArgPromotion::PromoteArguments(CallGraphNode *CGN) { isSelfRecursive = true; } + // Don't promote arguments for variadic functions. Adding, removing, or + // changing non-pack parameters can change the classification of pack + // parameters. Frontends encode that classification at the call site in the + // IR, while in the callee the classification is determined dynamically based + // on the number of registers consumed so far. + if (F->isVarArg()) return nullptr; + // Check to see which arguments are promotable. If an argument is promotable, // add it to ArgsToPromote. SmallPtrSet<Argument*, 8> ArgsToPromote; @@ -163,9 +246,13 @@ CallGraphNode *ArgPromotion::PromoteArguments(CallGraphNode *CGN) { Type *AgTy = cast<PointerType>(PtrArg->getType())->getElementType(); // If this is a byval argument, and if the aggregate type is small, just - // pass the elements, which is always safe. This does not apply to - // inalloca. - if (PtrArg->hasByValAttr()) { + // pass the elements, which is always safe, if the passed value is densely + // packed or if we can prove the padding bytes are never accessed. This does + // not apply to inalloca. + bool isSafeToPromote = + PtrArg->hasByValAttr() && + (isDenselyPacked(AgTy) || !canPaddingBeAccessed(PtrArg)); + if (isSafeToPromote) { if (StructType *STy = dyn_cast<StructType>(AgTy)) { if (maxElements > 0 && STy->getNumElements() > maxElements) { DEBUG(dbgs() << "argpromotion disable promoting argument '" @@ -443,7 +530,7 @@ bool ArgPromotion::isSafeToPromoteArgument(Argument *Arg, // of elements of the aggregate. return false; } - ToPromote.insert(Operands); + ToPromote.insert(std::move(Operands)); } } @@ -467,7 +554,8 @@ bool ArgPromotion::isSafeToPromoteArgument(Argument *Arg, BasicBlock *BB = Load->getParent(); AliasAnalysis::Location Loc = AA.getLocation(Load); - if (AA.canInstructionRangeModify(BB->front(), *Load, Loc)) + if (AA.canInstructionRangeModRef(BB->front(), *Load, Loc, + AliasAnalysis::Mod)) return false; // Pointer is invalidated! // Now check every path from the entry block to the load for transparency. @@ -475,10 +563,8 @@ bool ArgPromotion::isSafeToPromoteArgument(Argument *Arg, // loading block. for (pred_iterator PI = pred_begin(BB), E = pred_end(BB); PI != E; ++PI) { BasicBlock *P = *PI; - for (idf_ext_iterator<BasicBlock*, SmallPtrSet<BasicBlock*, 16> > - I = idf_ext_begin(P, TranspBlocks), - E = idf_ext_end(P, TranspBlocks); I != E; ++I) - if (AA.canBasicBlockModify(**I, Loc)) + for (BasicBlock *TranspBB : inverse_depth_first_ext(P, TranspBlocks)) + if (AA.canBasicBlockModify(*TranspBB, Loc)) return false; } } @@ -493,8 +579,8 @@ bool ArgPromotion::isSafeToPromoteArgument(Argument *Arg, /// arguments, and returns the new function. At this point, we know that it's /// safe to do so. CallGraphNode *ArgPromotion::DoPromotion(Function *F, - SmallPtrSet<Argument*, 8> &ArgsToPromote, - SmallPtrSet<Argument*, 8> &ByValArgsToTransform) { + SmallPtrSetImpl<Argument*> &ArgsToPromote, + SmallPtrSetImpl<Argument*> &ByValArgsToTransform) { // Start by computing a new prototype for the function, which is the same as // the old function, but has modified arguments. @@ -615,9 +701,15 @@ CallGraphNode *ArgPromotion::DoPromotion(Function *F, // Patch the pointer to LLVM function in debug info descriptor. auto DI = FunctionDIs.find(F); - if (DI != FunctionDIs.end()) - DI->second.replaceFunction(NF); - + if (DI != FunctionDIs.end()) { + DISubprogram SP = DI->second; + SP.replaceFunction(NF); + // Ensure the map is updated so it can be reused on subsequent argument + // promotions of the same function. + FunctionDIs.erase(DI); + FunctionDIs[NF] = SP; + } + DEBUG(dbgs() << "ARG PROMOTION: Promoting to:" << *NF << "\n" << "From: " << *F); @@ -716,9 +808,11 @@ CallGraphNode *ArgPromotion::DoPromotion(Function *F, // of the previous load. LoadInst *newLoad = new LoadInst(V, V->getName()+".val", Call); newLoad->setAlignment(OrigLoad->getAlignment()); - // Transfer the TBAA info too. - newLoad->setMetadata(LLVMContext::MD_tbaa, - OrigLoad->getMetadata(LLVMContext::MD_tbaa)); + // Transfer the AA info too. + AAMDNodes AAInfo; + OrigLoad->getAAMetadata(AAInfo); + newLoad->setAAMetadata(AAInfo); + Args.push_back(newLoad); AA.copyValue(OrigLoad, Args.back()); } diff --git a/lib/Transforms/IPO/ConstantMerge.cpp b/lib/Transforms/IPO/ConstantMerge.cpp index 23be0819e629..0b6ade9eb536 100644 --- a/lib/Transforms/IPO/ConstantMerge.cpp +++ b/lib/Transforms/IPO/ConstantMerge.cpp @@ -66,7 +66,7 @@ ModulePass *llvm::createConstantMergePass() { return new ConstantMerge(); } /// Find values that are marked as llvm.used. static void FindUsedValues(GlobalVariable *LLVMUsed, - SmallPtrSet<const GlobalValue*, 8> &UsedValues) { + SmallPtrSetImpl<const GlobalValue*> &UsedValues) { if (!LLVMUsed) return; ConstantArray *Inits = cast<ConstantArray>(LLVMUsed->getInitializer()); diff --git a/lib/Transforms/IPO/DeadArgumentElimination.cpp b/lib/Transforms/IPO/DeadArgumentElimination.cpp index ac3853dbd679..4045c09aaa2b 100644 --- a/lib/Transforms/IPO/DeadArgumentElimination.cpp +++ b/lib/Transforms/IPO/DeadArgumentElimination.cpp @@ -199,10 +199,15 @@ bool DAE::DeleteDeadVarargs(Function &Fn) { return false; // Okay, we know we can transform this function if safe. Scan its body - // looking for calls to llvm.vastart. + // looking for calls marked musttail or calls to llvm.vastart. for (Function::iterator BB = Fn.begin(), E = Fn.end(); BB != E; ++BB) { for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) { - if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) { + CallInst *CI = dyn_cast<CallInst>(I); + if (!CI) + continue; + if (CI->isMustTailCall()) + return false; + if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) { if (II->getIntrinsicID() == Intrinsic::vastart) return false; } @@ -297,8 +302,14 @@ bool DAE::DeleteDeadVarargs(Function &Fn) { // Patch the pointer to LLVM function in debug info descriptor. auto DI = FunctionDIs.find(&Fn); - if (DI != FunctionDIs.end()) - DI->second.replaceFunction(NF); + if (DI != FunctionDIs.end()) { + DISubprogram SP = DI->second; + SP.replaceFunction(NF); + // Ensure the map is updated so it can be reused on non-varargs argument + // eliminations of the same function. + FunctionDIs.erase(DI); + FunctionDIs[NF] = SP; + } // Fix up any BlockAddresses that refer to the function. Fn.replaceAllUsesWith(ConstantExpr::getBitCast(NF, Fn.getType())); @@ -1088,8 +1099,8 @@ bool DAE::runOnModule(Module &M) { // determine that dead arguments passed into recursive functions are dead). // DEBUG(dbgs() << "DAE - Determining liveness\n"); - for (Module::iterator I = M.begin(), E = M.end(); I != E; ++I) - SurveyFunction(*I); + for (auto &F : M) + SurveyFunction(F); // Now, remove all dead arguments and return values from each function in // turn. @@ -1102,11 +1113,8 @@ bool DAE::runOnModule(Module &M) { // Finally, look for any unused parameters in functions with non-local // linkage and replace the passed in parameters with undef. - for (Module::iterator I = M.begin(), E = M.end(); I != E; ++I) { - Function& F = *I; - + for (auto &F : M) Changed |= RemoveDeadArgumentsFromCallers(F); - } return Changed; } diff --git a/lib/Transforms/IPO/ExtractGV.cpp b/lib/Transforms/IPO/ExtractGV.cpp index 40ec9fa8c1de..2f8c7d9349b9 100644 --- a/lib/Transforms/IPO/ExtractGV.cpp +++ b/lib/Transforms/IPO/ExtractGV.cpp @@ -91,7 +91,7 @@ namespace { continue; } - makeVisible(*I, Delete); + makeVisible(*I, Delete); if (Delete) I->setInitializer(nullptr); @@ -106,7 +106,7 @@ namespace { continue; } - makeVisible(*I, Delete); + makeVisible(*I, Delete); if (Delete) I->deleteBody(); @@ -118,8 +118,8 @@ namespace { Module::alias_iterator CurI = I; ++I; - bool Delete = deleteStuff == (bool)Named.count(CurI); - makeVisible(*CurI, Delete); + bool Delete = deleteStuff == (bool)Named.count(CurI); + makeVisible(*CurI, Delete); if (Delete) { Type *Ty = CurI->getType()->getElementType(); @@ -148,7 +148,7 @@ namespace { char GVExtractorPass::ID = 0; } -ModulePass *llvm::createGVExtractionPass(std::vector<GlobalValue*>& GVs, +ModulePass *llvm::createGVExtractionPass(std::vector<GlobalValue *> &GVs, bool deleteFn) { return new GVExtractorPass(GVs, deleteFn); } diff --git a/lib/Transforms/IPO/FunctionAttrs.cpp b/lib/Transforms/IPO/FunctionAttrs.cpp index 8174df9ec069..823ae53f1e25 100644 --- a/lib/Transforms/IPO/FunctionAttrs.cpp +++ b/lib/Transforms/IPO/FunctionAttrs.cpp @@ -161,8 +161,9 @@ bool FunctionAttrs::AddReadAttrs(const CallGraphSCC &SCC) { for (CallGraphSCC::iterator I = SCC.begin(), E = SCC.end(); I != E; ++I) { Function *F = (*I)->getFunction(); - if (!F) - // External node - may write memory. Just give up. + if (!F || F->hasFnAttribute(Attribute::OptimizeNone)) + // External node or node we don't want to optimize - assume it may write + // memory and give up. return false; AliasAnalysis::ModRefBehavior MRB = AA->getModRefBehavior(F); @@ -204,9 +205,11 @@ bool FunctionAttrs::AddReadAttrs(const CallGraphSCC &SCC) { CI != CE; ++CI) { Value *Arg = *CI; if (Arg->getType()->isPointerTy()) { + AAMDNodes AAInfo; + I->getAAMetadata(AAInfo); + AliasAnalysis::Location Loc(Arg, - AliasAnalysis::UnknownSize, - I->getMetadata(LLVMContext::MD_tbaa)); + AliasAnalysis::UnknownSize, AAInfo); if (!AA->pointsToConstantMemory(Loc, /*OrLocal=*/true)) { if (MRB & AliasAnalysis::Mod) // Writes non-local memory. Give up. @@ -443,7 +446,7 @@ determinePointerReadAttrs(Argument *A, case Instruction::AddrSpaceCast: // The original value is not read/written via this if the new value isn't. for (Use &UU : I->uses()) - if (Visited.insert(&UU)) + if (Visited.insert(&UU).second) Worklist.push_back(&UU); break; @@ -457,7 +460,7 @@ determinePointerReadAttrs(Argument *A, auto AddUsersToWorklistIfCapturing = [&] { if (Captures) for (Use &UU : I->uses()) - if (Visited.insert(&UU)) + if (Visited.insert(&UU).second) Worklist.push_back(&UU); }; @@ -525,7 +528,8 @@ bool FunctionAttrs::AddArgumentAttrs(const CallGraphSCC &SCC) { // looking up whether a given CallGraphNode is in this SCC. for (CallGraphSCC::iterator I = SCC.begin(), E = SCC.end(); I != E; ++I) { Function *F = (*I)->getFunction(); - if (F && !F->isDeclaration() && !F->mayBeOverridden()) + if (F && !F->isDeclaration() && !F->mayBeOverridden() && + !F->hasFnAttribute(Attribute::OptimizeNone)) SCCNodes.insert(F); } @@ -539,8 +543,9 @@ bool FunctionAttrs::AddArgumentAttrs(const CallGraphSCC &SCC) { for (CallGraphSCC::iterator I = SCC.begin(), E = SCC.end(); I != E; ++I) { Function *F = (*I)->getFunction(); - if (!F) - // External node - only a problem for arguments that we pass to it. + if (!F || F->hasFnAttribute(Attribute::OptimizeNone)) + // External node or function we're trying not to optimize - only a problem + // for arguments that we pass to it. continue; // Definitions with weak linkage may be overridden at linktime with @@ -792,8 +797,8 @@ bool FunctionAttrs::AddNoAliasAttrs(const CallGraphSCC &SCC) { for (CallGraphSCC::iterator I = SCC.begin(), E = SCC.end(); I != E; ++I) { Function *F = (*I)->getFunction(); - if (!F) - // External node - skip it; + if (!F || F->hasFnAttribute(Attribute::OptimizeNone)) + // External node or node we don't want to optimize - skip it; return false; // Already noalias. @@ -832,6 +837,9 @@ bool FunctionAttrs::AddNoAliasAttrs(const CallGraphSCC &SCC) { /// given function and set any applicable attributes. Returns true /// if any attributes were set and false otherwise. bool FunctionAttrs::inferPrototypeAttributes(Function &F) { + if (F.hasFnAttribute(Attribute::OptimizeNone)) + return false; + FunctionType *FTy = F.getFunctionType(); LibFunc::Func TheLibFunc; if (!(TLI->getLibFunc(F.getName(), TheLibFunc) && TLI->has(TheLibFunc))) diff --git a/lib/Transforms/IPO/GlobalDCE.cpp b/lib/Transforms/IPO/GlobalDCE.cpp index 7e7a4c0ae835..0c844fe70650 100644 --- a/lib/Transforms/IPO/GlobalDCE.cpp +++ b/lib/Transforms/IPO/GlobalDCE.cpp @@ -22,6 +22,7 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" #include "llvm/Transforms/Utils/CtorUtils.h" +#include "llvm/Transforms/Utils/GlobalStatus.h" #include "llvm/Pass.h" using namespace llvm; @@ -77,9 +78,6 @@ bool GlobalDCE::runOnModule(Module &M) { // Remove empty functions from the global ctors list. Changed |= optimizeGlobalCtorsList(M, isEmptyFunction); - typedef std::multimap<const Comdat *, GlobalValue *> ComdatGVPairsTy; - ComdatGVPairsTy ComdatGVPairs; - // Loop over the module, adding globals which are obviously necessary. for (Module::iterator I = M.begin(), E = M.end(); I != E; ++I) { Changed |= RemoveUnusedGlobalValue(*I); @@ -87,8 +85,6 @@ bool GlobalDCE::runOnModule(Module &M) { if (!I->isDeclaration() && !I->hasAvailableExternallyLinkage()) { if (!I->isDiscardableIfUnused()) GlobalIsNeeded(I); - else if (const Comdat *C = I->getComdat()) - ComdatGVPairs.insert(std::make_pair(C, I)); } } @@ -100,8 +96,6 @@ bool GlobalDCE::runOnModule(Module &M) { if (!I->isDeclaration() && !I->hasAvailableExternallyLinkage()) { if (!I->isDiscardableIfUnused()) GlobalIsNeeded(I); - else if (const Comdat *C = I->getComdat()) - ComdatGVPairs.insert(std::make_pair(C, I)); } } @@ -111,24 +105,7 @@ bool GlobalDCE::runOnModule(Module &M) { // Externally visible aliases are needed. if (!I->isDiscardableIfUnused()) { GlobalIsNeeded(I); - } else if (const Comdat *C = I->getComdat()) { - ComdatGVPairs.insert(std::make_pair(C, I)); - } - } - - for (ComdatGVPairsTy::iterator I = ComdatGVPairs.begin(), - E = ComdatGVPairs.end(); - I != E;) { - ComdatGVPairsTy::iterator UB = ComdatGVPairs.upper_bound(I->first); - bool CanDiscard = std::all_of(I, UB, [](ComdatGVPairsTy::value_type Pair) { - return Pair.second->isDiscardableIfUnused(); - }); - if (!CanDiscard) { - std::for_each(I, UB, [this](ComdatGVPairsTy::value_type Pair) { - GlobalIsNeeded(Pair.second); - }); } - I = UB; } // Now that all globals which are needed are in the AliveGlobals set, we loop @@ -141,7 +118,12 @@ bool GlobalDCE::runOnModule(Module &M) { I != E; ++I) if (!AliveGlobals.count(I)) { DeadGlobalVars.push_back(I); // Keep track of dead globals - I->setInitializer(nullptr); + if (I->hasInitializer()) { + Constant *Init = I->getInitializer(); + I->setInitializer(nullptr); + if (isSafeToDestroyConstant(Init)) + Init->destroyConstant(); + } } // The second pass drops the bodies of functions which are dead... @@ -203,9 +185,22 @@ bool GlobalDCE::runOnModule(Module &M) { /// recursively mark anything that it uses as also needed. void GlobalDCE::GlobalIsNeeded(GlobalValue *G) { // If the global is already in the set, no need to reprocess it. - if (!AliveGlobals.insert(G)) + if (!AliveGlobals.insert(G).second) return; - + + Module *M = G->getParent(); + if (Comdat *C = G->getComdat()) { + for (Function &F : *M) + if (F.getComdat() == C) + GlobalIsNeeded(&F); + for (GlobalVariable &GV : M->globals()) + if (GV.getComdat() == C) + GlobalIsNeeded(&GV); + for (GlobalAlias &GA : M->aliases()) + if (GA.getComdat() == C) + GlobalIsNeeded(&GA); + } + if (GlobalVariable *GV = dyn_cast<GlobalVariable>(G)) { // If this is a global variable, we must make sure to add any global values // referenced by the initializer to the alive set. @@ -224,6 +219,9 @@ void GlobalDCE::GlobalIsNeeded(GlobalValue *G) { if (F->hasPrefixData()) MarkUsedGlobalsAsNeeded(F->getPrefixData()); + if (F->hasPrologueData()) + MarkUsedGlobalsAsNeeded(F->getPrologueData()); + for (Function::iterator BB = F->begin(), E = F->end(); BB != E; ++BB) for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) for (User::op_iterator U = I->op_begin(), E = I->op_end(); U != E; ++U) @@ -243,7 +241,7 @@ void GlobalDCE::MarkUsedGlobalsAsNeeded(Constant *C) { for (User::op_iterator I = C->op_begin(), E = C->op_end(); I != E; ++I) { // If we've already processed this constant there's no need to do it again. Constant *Op = dyn_cast<Constant>(*I); - if (Op && SeenConstants.insert(Op)) + if (Op && SeenConstants.insert(Op).second) MarkUsedGlobalsAsNeeded(Op); } } diff --git a/lib/Transforms/IPO/GlobalOpt.cpp b/lib/Transforms/IPO/GlobalOpt.cpp index c1d0d3bcdb17..6e0ae8347bc0 100644 --- a/lib/Transforms/IPO/GlobalOpt.cpp +++ b/lib/Transforms/IPO/GlobalOpt.cpp @@ -88,6 +88,7 @@ namespace { const DataLayout *DL; TargetLibraryInfo *TLI; + SmallSet<const Comdat *, 8> NotDiscardableComdats; }; } @@ -612,7 +613,7 @@ static GlobalVariable *SRAGlobal(GlobalVariable *GV, const DataLayout &DL) { /// value will trap if the value is dynamically null. PHIs keeps track of any /// phi nodes we've seen to avoid reprocessing them. static bool AllUsesOfValueWillTrapIfNull(const Value *V, - SmallPtrSet<const PHINode*, 8> &PHIs) { + SmallPtrSetImpl<const PHINode*> &PHIs) { for (const User *U : V->users()) if (isa<LoadInst>(U)) { // Will trap. @@ -638,7 +639,7 @@ static bool AllUsesOfValueWillTrapIfNull(const Value *V, } else if (const PHINode *PN = dyn_cast<PHINode>(U)) { // If we've already seen this phi node, ignore it, it has already been // checked. - if (PHIs.insert(PN) && !AllUsesOfValueWillTrapIfNull(PN, PHIs)) + if (PHIs.insert(PN).second && !AllUsesOfValueWillTrapIfNull(PN, PHIs)) return false; } else if (isa<ICmpInst>(U) && isa<ConstantPointerNull>(U->getOperand(1))) { @@ -957,7 +958,7 @@ static GlobalVariable *OptimizeGlobalAddressOfMalloc(GlobalVariable *GV, /// it is to the specified global. static bool ValueIsOnlyUsedLocallyOrStoredToOneGlobal(const Instruction *V, const GlobalVariable *GV, - SmallPtrSet<const PHINode*, 8> &PHIs) { + SmallPtrSetImpl<const PHINode*> &PHIs) { for (const User *U : V->users()) { const Instruction *Inst = cast<Instruction>(U); @@ -981,7 +982,7 @@ static bool ValueIsOnlyUsedLocallyOrStoredToOneGlobal(const Instruction *V, if (const PHINode *PN = dyn_cast<PHINode>(Inst)) { // PHIs are ok if all uses are ok. Don't infinitely recurse through PHI // cycles. - if (PHIs.insert(PN)) + if (PHIs.insert(PN).second) if (!ValueIsOnlyUsedLocallyOrStoredToOneGlobal(PN, GV, PHIs)) return false; continue; @@ -1047,8 +1048,8 @@ static void ReplaceUsesOfMallocWithGlobal(Instruction *Alloc, /// of a load) are simple enough to perform heap SRA on. This permits GEP's /// that index through the array and struct field, icmps of null, and PHIs. static bool LoadUsesSimpleEnoughForHeapSRA(const Value *V, - SmallPtrSet<const PHINode*, 32> &LoadUsingPHIs, - SmallPtrSet<const PHINode*, 32> &LoadUsingPHIsPerLoad) { + SmallPtrSetImpl<const PHINode*> &LoadUsingPHIs, + SmallPtrSetImpl<const PHINode*> &LoadUsingPHIsPerLoad) { // We permit two users of the load: setcc comparing against the null // pointer, and a getelementptr of a specific form. for (const User *U : V->users()) { @@ -1072,11 +1073,11 @@ static bool LoadUsesSimpleEnoughForHeapSRA(const Value *V, } if (const PHINode *PN = dyn_cast<PHINode>(UI)) { - if (!LoadUsingPHIsPerLoad.insert(PN)) + if (!LoadUsingPHIsPerLoad.insert(PN).second) // This means some phi nodes are dependent on each other. // Avoid infinite looping! return false; - if (!LoadUsingPHIs.insert(PN)) + if (!LoadUsingPHIs.insert(PN).second) // If we have already analyzed this PHI, then it is safe. continue; @@ -1115,9 +1116,7 @@ static bool AllGlobalLoadUsesSimpleEnoughForHeapSRA(const GlobalVariable *GV, // that all inputs the to the PHI nodes are in the same equivalence sets. // Check to verify that all operands of the PHIs are either PHIS that can be // transformed, loads from GV, or MI itself. - for (SmallPtrSet<const PHINode*, 32>::const_iterator I = LoadUsingPHIs.begin() - , E = LoadUsingPHIs.end(); I != E; ++I) { - const PHINode *PN = *I; + for (const PHINode *PN : LoadUsingPHIs) { for (unsigned op = 0, e = PN->getNumIncomingValues(); op != e; ++op) { Value *InVal = PN->getIncomingValue(op); @@ -1910,8 +1909,11 @@ bool GlobalOpt::OptimizeFunctions(Module &M) { // Functions without names cannot be referenced outside this module. if (!F->hasName() && !F->isDeclaration() && !F->hasLocalLinkage()) F->setLinkage(GlobalValue::InternalLinkage); + + const Comdat *C = F->getComdat(); + bool inComdat = C && NotDiscardableComdats.count(C); F->removeDeadConstantUsers(); - if (F->isDefTriviallyDead()) { + if ((!inComdat || F->hasLocalLinkage()) && F->isDefTriviallyDead()) { F->eraseFromParent(); Changed = true; ++NumFnDeleted; @@ -1943,12 +1945,6 @@ bool GlobalOpt::OptimizeFunctions(Module &M) { bool GlobalOpt::OptimizeGlobalVars(Module &M) { bool Changed = false; - SmallSet<const Comdat *, 8> NotDiscardableComdats; - for (const GlobalVariable &GV : M.globals()) - if (const Comdat *C = GV.getComdat()) - if (!GV.isDiscardableIfUnused()) - NotDiscardableComdats.insert(C); - for (Module::global_iterator GVI = M.global_begin(), E = M.global_end(); GVI != E; ) { GlobalVariable *GV = GVI++; @@ -1965,7 +1961,7 @@ bool GlobalOpt::OptimizeGlobalVars(Module &M) { if (GV->isDiscardableIfUnused()) { if (const Comdat *C = GV->getComdat()) - if (NotDiscardableComdats.count(C)) + if (NotDiscardableComdats.count(C) && !GV->hasLocalLinkage()) continue; Changed |= ProcessGlobal(GV, GVI); } @@ -1975,7 +1971,7 @@ bool GlobalOpt::OptimizeGlobalVars(Module &M) { static inline bool isSimpleEnoughValueToCommit(Constant *C, - SmallPtrSet<Constant*, 8> &SimpleConstants, + SmallPtrSetImpl<Constant*> &SimpleConstants, const DataLayout *DL); @@ -1988,7 +1984,7 @@ isSimpleEnoughValueToCommit(Constant *C, /// in SimpleConstants to avoid having to rescan the same constants all the /// time. static bool isSimpleEnoughValueToCommitHelper(Constant *C, - SmallPtrSet<Constant*, 8> &SimpleConstants, + SmallPtrSetImpl<Constant*> &SimpleConstants, const DataLayout *DL) { // Simple global addresses are supported, do not allow dllimport or // thread-local globals. @@ -2046,10 +2042,11 @@ static bool isSimpleEnoughValueToCommitHelper(Constant *C, static inline bool isSimpleEnoughValueToCommit(Constant *C, - SmallPtrSet<Constant*, 8> &SimpleConstants, + SmallPtrSetImpl<Constant*> &SimpleConstants, const DataLayout *DL) { // If we already checked this constant, we win. - if (!SimpleConstants.insert(C)) return true; + if (!SimpleConstants.insert(C).second) + return true; // Check the constant. return isSimpleEnoughValueToCommitHelper(C, SimpleConstants, DL); } @@ -2217,7 +2214,7 @@ public: return MutatedMemory; } - const SmallPtrSet<GlobalVariable*, 8> &getInvariants() const { + const SmallPtrSetImpl<GlobalVariable*> &getInvariants() const { return Invariants; } @@ -2394,6 +2391,17 @@ bool Evaluator::EvaluateBlock(BasicBlock::iterator CurInst, getVal(SI->getOperand(2))); DEBUG(dbgs() << "Found a Select! Simplifying: " << *InstResult << "\n"); + } else if (auto *EVI = dyn_cast<ExtractValueInst>(CurInst)) { + InstResult = ConstantExpr::getExtractValue( + getVal(EVI->getAggregateOperand()), EVI->getIndices()); + DEBUG(dbgs() << "Found an ExtractValueInst! Simplifying: " << *InstResult + << "\n"); + } else if (auto *IVI = dyn_cast<InsertValueInst>(CurInst)) { + InstResult = ConstantExpr::getInsertValue( + getVal(IVI->getAggregateOperand()), + getVal(IVI->getInsertedValueOperand()), IVI->getIndices()); + DEBUG(dbgs() << "Found an InsertValueInst! Simplifying: " << *InstResult + << "\n"); } else if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(CurInst)) { Constant *P = getVal(GEP->getOperand(0)); SmallVector<Constant*, 8> GEPOps; @@ -2663,7 +2671,7 @@ bool Evaluator::EvaluateFunction(Function *F, Constant *&RetVal, // Okay, we succeeded in evaluating this control flow. See if we have // executed the new block before. If so, we have a looping function, // which we cannot evaluate in reasonable time. - if (!ExecutedBlocks.insert(NextBB)) + if (!ExecutedBlocks.insert(NextBB).second) return false; // looped! // Okay, we have never been in this block before. Check to see if there @@ -2700,10 +2708,8 @@ static bool EvaluateStaticConstructor(Function *F, const DataLayout *DL, Eval.getMutatedMemory().begin(), E = Eval.getMutatedMemory().end(); I != E; ++I) CommitValueTo(I->second, I->first); - for (SmallPtrSet<GlobalVariable*, 8>::const_iterator I = - Eval.getInvariants().begin(), E = Eval.getInvariants().end(); - I != E; ++I) - (*I)->setConstant(true); + for (GlobalVariable *GV : Eval.getInvariants()) + GV->setConstant(true); } return EvalSuccess; @@ -2714,7 +2720,7 @@ static int compareNames(Constant *const *A, Constant *const *B) { } static void setUsedInitializer(GlobalVariable &V, - SmallPtrSet<GlobalValue *, 8> Init) { + const SmallPtrSet<GlobalValue *, 8> &Init) { if (Init.empty()) { V.eraseFromParent(); return; @@ -2724,10 +2730,9 @@ static void setUsedInitializer(GlobalVariable &V, PointerType *Int8PtrTy = Type::getInt8PtrTy(V.getContext(), 0); SmallVector<llvm::Constant *, 8> UsedArray; - for (SmallPtrSet<GlobalValue *, 8>::iterator I = Init.begin(), E = Init.end(); - I != E; ++I) { + for (GlobalValue *GV : Init) { Constant *Cast - = ConstantExpr::getPointerBitCastOrAddrSpaceCast(*I, Int8PtrTy); + = ConstantExpr::getPointerBitCastOrAddrSpaceCast(GV, Int8PtrTy); UsedArray.push_back(Cast); } // Sort to get deterministic order. @@ -2758,18 +2763,27 @@ public: CompilerUsedV = collectUsedGlobalVariables(M, CompilerUsed, true); } typedef SmallPtrSet<GlobalValue *, 8>::iterator iterator; + typedef iterator_range<iterator> used_iterator_range; iterator usedBegin() { return Used.begin(); } iterator usedEnd() { return Used.end(); } + used_iterator_range used() { + return used_iterator_range(usedBegin(), usedEnd()); + } iterator compilerUsedBegin() { return CompilerUsed.begin(); } iterator compilerUsedEnd() { return CompilerUsed.end(); } + used_iterator_range compilerUsed() { + return used_iterator_range(compilerUsedBegin(), compilerUsedEnd()); + } bool usedCount(GlobalValue *GV) const { return Used.count(GV); } bool compilerUsedCount(GlobalValue *GV) const { return CompilerUsed.count(GV); } bool usedErase(GlobalValue *GV) { return Used.erase(GV); } bool compilerUsedErase(GlobalValue *GV) { return CompilerUsed.erase(GV); } - bool usedInsert(GlobalValue *GV) { return Used.insert(GV); } - bool compilerUsedInsert(GlobalValue *GV) { return CompilerUsed.insert(GV); } + bool usedInsert(GlobalValue *GV) { return Used.insert(GV).second; } + bool compilerUsedInsert(GlobalValue *GV) { + return CompilerUsed.insert(GV).second; + } void syncVariablesAndSets() { if (UsedV) @@ -2814,7 +2828,8 @@ static bool mayHaveOtherReferences(GlobalAlias &GA, const LLVMUsed &U) { return U.usedCount(&GA) || U.compilerUsedCount(&GA); } -static bool hasUsesToReplace(GlobalAlias &GA, LLVMUsed &U, bool &RenameTarget) { +static bool hasUsesToReplace(GlobalAlias &GA, const LLVMUsed &U, + bool &RenameTarget) { RenameTarget = false; bool Ret = false; if (hasUseOtherThanLLVMUsed(GA, U)) @@ -2849,10 +2864,8 @@ bool GlobalOpt::OptimizeGlobalAliases(Module &M) { bool Changed = false; LLVMUsed Used(M); - for (SmallPtrSet<GlobalValue *, 8>::iterator I = Used.usedBegin(), - E = Used.usedEnd(); - I != E; ++I) - Used.compilerUsedErase(*I); + for (GlobalValue *GV : Used.used()) + Used.compilerUsedErase(GV); for (Module::alias_iterator I = M.alias_begin(), E = M.alias_end(); I != E;) { @@ -2963,7 +2976,7 @@ static bool cxxDtorIsEmpty(const Function &Fn, SmallPtrSet<const Function *, 8> NewCalledFunctions(CalledFunctions); // Don't treat recursive functions as empty. - if (!NewCalledFunctions.insert(CalledFn)) + if (!NewCalledFunctions.insert(CalledFn).second) return false; if (!cxxDtorIsEmpty(*CalledFn, NewCalledFunctions)) @@ -3035,6 +3048,20 @@ bool GlobalOpt::runOnModule(Module &M) { while (LocalChange) { LocalChange = false; + NotDiscardableComdats.clear(); + for (const GlobalVariable &GV : M.globals()) + if (const Comdat *C = GV.getComdat()) + if (!GV.isDiscardableIfUnused() || !GV.use_empty()) + NotDiscardableComdats.insert(C); + for (Function &F : M) + if (const Comdat *C = F.getComdat()) + if (!F.isDefTriviallyDead()) + NotDiscardableComdats.insert(C); + for (GlobalAlias &GA : M.aliases()) + if (const Comdat *C = GA.getComdat()) + if (!GA.isDiscardableIfUnused() || !GA.use_empty()) + NotDiscardableComdats.insert(C); + // Delete functions that are trivially dead, ccc -> fastcc LocalChange |= OptimizeFunctions(M); diff --git a/lib/Transforms/IPO/InlineAlways.cpp b/lib/Transforms/IPO/InlineAlways.cpp index 624cb90c0d5c..dc56a02e7b7d 100644 --- a/lib/Transforms/IPO/InlineAlways.cpp +++ b/lib/Transforms/IPO/InlineAlways.cpp @@ -14,6 +14,8 @@ #include "llvm/Transforms/IPO.h" #include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/CallGraph.h" #include "llvm/Analysis/InlineCost.h" #include "llvm/IR/CallSite.h" @@ -65,6 +67,8 @@ public: char AlwaysInliner::ID = 0; INITIALIZE_PASS_BEGIN(AlwaysInliner, "always-inline", "Inliner for always_inline functions", false, false) +INITIALIZE_AG_DEPENDENCY(AliasAnalysis) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass) INITIALIZE_PASS_DEPENDENCY(InlineCostAnalysis) INITIALIZE_PASS_END(AlwaysInliner, "always-inline", diff --git a/lib/Transforms/IPO/InlineSimple.cpp b/lib/Transforms/IPO/InlineSimple.cpp index d189756032b6..9b01d81b3c7c 100644 --- a/lib/Transforms/IPO/InlineSimple.cpp +++ b/lib/Transforms/IPO/InlineSimple.cpp @@ -12,6 +12,8 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/IPO.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/CallGraph.h" #include "llvm/Analysis/InlineCost.h" #include "llvm/IR/CallSite.h" @@ -73,6 +75,8 @@ static int computeThresholdFromOptLevels(unsigned OptLevel, char SimpleInliner::ID = 0; INITIALIZE_PASS_BEGIN(SimpleInliner, "inline", "Function Integration/Inlining", false, false) +INITIALIZE_AG_DEPENDENCY(AliasAnalysis) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass) INITIALIZE_PASS_DEPENDENCY(InlineCostAnalysis) INITIALIZE_PASS_END(SimpleInliner, "inline", diff --git a/lib/Transforms/IPO/Inliner.cpp b/lib/Transforms/IPO/Inliner.cpp index 9087ab23bb70..66867437e1b7 100644 --- a/lib/Transforms/IPO/Inliner.cpp +++ b/lib/Transforms/IPO/Inliner.cpp @@ -16,6 +16,8 @@ #include "llvm/Transforms/IPO/InlinerPass.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/CallGraph.h" #include "llvm/Analysis/InlineCost.h" #include "llvm/IR/CallSite.h" @@ -74,6 +76,8 @@ Inliner::Inliner(char &ID, int Threshold, bool InsertLifetime) /// the call graph. If the derived class implements this method, it should /// always explicitly call the implementation here. void Inliner::getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequired<AliasAnalysis>(); + AU.addRequired<AssumptionCacheTracker>(); CallGraphSCCPass::getAnalysisUsage(AU); } @@ -215,7 +219,7 @@ static bool InlineCallIfPossible(CallSite CS, InlineFunctionInfo &IFI, // If the inlined function already uses this alloca then we can't reuse // it. - if (!UsedAllocas.insert(AvailableAlloca)) + if (!UsedAllocas.insert(AvailableAlloca).second) continue; // Otherwise, we *can* reuse it, RAUW AI into AvailableAlloca and declare @@ -357,8 +361,7 @@ bool Inliner::shouldInline(CallSite CS) { // FIXME: All of this logic should be sunk into getInlineCost. It relies on // the internal implementation of the inline cost metrics rather than // treating them as truly abstract units etc. - if (Caller->hasLocalLinkage() || - Caller->getLinkage() == GlobalValue::LinkOnceODRLinkage) { + if (Caller->hasLocalLinkage() || Caller->hasLinkOnceODRLinkage()) { int TotalSecondaryCost = 0; // The candidate cost to be imposed upon the current function. int CandidateCost = IC.getCost() - (InlineConstants::CallPenalty + 1); @@ -440,9 +443,11 @@ static bool InlineHistoryIncludes(Function *F, int InlineHistoryID, bool Inliner::runOnSCC(CallGraphSCC &SCC) { CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph(); + AssumptionCacheTracker *ACT = &getAnalysis<AssumptionCacheTracker>(); DataLayoutPass *DLP = getAnalysisIfAvailable<DataLayoutPass>(); const DataLayout *DL = DLP ? &DLP->getDataLayout() : nullptr; const TargetLibraryInfo *TLI = getAnalysisIfAvailable<TargetLibraryInfo>(); + AliasAnalysis *AA = &getAnalysis<AliasAnalysis>(); SmallPtrSet<Function*, 8> SCCFunctions; DEBUG(dbgs() << "Inliner visiting SCC:"); @@ -501,8 +506,8 @@ bool Inliner::runOnSCC(CallGraphSCC &SCC) { InlinedArrayAllocasTy InlinedArrayAllocas; - InlineFunctionInfo InlineInfo(&CG, DL); - + InlineFunctionInfo InlineInfo(&CG, DL, AA, ACT); + // Now that we have all of the call sites, loop over them and inline them if // it looks profitable to do so. bool Changed = false; @@ -664,6 +669,13 @@ bool Inliner::removeDeadFunctions(CallGraph &CG, bool AlwaysInlineOnly) { if (!F->isDefTriviallyDead()) continue; + + // It is unsafe to drop a function with discardable linkage from a COMDAT + // without also dropping the other members of the COMDAT. + // The inliner doesn't visit non-function entities which are in COMDAT + // groups so it is unsafe to do so *unless* the linkage is local. + if (!F->hasLocalLinkage() && F->hasComdat()) + continue; // Remove any call graph edges from the function to its callees. CGN->removeAllCalledFunctions(); diff --git a/lib/Transforms/IPO/Internalize.cpp b/lib/Transforms/IPO/Internalize.cpp index c970a1a1c1af..7950163f757d 100644 --- a/lib/Transforms/IPO/Internalize.cpp +++ b/lib/Transforms/IPO/Internalize.cpp @@ -148,9 +148,7 @@ bool InternalizePass::runOnModule(Module &M) { // we don't see references from function local inline assembly. To be // conservative, we internalize symbols in llvm.compiler.used, but we // keep llvm.compiler.used so that the symbol is not deleted by llvm. - for (SmallPtrSet<GlobalValue *, 8>::iterator I = Used.begin(), E = Used.end(); - I != E; ++I) { - GlobalValue *V = *I; + for (GlobalValue *V : Used) { ExternalNames.insert(V->getName()); } diff --git a/lib/Transforms/IPO/MergeFunctions.cpp b/lib/Transforms/IPO/MergeFunctions.cpp index 2fb0ddb174a0..b91ebf2b96b0 100644 --- a/lib/Transforms/IPO/MergeFunctions.cpp +++ b/lib/Transforms/IPO/MergeFunctions.cpp @@ -286,7 +286,7 @@ private: /// 6.4.Load: range metadata (as integer numbers) /// On this stage its better to see the code, since its not more than 10-15 /// strings for particular instruction, and could change sometimes. - int cmpOperation(const Instruction *L, const Instruction *R) const; + int cmpOperations(const Instruction *L, const Instruction *R) const; /// Compare two GEPs for equivalent pointer arithmetic. /// Parts to be compared for each comparison stage, @@ -297,9 +297,9 @@ private: /// 3. Pointer operand type (using cmpType method). /// 4. Number of operands. /// 5. Compare operands, using cmpValues method. - int cmpGEP(const GEPOperator *GEPL, const GEPOperator *GEPR); - int cmpGEP(const GetElementPtrInst *GEPL, const GetElementPtrInst *GEPR) { - return cmpGEP(cast<GEPOperator>(GEPL), cast<GEPOperator>(GEPR)); + int cmpGEPs(const GEPOperator *GEPL, const GEPOperator *GEPR); + int cmpGEPs(const GetElementPtrInst *GEPL, const GetElementPtrInst *GEPR) { + return cmpGEPs(cast<GEPOperator>(GEPL), cast<GEPOperator>(GEPR)); } /// cmpType - compares two types, @@ -342,12 +342,12 @@ private: /// be checked with the same way. If we get Res != 0 on some stage, return it. /// Otherwise return 0. /// 6. For all other cases put llvm_unreachable. - int cmpType(Type *TyL, Type *TyR) const; + int cmpTypes(Type *TyL, Type *TyR) const; int cmpNumbers(uint64_t L, uint64_t R) const; - int cmpAPInt(const APInt &L, const APInt &R) const; - int cmpAPFloat(const APFloat &L, const APFloat &R) const; + int cmpAPInts(const APInt &L, const APInt &R) const; + int cmpAPFloats(const APFloat &L, const APFloat &R) const; int cmpStrings(StringRef L, StringRef R) const; int cmpAttrs(const AttributeSet L, const AttributeSet R) const; @@ -392,15 +392,15 @@ private: DenseMap<const Value*, int> sn_mapL, sn_mapR; }; -class FunctionPtr { +class FunctionNode { AssertingVH<Function> F; const DataLayout *DL; public: - FunctionPtr(Function *F, const DataLayout *DL) : F(F), DL(DL) {} + FunctionNode(Function *F, const DataLayout *DL) : F(F), DL(DL) {} Function *getFunc() const { return F; } void release() { F = 0; } - bool operator<(const FunctionPtr &RHS) const { + bool operator<(const FunctionNode &RHS) const { return (FunctionComparator(DL, F, RHS.getFunc()).compare()) == -1; } }; @@ -412,7 +412,7 @@ int FunctionComparator::cmpNumbers(uint64_t L, uint64_t R) const { return 0; } -int FunctionComparator::cmpAPInt(const APInt &L, const APInt &R) const { +int FunctionComparator::cmpAPInts(const APInt &L, const APInt &R) const { if (int Res = cmpNumbers(L.getBitWidth(), R.getBitWidth())) return Res; if (L.ugt(R)) return 1; @@ -420,11 +420,11 @@ int FunctionComparator::cmpAPInt(const APInt &L, const APInt &R) const { return 0; } -int FunctionComparator::cmpAPFloat(const APFloat &L, const APFloat &R) const { +int FunctionComparator::cmpAPFloats(const APFloat &L, const APFloat &R) const { if (int Res = cmpNumbers((uint64_t)&L.getSemantics(), (uint64_t)&R.getSemantics())) return Res; - return cmpAPInt(L.bitcastToAPInt(), R.bitcastToAPInt()); + return cmpAPInts(L.bitcastToAPInt(), R.bitcastToAPInt()); } int FunctionComparator::cmpStrings(StringRef L, StringRef R) const { @@ -474,7 +474,7 @@ int FunctionComparator::cmpConstants(const Constant *L, const Constant *R) { // Check whether types are bitcastable. This part is just re-factored // Type::canLosslesslyBitCastTo method, but instead of returning true/false, // we also pack into result which type is "less" for us. - int TypesRes = cmpType(TyL, TyR); + int TypesRes = cmpTypes(TyL, TyR); if (TypesRes != 0) { // Types are different, but check whether we can bitcast them. if (!TyL->isFirstClassType()) { @@ -541,12 +541,12 @@ int FunctionComparator::cmpConstants(const Constant *L, const Constant *R) { case Value::ConstantIntVal: { const APInt &LInt = cast<ConstantInt>(L)->getValue(); const APInt &RInt = cast<ConstantInt>(R)->getValue(); - return cmpAPInt(LInt, RInt); + return cmpAPInts(LInt, RInt); } case Value::ConstantFPVal: { const APFloat &LAPF = cast<ConstantFP>(L)->getValueAPF(); const APFloat &RAPF = cast<ConstantFP>(R)->getValueAPF(); - return cmpAPFloat(LAPF, RAPF); + return cmpAPFloats(LAPF, RAPF); } case Value::ConstantArrayVal: { const ConstantArray *LA = cast<ConstantArray>(L); @@ -615,7 +615,7 @@ int FunctionComparator::cmpConstants(const Constant *L, const Constant *R) { /// cmpType - compares two types, /// defines total ordering among the types set. /// See method declaration comments for more details. -int FunctionComparator::cmpType(Type *TyL, Type *TyR) const { +int FunctionComparator::cmpTypes(Type *TyL, Type *TyR) const { PointerType *PTyL = dyn_cast<PointerType>(TyL); PointerType *PTyR = dyn_cast<PointerType>(TyR); @@ -665,8 +665,7 @@ int FunctionComparator::cmpType(Type *TyL, Type *TyR) const { return cmpNumbers(STyL->isPacked(), STyR->isPacked()); for (unsigned i = 0, e = STyL->getNumElements(); i != e; ++i) { - if (int Res = cmpType(STyL->getElementType(i), - STyR->getElementType(i))) + if (int Res = cmpTypes(STyL->getElementType(i), STyR->getElementType(i))) return Res; } return 0; @@ -681,11 +680,11 @@ int FunctionComparator::cmpType(Type *TyL, Type *TyR) const { if (FTyL->isVarArg() != FTyR->isVarArg()) return cmpNumbers(FTyL->isVarArg(), FTyR->isVarArg()); - if (int Res = cmpType(FTyL->getReturnType(), FTyR->getReturnType())) + if (int Res = cmpTypes(FTyL->getReturnType(), FTyR->getReturnType())) return Res; for (unsigned i = 0, e = FTyL->getNumParams(); i != e; ++i) { - if (int Res = cmpType(FTyL->getParamType(i), FTyR->getParamType(i))) + if (int Res = cmpTypes(FTyL->getParamType(i), FTyR->getParamType(i))) return Res; } return 0; @@ -696,7 +695,7 @@ int FunctionComparator::cmpType(Type *TyL, Type *TyR) const { ArrayType *ATyR = cast<ArrayType>(TyR); if (ATyL->getNumElements() != ATyR->getNumElements()) return cmpNumbers(ATyL->getNumElements(), ATyR->getNumElements()); - return cmpType(ATyL->getElementType(), ATyR->getElementType()); + return cmpTypes(ATyL->getElementType(), ATyR->getElementType()); } } } @@ -705,8 +704,8 @@ int FunctionComparator::cmpType(Type *TyL, Type *TyR) const { // and pointer-to-B are equivalent. This should be kept in sync with // Instruction::isSameOperationAs. // Read method declaration comments for more details. -int FunctionComparator::cmpOperation(const Instruction *L, - const Instruction *R) const { +int FunctionComparator::cmpOperations(const Instruction *L, + const Instruction *R) const { // Differences from Instruction::isSameOperationAs: // * replace type comparison with calls to isEquivalentType. // * we test for I->hasSameSubclassOptionalData (nuw/nsw/tail) at the top @@ -717,7 +716,7 @@ int FunctionComparator::cmpOperation(const Instruction *L, if (int Res = cmpNumbers(L->getNumOperands(), R->getNumOperands())) return Res; - if (int Res = cmpType(L->getType(), R->getType())) + if (int Res = cmpTypes(L->getType(), R->getType())) return Res; if (int Res = cmpNumbers(L->getRawSubclassOptionalData(), @@ -728,7 +727,7 @@ int FunctionComparator::cmpOperation(const Instruction *L, // if all operands are the same type for (unsigned i = 0, e = L->getNumOperands(); i != e; ++i) { if (int Res = - cmpType(L->getOperand(i)->getType(), R->getOperand(i)->getType())) + cmpTypes(L->getOperand(i)->getType(), R->getOperand(i)->getType())) return Res; } @@ -845,7 +844,7 @@ int FunctionComparator::cmpOperation(const Instruction *L, // Determine whether two GEP operations perform the same underlying arithmetic. // Read method declaration comments for more details. -int FunctionComparator::cmpGEP(const GEPOperator *GEPL, +int FunctionComparator::cmpGEPs(const GEPOperator *GEPL, const GEPOperator *GEPR) { unsigned int ASL = GEPL->getPointerAddressSpace(); @@ -861,7 +860,7 @@ int FunctionComparator::cmpGEP(const GEPOperator *GEPL, APInt OffsetL(BitWidth, 0), OffsetR(BitWidth, 0); if (GEPL->accumulateConstantOffset(*DL, OffsetL) && GEPR->accumulateConstantOffset(*DL, OffsetR)) - return cmpAPInt(OffsetL, OffsetR); + return cmpAPInts(OffsetL, OffsetR); } if (int Res = cmpNumbers((uint64_t)GEPL->getPointerOperand()->getType(), @@ -945,10 +944,10 @@ int FunctionComparator::compare(const BasicBlock *BBL, const BasicBlock *BBR) { if (int Res = cmpValues(GEPL->getPointerOperand(), GEPR->getPointerOperand())) return Res; - if (int Res = cmpGEP(GEPL, GEPR)) + if (int Res = cmpGEPs(GEPL, GEPR)) return Res; } else { - if (int Res = cmpOperation(InstL, InstR)) + if (int Res = cmpOperations(InstL, InstR)) return Res; assert(InstL->getNumOperands() == InstR->getNumOperands()); @@ -960,7 +959,7 @@ int FunctionComparator::compare(const BasicBlock *BBL, const BasicBlock *BBR) { if (int Res = cmpNumbers(OpL->getValueID(), OpR->getValueID())) return Res; // TODO: Already checked in cmpOperation - if (int Res = cmpType(OpL->getType(), OpR->getType())) + if (int Res = cmpTypes(OpL->getType(), OpR->getType())) return Res; } } @@ -1008,7 +1007,7 @@ int FunctionComparator::compare() { if (int Res = cmpNumbers(FnL->getCallingConv(), FnR->getCallingConv())) return Res; - if (int Res = cmpType(FnL->getFunctionType(), FnR->getFunctionType())) + if (int Res = cmpTypes(FnL->getFunctionType(), FnR->getFunctionType())) return Res; assert(FnL->arg_size() == FnR->arg_size() && @@ -1050,7 +1049,7 @@ int FunctionComparator::compare() { assert(TermL->getNumSuccessors() == TermR->getNumSuccessors()); for (unsigned i = 0, e = TermL->getNumSuccessors(); i != e; ++i) { - if (!VisitedBBs.insert(TermL->getSuccessor(i))) + if (!VisitedBBs.insert(TermL->getSuccessor(i)).second) continue; FnLBBs.push_back(TermL->getSuccessor(i)); @@ -1078,7 +1077,7 @@ public: bool runOnModule(Module &M) override; private: - typedef std::set<FunctionPtr> FnTreeType; + typedef std::set<FunctionNode> FnTreeType; /// A work queue of functions that may have been modified and should be /// analyzed again. @@ -1301,11 +1300,11 @@ static Value *createCast(IRBuilder<false> &Builder, Value *V, Type *DestTy) { Value *Result = UndefValue::get(DestTy); for (unsigned int I = 0, E = SrcTy->getStructNumElements(); I < E; ++I) { Value *Element = createCast( - Builder, Builder.CreateExtractValue(V, ArrayRef<unsigned int>(I)), + Builder, Builder.CreateExtractValue(V, makeArrayRef(I)), DestTy->getStructElementType(I)); Result = - Builder.CreateInsertValue(Result, Element, ArrayRef<unsigned int>(I)); + Builder.CreateInsertValue(Result, Element, makeArrayRef(I)); } return Result; } @@ -1421,14 +1420,14 @@ void MergeFunctions::mergeTwoFunctions(Function *F, Function *G) { // that was already inserted. bool MergeFunctions::insert(Function *NewFunction) { std::pair<FnTreeType::iterator, bool> Result = - FnTree.insert(FunctionPtr(NewFunction, DL)); + FnTree.insert(FunctionNode(NewFunction, DL)); if (Result.second) { DEBUG(dbgs() << "Inserting as unique: " << NewFunction->getName() << '\n'); return false; } - const FunctionPtr &OldF = *Result.first; + const FunctionNode &OldF = *Result.first; // Don't merge tiny functions, since it can just end up making the function // larger. @@ -1458,7 +1457,7 @@ bool MergeFunctions::insert(Function *NewFunction) { void MergeFunctions::remove(Function *F) { // We need to make sure we remove F, not a function "equal" to F per the // function equality comparator. - FnTreeType::iterator found = FnTree.find(FunctionPtr(F, DL)); + FnTreeType::iterator found = FnTree.find(FunctionNode(F, DL)); size_t Erased = 0; if (found != FnTree.end() && found->getFunc() == F) { Erased = 1; diff --git a/lib/Transforms/IPO/PassManagerBuilder.cpp b/lib/Transforms/IPO/PassManagerBuilder.cpp index 701fb462b4fd..0414caa61fca 100644 --- a/lib/Transforms/IPO/PassManagerBuilder.cpp +++ b/lib/Transforms/IPO/PassManagerBuilder.cpp @@ -17,11 +17,14 @@ #include "llvm-c/Transforms/PassManagerBuilder.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/Passes.h" +#include "llvm/IR/DataLayout.h" #include "llvm/IR/Verifier.h" #include "llvm/PassManager.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/ManagedStatic.h" #include "llvm/Target/TargetLibraryInfo.h" +#include "llvm/Target/TargetMachine.h" +#include "llvm/Target/TargetSubtargetInfo.h" #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Vectorize.h" @@ -45,6 +48,10 @@ UseGVNAfterVectorization("use-gvn-after-vectorization", cl::init(false), cl::Hidden, cl::desc("Run GVN instead of Early CSE after vectorization passes")); +static cl::opt<bool> ExtraVectorizerPasses( + "extra-vectorizer-passes", cl::init(false), cl::Hidden, + cl::desc("Run cleanup optimization passes after vectorization.")); + static cl::opt<bool> UseNewSROA("use-new-sroa", cl::init(true), cl::Hidden, cl::desc("Enable the new, experimental SROA pass")); @@ -57,6 +64,20 @@ static cl::opt<bool> RunLoadCombine("combine-loads", cl::init(false), cl::Hidden, cl::desc("Run the load combining pass")); +static cl::opt<bool> +RunSLPAfterLoopVectorization("run-slp-after-loop-vectorization", + cl::init(true), cl::Hidden, + cl::desc("Run the SLP vectorizer (and BB vectorizer) after the Loop " + "vectorizer instead of before")); + +static cl::opt<bool> UseCFLAA("use-cfl-aa", + cl::init(false), cl::Hidden, + cl::desc("Enable the new, experimental CFL alias analysis")); + +static cl::opt<bool> +EnableMLSM("mlsm", cl::init(true), cl::Hidden, + cl::desc("Enable motion of merged load and store")); + PassManagerBuilder::PassManagerBuilder() { OptLevel = 2; SizeLevel = 0; @@ -70,6 +91,11 @@ PassManagerBuilder::PassManagerBuilder() { LoopVectorize = RunLoopVectorization; RerollLoops = RunLoopRerolling; LoadCombine = RunLoadCombine; + DisableGVNLoadPRE = false; + VerifyInput = false; + VerifyOutput = false; + StripDebug = false; + MergeFunctions = false; } PassManagerBuilder::~PassManagerBuilder() { @@ -106,7 +132,10 @@ PassManagerBuilder::addInitialAliasAnalysisPasses(PassManagerBase &PM) const { // Add TypeBasedAliasAnalysis before BasicAliasAnalysis so that // BasicAliasAnalysis wins if they disagree. This is intended to help // support "obvious" type-punning idioms. + if (UseCFLAA) + PM.add(createCFLAliasAnalysisPass()); PM.add(createTypeBasedAliasAnalysisPass()); + PM.add(createScopedNoAliasAAPass()); PM.add(createBasicAliasAnalysisPass()); } @@ -130,18 +159,22 @@ void PassManagerBuilder::populateFunctionPassManager(FunctionPassManager &FPM) { } void PassManagerBuilder::populateModulePassManager(PassManagerBase &MPM) { - // If all optimizations are disabled, just run the always-inline pass. + // If all optimizations are disabled, just run the always-inline pass and, + // if enabled, the function merging pass. if (OptLevel == 0) { if (Inliner) { MPM.add(Inliner); Inliner = nullptr; } - // FIXME: This is a HACK! The inliner pass above implicitly creates a CGSCC - // pass manager, but we don't want to add extensions into that pass manager. - // To prevent this we must insert a no-op module pass to reset the pass - // manager to get the same behavior as EP_OptimizerLast in non-O0 builds. - if (!GlobalExtensions->empty() || !Extensions.empty()) + // FIXME: The BarrierNoopPass is a HACK! The inliner pass above implicitly + // creates a CGSCC pass manager, but we don't want to add extensions into + // that pass manager. To prevent this we insert a no-op module pass to reset + // the pass manager to get the same behavior as EP_OptimizerLast in non-O0 + // builds. The function merging pass is + if (MergeFunctions) + MPM.add(createMergeFunctionsPass()); + else if (!GlobalExtensions->empty() || !Extensions.empty()) MPM.add(createBarrierNoopPass()); addExtensionsToPM(EP_EnabledOnOptLevel0, MPM); @@ -195,7 +228,8 @@ void PassManagerBuilder::populateModulePassManager(PassManagerBase &MPM) { MPM.add(createTailCallEliminationPass()); // Eliminate tail calls MPM.add(createCFGSimplificationPass()); // Merge & remove BBs MPM.add(createReassociatePass()); // Reassociate expressions - MPM.add(createLoopRotatePass()); // Rotate Loop + // Rotate Loop - disable header duplication at -Oz + MPM.add(createLoopRotatePass(SizeLevel == 2 ? 0 : -1)); MPM.add(createLICMPass()); // Hoist loop invariants MPM.add(createLoopUnswitchPass(SizeLevel || OptLevel < 3)); MPM.add(createInstructionCombiningPass()); @@ -208,8 +242,9 @@ void PassManagerBuilder::populateModulePassManager(PassManagerBase &MPM) { addExtensionsToPM(EP_LoopOptimizerEnd, MPM); if (OptLevel > 1) { - MPM.add(createMergedLoadStoreMotionPass()); // Merge load/stores in diamond - MPM.add(createGVNPass()); // Remove redundancies + if (EnableMLSM) + MPM.add(createMergedLoadStoreMotionPass()); // Merge ld/st in diamonds + MPM.add(createGVNPass(DisableGVNLoadPRE)); // Remove redundancies } MPM.add(createMemCpyOptPass()); // Remove memcpy / form memset MPM.add(createSCCPPass()); // Constant prop with SCCP @@ -226,21 +261,23 @@ void PassManagerBuilder::populateModulePassManager(PassManagerBase &MPM) { if (RerollLoops) MPM.add(createLoopRerollPass()); - if (SLPVectorize) - MPM.add(createSLPVectorizerPass()); // Vectorize parallel scalar chains. - - if (BBVectorize) { - MPM.add(createBBVectorizePass()); - MPM.add(createInstructionCombiningPass()); - addExtensionsToPM(EP_Peephole, MPM); - if (OptLevel > 1 && UseGVNAfterVectorization) - MPM.add(createGVNPass()); // Remove redundancies - else - MPM.add(createEarlyCSEPass()); // Catch trivial redundancies - - // BBVectorize may have significantly shortened a loop body; unroll again. - if (!DisableUnrollLoops) - MPM.add(createLoopUnrollPass()); + if (!RunSLPAfterLoopVectorization) { + if (SLPVectorize) + MPM.add(createSLPVectorizerPass()); // Vectorize parallel scalar chains. + + if (BBVectorize) { + MPM.add(createBBVectorizePass()); + MPM.add(createInstructionCombiningPass()); + addExtensionsToPM(EP_Peephole, MPM); + if (OptLevel > 1 && UseGVNAfterVectorization) + MPM.add(createGVNPass(DisableGVNLoadPRE)); // Remove redundancies + else + MPM.add(createEarlyCSEPass()); // Catch trivial redundancies + + // BBVectorize may have significantly shortened a loop body; unroll again. + if (!DisableUnrollLoops) + MPM.add(createLoopUnrollPass()); + } } if (LoadCombine) @@ -255,6 +292,13 @@ void PassManagerBuilder::populateModulePassManager(PassManagerBase &MPM) { // pass manager that we are specifically trying to avoid. To prevent this // we must insert a no-op module pass to reset the pass manager. MPM.add(createBarrierNoopPass()); + + // Re-rotate loops in all our loop nests. These may have fallout out of + // rotated form due to GVN or other transformations, and the vectorizer relies + // on the rotated form. + if (ExtraVectorizerPasses) + MPM.add(createLoopRotatePass()); + MPM.add(createLoopVectorizePass(DisableUnrollLoops, LoopVectorize)); // FIXME: Because of #pragma vectorize enable, the passes below are always // inserted in the pipeline, even when the vectorizer doesn't run (ex. when @@ -262,12 +306,56 @@ void PassManagerBuilder::populateModulePassManager(PassManagerBase &MPM) { // as function calls, so that we can only pass them when the vectorizer // changed the code. MPM.add(createInstructionCombiningPass()); + if (OptLevel > 1 && ExtraVectorizerPasses) { + // At higher optimization levels, try to clean up any runtime overlap and + // alignment checks inserted by the vectorizer. We want to track correllated + // runtime checks for two inner loops in the same outer loop, fold any + // common computations, hoist loop-invariant aspects out of any outer loop, + // and unswitch the runtime checks if possible. Once hoisted, we may have + // dead (or speculatable) control flows or more combining opportunities. + MPM.add(createEarlyCSEPass()); + MPM.add(createCorrelatedValuePropagationPass()); + MPM.add(createInstructionCombiningPass()); + MPM.add(createLICMPass()); + MPM.add(createLoopUnswitchPass(SizeLevel || OptLevel < 3)); + MPM.add(createCFGSimplificationPass()); + MPM.add(createInstructionCombiningPass()); + } + + if (RunSLPAfterLoopVectorization) { + if (SLPVectorize) { + MPM.add(createSLPVectorizerPass()); // Vectorize parallel scalar chains. + if (OptLevel > 1 && ExtraVectorizerPasses) { + MPM.add(createEarlyCSEPass()); + } + } + + if (BBVectorize) { + MPM.add(createBBVectorizePass()); + MPM.add(createInstructionCombiningPass()); + addExtensionsToPM(EP_Peephole, MPM); + if (OptLevel > 1 && UseGVNAfterVectorization) + MPM.add(createGVNPass(DisableGVNLoadPRE)); // Remove redundancies + else + MPM.add(createEarlyCSEPass()); // Catch trivial redundancies + + // BBVectorize may have significantly shortened a loop body; unroll again. + if (!DisableUnrollLoops) + MPM.add(createLoopUnrollPass()); + } + } + addExtensionsToPM(EP_Peephole, MPM); MPM.add(createCFGSimplificationPass()); + MPM.add(createInstructionCombiningPass()); if (!DisableUnrollLoops) MPM.add(createLoopUnrollPass()); // Unroll small loops + // After vectorization and unrolling, assume intrinsics may tell us more + // about pointer alignments. + MPM.add(createAlignmentFromAssumptionsPass()); + if (!DisableUnitAtATime) { // FIXME: We shouldn't bother with this anymore. MPM.add(createStripDeadPrototypesPass()); // Get rid of dead prototypes @@ -279,22 +367,17 @@ void PassManagerBuilder::populateModulePassManager(PassManagerBase &MPM) { MPM.add(createConstantMergePass()); // Merge dup global constants } } + + if (MergeFunctions) + MPM.add(createMergeFunctionsPass()); + addExtensionsToPM(EP_OptimizerLast, MPM); } -void PassManagerBuilder::populateLTOPassManager(PassManagerBase &PM, - bool Internalize, - bool RunInliner, - bool DisableGVNLoadPRE) { +void PassManagerBuilder::addLTOOptimizationPasses(PassManagerBase &PM) { // Provide AliasAnalysis services for optimizations. addInitialAliasAnalysisPasses(PM); - // Now that composite has been compiled, scan through the module, looking - // for a main function. If main is defined, mark all other functions - // internal. - if (Internalize) - PM.add(createInternalizePass("main")); - // Propagate constants at call sites into the functions they call. This // opens opportunities for globalopt (and inlining) by substituting function // pointers passed as arguments to direct uses of functions. @@ -318,8 +401,11 @@ void PassManagerBuilder::populateLTOPassManager(PassManagerBase &PM, addExtensionsToPM(EP_Peephole, PM); // Inline small functions - if (RunInliner) - PM.add(createFunctionInliningPass()); + bool RunInliner = Inliner; + if (RunInliner) { + PM.add(Inliner); + Inliner = nullptr; + } PM.add(createPruneEHPass()); // Remove dead EH info. @@ -348,7 +434,8 @@ void PassManagerBuilder::populateLTOPassManager(PassManagerBase &PM, PM.add(createGlobalsModRefPass()); // IP alias analysis. PM.add(createLICMPass()); // Hoist loop invariants. - PM.add(createMergedLoadStoreMotionPass()); // Merge load/stores in diamonds + if (EnableMLSM) + PM.add(createMergedLoadStoreMotionPass()); // Merge ld/st in diamonds. PM.add(createGVNPass(DisableGVNLoadPRE)); // Remove redundancies. PM.add(createMemCpyOptPass()); // Remove dead memcpys. @@ -358,10 +445,16 @@ void PassManagerBuilder::populateLTOPassManager(PassManagerBase &PM, // More loops are countable; try to optimize them. PM.add(createIndVarSimplifyPass()); PM.add(createLoopDeletionPass()); - PM.add(createLoopVectorizePass(true, true)); + PM.add(createLoopVectorizePass(true, LoopVectorize)); // More scalar chains could be vectorized due to more alias information - PM.add(createSLPVectorizerPass()); // Vectorize parallel scalar chains. + if (RunSLPAfterLoopVectorization) + if (SLPVectorize) + PM.add(createSLPVectorizerPass()); // Vectorize parallel scalar chains. + + // After vectorization, assume intrinsics may tell us more about pointer + // alignments. + PM.add(createAlignmentFromAssumptionsPass()); if (LoadCombine) PM.add(createLoadCombinePass()); @@ -377,6 +470,39 @@ void PassManagerBuilder::populateLTOPassManager(PassManagerBase &PM, // Now that we have optimized the program, discard unreachable functions. PM.add(createGlobalDCEPass()); + + // FIXME: this is profitable (for compiler time) to do at -O0 too, but + // currently it damages debug info. + if (MergeFunctions) + PM.add(createMergeFunctionsPass()); +} + +void PassManagerBuilder::populateLTOPassManager(PassManagerBase &PM, + TargetMachine *TM) { + if (TM) { + PM.add(new DataLayoutPass()); + TM->addAnalysisPasses(PM); + } + + if (LibraryInfo) + PM.add(new TargetLibraryInfo(*LibraryInfo)); + + if (VerifyInput) + PM.add(createVerifierPass()); + + if (StripDebug) + PM.add(createStripSymbolsPass(true)); + + if (VerifyInput) + PM.add(createDebugInfoVerifierPass()); + + if (OptLevel != 0) + addLTOOptimizationPasses(PM); + + if (VerifyOutput) { + PM.add(createVerifierPass()); + PM.add(createDebugInfoVerifierPass()); + } } inline PassManagerBuilder *unwrap(LLVMPassManagerBuilderRef P) { @@ -460,5 +586,11 @@ void LLVMPassManagerBuilderPopulateLTOPassManager(LLVMPassManagerBuilderRef PMB, LLVMBool RunInliner) { PassManagerBuilder *Builder = unwrap(PMB); PassManagerBase *LPM = unwrap(PM); - Builder->populateLTOPassManager(*LPM, Internalize != 0, RunInliner != 0); + + // A small backwards compatibility hack. populateLTOPassManager used to take + // an RunInliner option. + if (RunInliner && !Builder->Inliner) + Builder->Inliner = createFunctionInliningPass(); + + Builder->populateLTOPassManager(*LPM); } diff --git a/lib/Transforms/IPO/PruneEH.cpp b/lib/Transforms/IPO/PruneEH.cpp index b2c4a099b020..7bd4ce12860d 100644 --- a/lib/Transforms/IPO/PruneEH.cpp +++ b/lib/Transforms/IPO/PruneEH.cpp @@ -200,7 +200,7 @@ bool PruneEH::SimplifyFunction(Function *F) { BB->getInstList().pop_back(); // If the unwind block is now dead, nuke it. - if (pred_begin(UnwindBlock) == pred_end(UnwindBlock)) + if (pred_empty(UnwindBlock)) DeleteBasicBlock(UnwindBlock); // Delete the new BB. ++NumRemoved; @@ -234,7 +234,7 @@ bool PruneEH::SimplifyFunction(Function *F) { /// updating the callgraph to reflect any now-obsolete edges due to calls that /// exist in the BB. void PruneEH::DeleteBasicBlock(BasicBlock *BB) { - assert(pred_begin(BB) == pred_end(BB) && "BB is not dead!"); + assert(pred_empty(BB) && "BB is not dead!"); CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph(); CallGraphNode *CGN = CG[BB->getParent()]; diff --git a/lib/Transforms/IPO/StripSymbols.cpp b/lib/Transforms/IPO/StripSymbols.cpp index 1abbccc0fc82..816978ea9ce6 100644 --- a/lib/Transforms/IPO/StripSymbols.cpp +++ b/lib/Transforms/IPO/StripSymbols.cpp @@ -154,9 +154,8 @@ static void RemoveDeadConstant(Constant *C) { C->destroyConstant(); // If the constant referenced anything, see if we can delete it as well. - for (SmallPtrSet<Constant*, 4>::iterator OI = Operands.begin(), - OE = Operands.end(); OI != OE; ++OI) - RemoveDeadConstant(*OI); + for (Constant *O : Operands) + RemoveDeadConstant(O); } // Strip the symbol table of its names. @@ -191,7 +190,7 @@ static void StripTypeNames(Module &M, bool PreserveDbgInfo) { /// Find values that are marked as llvm.used. static void findUsedValues(GlobalVariable *LLVMUsed, - SmallPtrSet<const GlobalValue*, 8> &UsedValues) { + SmallPtrSetImpl<const GlobalValue*> &UsedValues) { if (!LLVMUsed) return; UsedValues.insert(LLVMUsed); @@ -302,8 +301,8 @@ bool StripDeadDebugInfo::runOnModule(Module &M) { // For each compile unit, find the live set of global variables/functions and // replace the current list of potentially dead global variables/functions // with the live list. - SmallVector<Value *, 64> LiveGlobalVariables; - SmallVector<Value *, 64> LiveSubprograms; + SmallVector<Metadata *, 64> LiveGlobalVariables; + SmallVector<Metadata *, 64> LiveSubprograms; DenseSet<const MDNode *> VisitedSet; for (DICompileUnit DIC : F.compile_units()) { @@ -350,28 +349,12 @@ bool StripDeadDebugInfo::runOnModule(Module &M) { // subprogram list/global variable list with our new live subprogram/global // variable list. if (SubprogramChange) { - // Make sure that 9 is still the index of the subprograms. This is to make - // sure that an assert is hit if the location of the subprogram array - // changes. This is just to make sure that this is updated if such an - // event occurs. - assert(DIC->getNumOperands() >= 10 && - SPs == DIC->getOperand(9) && - "DICompileUnits is expected to store Subprograms in operand " - "9."); - DIC->replaceOperandWith(9, MDNode::get(C, LiveSubprograms)); + DIC.replaceSubprograms(DIArray(MDNode::get(C, LiveSubprograms))); Changed = true; } if (GlobalVariableChange) { - // Make sure that 10 is still the index of global variables. This is to - // make sure that an assert is hit if the location of the subprogram array - // changes. This is just to make sure that this index is updated if such - // an event occurs. - assert(DIC->getNumOperands() >= 11 && - GVs == DIC->getOperand(10) && - "DICompileUnits is expected to store Global Variables in operand " - "10."); - DIC->replaceOperandWith(10, MDNode::get(C, LiveGlobalVariables)); + DIC.replaceGlobalVariables(DIArray(MDNode::get(C, LiveGlobalVariables))); Changed = true; } diff --git a/lib/Transforms/InstCombine/InstCombine.h b/lib/Transforms/InstCombine/InstCombine.h index ab4dc1ce23e6..3c3c13551937 100644 --- a/lib/Transforms/InstCombine/InstCombine.h +++ b/lib/Transforms/InstCombine/InstCombine.h @@ -7,16 +7,19 @@ // //===----------------------------------------------------------------------===// -#ifndef INSTCOMBINE_INSTCOMBINE_H -#define INSTCOMBINE_INSTCOMBINE_H +#ifndef LLVM_LIB_TRANSFORMS_INSTCOMBINE_INSTCOMBINE_H +#define LLVM_LIB_TRANSFORMS_INSTCOMBINE_INSTCOMBINE_H #include "InstCombineWorklist.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/TargetFolder.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/Dominators.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstVisitor.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Operator.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/Pass.h" #include "llvm/Transforms/Utils/SimplifyLibCalls.h" @@ -25,6 +28,7 @@ namespace llvm { class CallSite; class DataLayout; +class DominatorTree; class TargetLibraryInfo; class DbgDeclareInst; class MemIntrinsic; @@ -71,14 +75,20 @@ static inline Constant *SubOne(Constant *C) { class LLVM_LIBRARY_VISIBILITY InstCombineIRInserter : public IRBuilderDefaultInserter<true> { InstCombineWorklist &Worklist; + AssumptionCache *AC; public: - InstCombineIRInserter(InstCombineWorklist &WL) : Worklist(WL) {} + InstCombineIRInserter(InstCombineWorklist &WL, AssumptionCache *AC) + : Worklist(WL), AC(AC) {} void InsertHelper(Instruction *I, const Twine &Name, BasicBlock *BB, BasicBlock::iterator InsertPt) const { IRBuilderDefaultInserter<true>::InsertHelper(I, Name, BB, InsertPt); Worklist.Add(I); + + using namespace llvm::PatternMatch; + if (match(I, m_Intrinsic<Intrinsic::assume>())) + AC->registerAssumption(cast<CallInst>(I)); } }; @@ -86,8 +96,10 @@ public: class LLVM_LIBRARY_VISIBILITY InstCombiner : public FunctionPass, public InstVisitor<InstCombiner, Instruction *> { + AssumptionCache *AC; const DataLayout *DL; TargetLibraryInfo *TLI; + DominatorTree *DT; bool MadeIRChange; LibCallSimplifier *Simplifier; bool MinimizeSize; @@ -102,7 +114,8 @@ public: BuilderTy *Builder; static char ID; // Pass identification, replacement for typeid - InstCombiner() : FunctionPass(ID), DL(nullptr), Builder(nullptr) { + InstCombiner() + : FunctionPass(ID), DL(nullptr), DT(nullptr), Builder(nullptr) { MinimizeSize = false; initializeInstCombinerPass(*PassRegistry::getPassRegistry()); } @@ -114,7 +127,11 @@ public: void getAnalysisUsage(AnalysisUsage &AU) const override; + AssumptionCache *getAssumptionCache() const { return AC; } + const DataLayout *getDataLayout() const { return DL; } + + DominatorTree *getDominatorTree() const { return DT; } TargetLibraryInfo *getTargetLibraryInfo() const { return TLI; } @@ -145,13 +162,16 @@ public: Instruction *visitUDiv(BinaryOperator &I); Instruction *visitSDiv(BinaryOperator &I); Instruction *visitFDiv(BinaryOperator &I); + Value *simplifyRangeCheck(ICmpInst *Cmp0, ICmpInst *Cmp1, bool Inverted); Value *FoldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS); Value *FoldAndOfFCmps(FCmpInst *LHS, FCmpInst *RHS); Instruction *visitAnd(BinaryOperator &I); - Value *FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS); + Value *FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, Instruction *CxtI); Value *FoldOrOfFCmps(FCmpInst *LHS, FCmpInst *RHS); Instruction *FoldOrWithConstants(BinaryOperator &I, Value *Op, Value *A, Value *B, Value *C); + Instruction *FoldXorWithConstants(BinaryOperator &I, Value *Op, Value *A, + Value *B, Value *C); Instruction *visitOr(BinaryOperator &I); Instruction *visitXor(BinaryOperator &I); Instruction *visitShl(BinaryOperator &I); @@ -172,6 +192,10 @@ public: ConstantInt *DivRHS); Instruction *FoldICmpShrCst(ICmpInst &ICI, BinaryOperator *DivI, ConstantInt *DivRHS); + Instruction *FoldICmpCstShrCst(ICmpInst &I, Value *Op, Value *A, + ConstantInt *CI1, ConstantInt *CI2); + Instruction *FoldICmpCstShlCst(ICmpInst &I, Value *Op, Value *A, + ConstantInt *CI1, ConstantInt *CI2); Instruction *FoldICmpAddOpCst(Instruction &ICI, Value *X, ConstantInt *CI, ICmpInst::Predicate Pred); Instruction *FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS, @@ -213,6 +237,7 @@ public: Instruction *visitStoreInst(StoreInst &SI); Instruction *visitBranchInst(BranchInst &BI); Instruction *visitSwitchInst(SwitchInst &SI); + Instruction *visitReturnInst(ReturnInst &RI); Instruction *visitInsertValueInst(InsertValueInst &IV); Instruction *visitInsertElementInst(InsertElementInst &IE); Instruction *visitExtractElementInst(ExtractElementInst &EI); @@ -223,6 +248,16 @@ public: // visitInstruction - Specify what to return for unhandled instructions... Instruction *visitInstruction(Instruction &I) { return nullptr; } + // True when DB dominates all uses of DI execpt UI. + // UI must be in the same block as DI. + // The routine checks that the DI parent and DB are different. + bool dominatesAllUses(const Instruction *DI, const Instruction *UI, + const BasicBlock *DB) const; + + // Replace select with select operand SIOpd in SI-ICmp sequence when possible + bool replacedSelectWithOperand(SelectInst *SI, const ICmpInst *Icmp, + const unsigned SIOpd); + private: bool ShouldChangeType(Type *From, Type *To) const; Value *dyn_castNegVal(Value *V) const; @@ -246,8 +281,10 @@ private: Instruction *transformZExtICmp(ICmpInst *ICI, Instruction &CI, bool DoXform = true); Instruction *transformSExtICmp(ICmpInst *ICI, Instruction &CI); - bool WillNotOverflowSignedAdd(Value *LHS, Value *RHS); - bool WillNotOverflowUnsignedAdd(Value *LHS, Value *RHS); + bool WillNotOverflowSignedAdd(Value *LHS, Value *RHS, Instruction *CxtI); + bool WillNotOverflowSignedSub(Value *LHS, Value *RHS, Instruction *CxtI); + bool WillNotOverflowUnsignedSub(Value *LHS, Value *RHS, Instruction *CxtI); + bool WillNotOverflowSignedMul(Value *LHS, Value *RHS, Instruction *CxtI); Value *EmitGEPOffset(User *GEP); Instruction *scalarizePHI(ExtractElementInst &EI, PHINode *PN); Value *EvaluateInDifferentElementOrder(Value *V, ArrayRef<int> Mask); @@ -294,6 +331,20 @@ public: return &I; } + /// Creates a result tuple for an overflow intrinsic \p II with a given + /// \p Result and a constant \p Overflow value. If \p ReUseName is true the + /// \p Result's name is taken from \p II. + Instruction *CreateOverflowTuple(IntrinsicInst *II, Value *Result, + bool Overflow, bool ReUseName = true) { + if (ReUseName) + Result->takeName(II); + Constant *V[] = { UndefValue::get(Result->getType()), + Overflow ? Builder->getTrue() : Builder->getFalse() }; + StructType *ST = cast<StructType>(II->getType()); + Constant *Struct = ConstantStruct::get(ST, V); + return InsertValueInst::Create(Struct, Result, 0); + } + // EraseInstFromFunction - When dealing with an instruction that has side // effects or produces a void value, we can't rely on DCE to delete the // instruction. Instead, visit methods should return the value returned by @@ -316,16 +367,32 @@ public: } void computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, - unsigned Depth = 0) const { - return llvm::computeKnownBits(V, KnownZero, KnownOne, DL, Depth); + unsigned Depth = 0, Instruction *CxtI = nullptr) const { + return llvm::computeKnownBits(V, KnownZero, KnownOne, DL, Depth, AC, CxtI, + DT); } bool MaskedValueIsZero(Value *V, const APInt &Mask, - unsigned Depth = 0) const { - return llvm::MaskedValueIsZero(V, Mask, DL, Depth); + unsigned Depth = 0, + Instruction *CxtI = nullptr) const { + return llvm::MaskedValueIsZero(V, Mask, DL, Depth, AC, CxtI, DT); + } + unsigned ComputeNumSignBits(Value *Op, unsigned Depth = 0, + Instruction *CxtI = nullptr) const { + return llvm::ComputeNumSignBits(Op, DL, Depth, AC, CxtI, DT); + } + void ComputeSignBit(Value *V, bool &KnownZero, bool &KnownOne, + unsigned Depth = 0, Instruction *CxtI = nullptr) const { + return llvm::ComputeSignBit(V, KnownZero, KnownOne, DL, Depth, AC, CxtI, + DT); + } + OverflowResult computeOverflowForUnsignedMul(Value *LHS, Value *RHS, + const Instruction *CxtI) { + return llvm::computeOverflowForUnsignedMul(LHS, RHS, DL, AC, CxtI, DT); } - unsigned ComputeNumSignBits(Value *Op, unsigned Depth = 0) const { - return llvm::ComputeNumSignBits(Op, DL, Depth); + OverflowResult computeOverflowForUnsignedAdd(Value *LHS, Value *RHS, + const Instruction *CxtI) { + return llvm::computeOverflowForUnsignedAdd(LHS, RHS, DL, AC, CxtI, DT); } private: @@ -343,7 +410,8 @@ private: /// SimplifyDemandedUseBits - Attempts to replace V with a simpler value /// based on the demanded bits. Value *SimplifyDemandedUseBits(Value *V, APInt DemandedMask, APInt &KnownZero, - APInt &KnownOne, unsigned Depth); + APInt &KnownOne, unsigned Depth, + Instruction *CxtI = nullptr); bool SimplifyDemandedBits(Use &U, APInt DemandedMask, APInt &KnownZero, APInt &KnownOne, unsigned Depth = 0); /// Helper routine of SimplifyDemandedUseBits. It tries to simplify demanded @@ -361,6 +429,7 @@ private: APInt &UndefElts, unsigned Depth = 0); Value *SimplifyVectorOp(BinaryOperator &Inst); + Value *SimplifyBSwap(BinaryOperator &Inst); // FoldOpIntoPhi - Given a binary operator, cast instruction, or select // which has a PHI node as operand #0, see if we can fold the instruction diff --git a/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/lib/Transforms/InstCombine/InstCombineAddSub.cpp index e80d6a9ee39b..6d20384e5d17 100644 --- a/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -751,8 +751,7 @@ Value *FAddCombine::createNaryFAdd return LastVal; } -Value *FAddCombine::createFSub - (Value *Opnd0, Value *Opnd1) { +Value *FAddCombine::createFSub(Value *Opnd0, Value *Opnd1) { Value *V = Builder->CreateFSub(Opnd0, Opnd1); if (Instruction *I = dyn_cast<Instruction>(V)) createInstPostProc(I); @@ -760,15 +759,14 @@ Value *FAddCombine::createFSub } Value *FAddCombine::createFNeg(Value *V) { - Value *Zero = cast<Value>(ConstantFP::get(V->getType(), 0.0)); + Value *Zero = cast<Value>(ConstantFP::getZeroValueForNegation(V->getType())); Value *NewV = createFSub(Zero, V); if (Instruction *I = dyn_cast<Instruction>(NewV)) createInstPostProc(I, true); // fneg's don't receive instruction numbers. return NewV; } -Value *FAddCombine::createFAdd - (Value *Opnd0, Value *Opnd1) { +Value *FAddCombine::createFAdd(Value *Opnd0, Value *Opnd1) { Value *V = Builder->CreateFAdd(Opnd0, Opnd1); if (Instruction *I = dyn_cast<Instruction>(V)) createInstPostProc(I); @@ -789,8 +787,7 @@ Value *FAddCombine::createFDiv(Value *Opnd0, Value *Opnd1) { return V; } -void FAddCombine::createInstPostProc(Instruction *NewInstr, - bool NoNumber) { +void FAddCombine::createInstPostProc(Instruction *NewInstr, bool NoNumber) { NewInstr->setDebugLoc(Instr->getDebugLoc()); // Keep track of the number of instruction created. @@ -840,8 +837,7 @@ unsigned FAddCombine::calcInstrNumber(const AddendVect &Opnds) { // <C, V> "fmul V, C" false // // NOTE: Keep this function in sync with FAddCombine::calcInstrNumber. -Value *FAddCombine::createAddendVal - (const FAddend &Opnd, bool &NeedNeg) { +Value *FAddCombine::createAddendVal(const FAddend &Opnd, bool &NeedNeg) { const FAddendCoef &Coeff = Opnd.getCoef(); if (Opnd.isConstant()) { @@ -894,8 +890,8 @@ static bool checkRippleForAdd(const APInt &Op0KnownZero, /// (sext (add LHS, RHS)) === (add (sext LHS), (sext RHS)) /// This basically requires proving that the add in the original type would not /// overflow to change the sign bit or have a carry out. -/// TODO: Handle this for Vectors. -bool InstCombiner::WillNotOverflowSignedAdd(Value *LHS, Value *RHS) { +bool InstCombiner::WillNotOverflowSignedAdd(Value *LHS, Value *RHS, + Instruction *CxtI) { // There are different heuristics we can use for this. Here are some simple // ones. @@ -913,44 +909,76 @@ bool InstCombiner::WillNotOverflowSignedAdd(Value *LHS, Value *RHS) { // // Since the carry into the most significant position is always equal to // the carry out of the addition, there is no signed overflow. - if (ComputeNumSignBits(LHS) > 1 && ComputeNumSignBits(RHS) > 1) + if (ComputeNumSignBits(LHS, 0, CxtI) > 1 && + ComputeNumSignBits(RHS, 0, CxtI) > 1) return true; - if (IntegerType *IT = dyn_cast<IntegerType>(LHS->getType())) { - int BitWidth = IT->getBitWidth(); - APInt LHSKnownZero(BitWidth, 0); - APInt LHSKnownOne(BitWidth, 0); - computeKnownBits(LHS, LHSKnownZero, LHSKnownOne); - - APInt RHSKnownZero(BitWidth, 0); - APInt RHSKnownOne(BitWidth, 0); - computeKnownBits(RHS, RHSKnownZero, RHSKnownOne); - - // Addition of two 2's compliment numbers having opposite signs will never - // overflow. - if ((LHSKnownOne[BitWidth - 1] && RHSKnownZero[BitWidth - 1]) || - (LHSKnownZero[BitWidth - 1] && RHSKnownOne[BitWidth - 1])) - return true; - - // Check if carry bit of addition will not cause overflow. - if (checkRippleForAdd(LHSKnownZero, RHSKnownZero)) - return true; - if (checkRippleForAdd(RHSKnownZero, LHSKnownZero)) - return true; - } + unsigned BitWidth = LHS->getType()->getScalarSizeInBits(); + APInt LHSKnownZero(BitWidth, 0); + APInt LHSKnownOne(BitWidth, 0); + computeKnownBits(LHS, LHSKnownZero, LHSKnownOne, 0, CxtI); + + APInt RHSKnownZero(BitWidth, 0); + APInt RHSKnownOne(BitWidth, 0); + computeKnownBits(RHS, RHSKnownZero, RHSKnownOne, 0, CxtI); + + // Addition of two 2's compliment numbers having opposite signs will never + // overflow. + if ((LHSKnownOne[BitWidth - 1] && RHSKnownZero[BitWidth - 1]) || + (LHSKnownZero[BitWidth - 1] && RHSKnownOne[BitWidth - 1])) + return true; + + // Check if carry bit of addition will not cause overflow. + if (checkRippleForAdd(LHSKnownZero, RHSKnownZero)) + return true; + if (checkRippleForAdd(RHSKnownZero, LHSKnownZero)) + return true; + + return false; +} + +/// \brief Return true if we can prove that: +/// (sub LHS, RHS) === (sub nsw LHS, RHS) +/// This basically requires proving that the add in the original type would not +/// overflow to change the sign bit or have a carry out. +/// TODO: Handle this for Vectors. +bool InstCombiner::WillNotOverflowSignedSub(Value *LHS, Value *RHS, + Instruction *CxtI) { + // If LHS and RHS each have at least two sign bits, the subtraction + // cannot overflow. + if (ComputeNumSignBits(LHS, 0, CxtI) > 1 && + ComputeNumSignBits(RHS, 0, CxtI) > 1) + return true; + + unsigned BitWidth = LHS->getType()->getScalarSizeInBits(); + APInt LHSKnownZero(BitWidth, 0); + APInt LHSKnownOne(BitWidth, 0); + computeKnownBits(LHS, LHSKnownZero, LHSKnownOne, 0, CxtI); + + APInt RHSKnownZero(BitWidth, 0); + APInt RHSKnownOne(BitWidth, 0); + computeKnownBits(RHS, RHSKnownZero, RHSKnownOne, 0, CxtI); + + // Subtraction of two 2's compliment numbers having identical signs will + // never overflow. + if ((LHSKnownOne[BitWidth - 1] && RHSKnownOne[BitWidth - 1]) || + (LHSKnownZero[BitWidth - 1] && RHSKnownZero[BitWidth - 1])) + return true; + + // TODO: implement logic similar to checkRippleForAdd return false; } -/// WillNotOverflowUnsignedAdd - Return true if we can prove that: -/// (zext (add LHS, RHS)) === (add (zext LHS), (zext RHS)) -bool InstCombiner::WillNotOverflowUnsignedAdd(Value *LHS, Value *RHS) { - // There are different heuristics we can use for this. Here is a simple one. - // If the sign bit of LHS and that of RHS are both zero, no unsigned wrap. +/// \brief Return true if we can prove that: +/// (sub LHS, RHS) === (sub nuw LHS, RHS) +bool InstCombiner::WillNotOverflowUnsignedSub(Value *LHS, Value *RHS, + Instruction *CxtI) { + // If the LHS is negative and the RHS is non-negative, no unsigned wrap. bool LHSKnownNonNegative, LHSKnownNegative; bool RHSKnownNonNegative, RHSKnownNegative; - ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, DL, 0); - ComputeSignBit(RHS, RHSKnownNonNegative, RHSKnownNegative, DL, 0); - if (LHSKnownNonNegative && RHSKnownNonNegative) + ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, /*Depth=*/0, CxtI); + ComputeSignBit(RHS, RHSKnownNonNegative, RHSKnownNegative, /*Depth=*/0, CxtI); + if (LHSKnownNegative && RHSKnownNonNegative) return true; return false; @@ -1025,7 +1053,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { return ReplaceInstUsesWith(I, V); if (Value *V = SimplifyAddInst(LHS, RHS, I.hasNoSignedWrap(), - I.hasNoUnsignedWrap(), DL)) + I.hasNoUnsignedWrap(), DL, TLI, DT, AC)) return ReplaceInstUsesWith(I, V); // (A*B)+(A*C) -> A*(B+C) etc @@ -1064,7 +1092,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { if (ExtendAmt) { APInt Mask = APInt::getHighBitsSet(TySizeBits, ExtendAmt); - if (!MaskedValueIsZero(XorLHS, Mask)) + if (!MaskedValueIsZero(XorLHS, Mask, 0, &I)) ExtendAmt = 0; } @@ -1080,7 +1108,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { IntegerType *IT = cast<IntegerType>(I.getType()); APInt LHSKnownOne(IT->getBitWidth(), 0); APInt LHSKnownZero(IT->getBitWidth(), 0); - computeKnownBits(XorLHS, LHSKnownZero, LHSKnownOne); + computeKnownBits(XorLHS, LHSKnownZero, LHSKnownOne, 0, &I); if ((XorRHS->getValue() | LHSKnownZero).isAllOnesValue()) return BinaryOperator::CreateSub(ConstantExpr::getAdd(XorRHS, CI), XorLHS); @@ -1133,11 +1161,11 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { if (IntegerType *IT = dyn_cast<IntegerType>(I.getType())) { APInt LHSKnownOne(IT->getBitWidth(), 0); APInt LHSKnownZero(IT->getBitWidth(), 0); - computeKnownBits(LHS, LHSKnownZero, LHSKnownOne); + computeKnownBits(LHS, LHSKnownZero, LHSKnownOne, 0, &I); if (LHSKnownZero != 0) { APInt RHSKnownOne(IT->getBitWidth(), 0); APInt RHSKnownZero(IT->getBitWidth(), 0); - computeKnownBits(RHS, RHSKnownZero, RHSKnownOne); + computeKnownBits(RHS, RHSKnownZero, RHSKnownOne, 0, &I); // No bits in common -> bitwise or. if ((LHSKnownZero|RHSKnownZero).isAllOnesValue()) @@ -1215,7 +1243,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { ConstantExpr::getTrunc(RHSC, LHSConv->getOperand(0)->getType()); if (LHSConv->hasOneUse() && ConstantExpr::getSExt(CI, I.getType()) == RHSC && - WillNotOverflowSignedAdd(LHSConv->getOperand(0), CI)) { + WillNotOverflowSignedAdd(LHSConv->getOperand(0), CI, &I)) { // Insert the new, smaller add. Value *NewAdd = Builder->CreateNSWAdd(LHSConv->getOperand(0), CI, "addconv"); @@ -1231,7 +1259,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { if (LHSConv->getOperand(0)->getType()==RHSConv->getOperand(0)->getType()&& (LHSConv->hasOneUse() || RHSConv->hasOneUse()) && WillNotOverflowSignedAdd(LHSConv->getOperand(0), - RHSConv->getOperand(0))) { + RHSConv->getOperand(0), &I)) { // Insert the new integer add. Value *NewAdd = Builder->CreateNSWAdd(LHSConv->getOperand(0), RHSConv->getOperand(0), "addconv"); @@ -1240,7 +1268,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { } } - // Check for (x & y) + (x ^ y) + // (add (xor A, B) (and A, B)) --> (or A, B) { Value *A = nullptr, *B = nullptr; if (match(RHS, m_Xor(m_Value(A), m_Value(B))) && @@ -1254,14 +1282,38 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { return BinaryOperator::CreateOr(A, B); } + // (add (or A, B) (and A, B)) --> (add A, B) + { + Value *A = nullptr, *B = nullptr; + if (match(RHS, m_Or(m_Value(A), m_Value(B))) && + (match(LHS, m_And(m_Specific(A), m_Specific(B))) || + match(LHS, m_And(m_Specific(B), m_Specific(A))))) { + auto *New = BinaryOperator::CreateAdd(A, B); + New->setHasNoSignedWrap(I.hasNoSignedWrap()); + New->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); + return New; + } + + if (match(LHS, m_Or(m_Value(A), m_Value(B))) && + (match(RHS, m_And(m_Specific(A), m_Specific(B))) || + match(RHS, m_And(m_Specific(B), m_Specific(A))))) { + auto *New = BinaryOperator::CreateAdd(A, B); + New->setHasNoSignedWrap(I.hasNoSignedWrap()); + New->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); + return New; + } + } + // TODO(jingyue): Consider WillNotOverflowSignedAdd and // WillNotOverflowUnsignedAdd to reduce the number of invocations of // computeKnownBits. - if (!I.hasNoSignedWrap() && WillNotOverflowSignedAdd(LHS, RHS)) { + if (!I.hasNoSignedWrap() && WillNotOverflowSignedAdd(LHS, RHS, &I)) { Changed = true; I.setHasNoSignedWrap(true); } - if (!I.hasNoUnsignedWrap() && WillNotOverflowUnsignedAdd(LHS, RHS)) { + if (!I.hasNoUnsignedWrap() && + computeOverflowForUnsignedAdd(LHS, RHS, &I) == + OverflowResult::NeverOverflows) { Changed = true; I.setHasNoUnsignedWrap(true); } @@ -1276,7 +1328,8 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return ReplaceInstUsesWith(I, V); - if (Value *V = SimplifyFAddInst(LHS, RHS, I.getFastMathFlags(), DL)) + if (Value *V = + SimplifyFAddInst(LHS, RHS, I.getFastMathFlags(), DL, TLI, DT, AC)) return ReplaceInstUsesWith(I, V); if (isa<Constant>(RHS)) { @@ -1318,7 +1371,7 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) { ConstantExpr::getFPToSI(CFP, LHSConv->getOperand(0)->getType()); if (LHSConv->hasOneUse() && ConstantExpr::getSIToFP(CI, I.getType()) == CFP && - WillNotOverflowSignedAdd(LHSConv->getOperand(0), CI)) { + WillNotOverflowSignedAdd(LHSConv->getOperand(0), CI, &I)) { // Insert the new integer add. Value *NewAdd = Builder->CreateNSWAdd(LHSConv->getOperand(0), CI, "addconv"); @@ -1334,7 +1387,7 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) { if (LHSConv->getOperand(0)->getType()==RHSConv->getOperand(0)->getType()&& (LHSConv->hasOneUse() || RHSConv->hasOneUse()) && WillNotOverflowSignedAdd(LHSConv->getOperand(0), - RHSConv->getOperand(0))) { + RHSConv->getOperand(0), &I)) { // Insert the new integer add. Value *NewAdd = Builder->CreateNSWAdd(LHSConv->getOperand(0), RHSConv->getOperand(0),"addconv"); @@ -1356,11 +1409,11 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) { Z2 = dyn_cast<Constant>(B2); B = B1; } else if (match(B1, m_AnyZero()) && match(A2, m_AnyZero())) { Z1 = dyn_cast<Constant>(B1); B = B2; - Z2 = dyn_cast<Constant>(A2); A = A1; + Z2 = dyn_cast<Constant>(A2); A = A1; } - - if (Z1 && Z2 && - (I.hasNoSignedZeros() || + + if (Z1 && Z2 && + (I.hasNoSignedZeros() || (Z1->isNegativeZeroValue() && Z2->isNegativeZeroValue()))) { return SelectInst::Create(C, A, B); } @@ -1447,7 +1500,6 @@ Value *InstCombiner::OptimizePointerDifference(Value *LHS, Value *RHS, return Builder->CreateIntCast(Result, Ty, true); } - Instruction *InstCombiner::visitSub(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); @@ -1455,18 +1507,27 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { return ReplaceInstUsesWith(I, V); if (Value *V = SimplifySubInst(Op0, Op1, I.hasNoSignedWrap(), - I.hasNoUnsignedWrap(), DL)) + I.hasNoUnsignedWrap(), DL, TLI, DT, AC)) return ReplaceInstUsesWith(I, V); // (A*B)-(A*C) -> A*(B-C) etc if (Value *V = SimplifyUsingDistributiveLaws(I)) return ReplaceInstUsesWith(I, V); - // If this is a 'B = x-(-A)', change to B = x+A. This preserves NSW/NUW. + // If this is a 'B = x-(-A)', change to B = x+A. if (Value *V = dyn_castNegVal(Op1)) { BinaryOperator *Res = BinaryOperator::CreateAdd(Op0, V); - Res->setHasNoSignedWrap(I.hasNoSignedWrap()); - Res->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); + + if (const auto *BO = dyn_cast<BinaryOperator>(Op1)) { + assert(BO->getOpcode() == Instruction::Sub && + "Expected a subtraction operator!"); + if (BO->hasNoSignedWrap() && I.hasNoSignedWrap()) + Res->setHasNoSignedWrap(true); + } else { + if (cast<Constant>(Op1)->isNotMinSignedValue() && I.hasNoSignedWrap()) + Res->setHasNoSignedWrap(true); + } + return Res; } @@ -1511,21 +1572,23 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { // -(X >>u 31) -> (X >>s 31) // -(X >>s 31) -> (X >>u 31) if (C->isZero()) { - Value *X; ConstantInt *CI; + Value *X; + ConstantInt *CI; if (match(Op1, m_LShr(m_Value(X), m_ConstantInt(CI))) && // Verify we are shifting out everything but the sign bit. - CI->getValue() == I.getType()->getPrimitiveSizeInBits()-1) + CI->getValue() == I.getType()->getPrimitiveSizeInBits() - 1) return BinaryOperator::CreateAShr(X, CI); if (match(Op1, m_AShr(m_Value(X), m_ConstantInt(CI))) && // Verify we are shifting out everything but the sign bit. - CI->getValue() == I.getType()->getPrimitiveSizeInBits()-1) + CI->getValue() == I.getType()->getPrimitiveSizeInBits() - 1) return BinaryOperator::CreateLShr(X, CI); } } - { Value *Y; + { + Value *Y; // X-(X+Y) == -Y X-(Y+X) == -Y if (match(Op1, m_Add(m_Specific(Op0), m_Value(Y))) || match(Op1, m_Add(m_Value(Y), m_Specific(Op0)))) @@ -1536,6 +1599,24 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { return BinaryOperator::CreateNeg(Y); } + // (sub (or A, B) (xor A, B)) --> (and A, B) + { + Value *A = nullptr, *B = nullptr; + if (match(Op1, m_Xor(m_Value(A), m_Value(B))) && + (match(Op0, m_Or(m_Specific(A), m_Specific(B))) || + match(Op0, m_Or(m_Specific(B), m_Specific(A))))) + return BinaryOperator::CreateAnd(A, B); + } + + if (Op0->hasOneUse()) { + Value *Y = nullptr; + // ((X | Y) - X) --> (~X & Y) + if (match(Op0, m_Or(m_Value(Y), m_Specific(Op1))) || + match(Op0, m_Or(m_Specific(Op1), m_Value(Y)))) + return BinaryOperator::CreateAnd( + Y, Builder->CreateNot(Op1, Op1->getName() + ".not")); + } + if (Op1->hasOneUse()) { Value *X = nullptr, *Y = nullptr, *Z = nullptr; Constant *C = nullptr; @@ -1555,7 +1636,7 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { // 0 - (X sdiv C) -> (X sdiv -C) provided the negation doesn't overflow. if (match(Op1, m_SDiv(m_Value(X), m_Constant(C))) && match(Op0, m_Zero()) && - !C->isMinSignedValue()) + C->isNotMinSignedValue() && !C->isOneValue()) return BinaryOperator::CreateSDiv(X, ConstantExpr::getNeg(C)); // 0 - (X << Y) -> (-X << Y) when X is freely negatable. @@ -1595,7 +1676,17 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { return ReplaceInstUsesWith(I, Res); } - return nullptr; + bool Changed = false; + if (!I.hasNoSignedWrap() && WillNotOverflowSignedSub(Op0, Op1, &I)) { + Changed = true; + I.setHasNoSignedWrap(true); + } + if (!I.hasNoUnsignedWrap() && WillNotOverflowUnsignedSub(Op0, Op1, &I)) { + Changed = true; + I.setHasNoUnsignedWrap(true); + } + + return Changed ? &I : nullptr; } Instruction *InstCombiner::visitFSub(BinaryOperator &I) { @@ -1604,9 +1695,18 @@ Instruction *InstCombiner::visitFSub(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return ReplaceInstUsesWith(I, V); - if (Value *V = SimplifyFSubInst(Op0, Op1, I.getFastMathFlags(), DL)) + if (Value *V = + SimplifyFSubInst(Op0, Op1, I.getFastMathFlags(), DL, TLI, DT, AC)) return ReplaceInstUsesWith(I, V); + // fsub nsz 0, X ==> fsub nsz -0.0, X + if (I.getFastMathFlags().noSignedZeros() && match(Op0, m_Zero())) { + // Subtraction from -0.0 is the canonical form of fneg. + Instruction *NewI = BinaryOperator::CreateFNeg(Op1); + NewI->copyFastMathFlags(&I); + return NewI; + } + if (isa<Constant>(Op0)) if (SelectInst *SI = dyn_cast<SelectInst>(Op1)) if (Instruction *NV = FoldOpIntoSelect(I, SI)) diff --git a/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index b23a606e0889..74b6970b6a53 100644 --- a/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -117,6 +117,61 @@ static Value *getFCmpValue(bool isordered, unsigned code, return Builder->CreateFCmp(Pred, LHS, RHS); } +/// \brief Transform BITWISE_OP(BSWAP(A),BSWAP(B)) to BSWAP(BITWISE_OP(A, B)) +/// \param I Binary operator to transform. +/// \return Pointer to node that must replace the original binary operator, or +/// null pointer if no transformation was made. +Value *InstCombiner::SimplifyBSwap(BinaryOperator &I) { + IntegerType *ITy = dyn_cast<IntegerType>(I.getType()); + + // Can't do vectors. + if (I.getType()->isVectorTy()) return nullptr; + + // Can only do bitwise ops. + unsigned Op = I.getOpcode(); + if (Op != Instruction::And && Op != Instruction::Or && + Op != Instruction::Xor) + return nullptr; + + Value *OldLHS = I.getOperand(0); + Value *OldRHS = I.getOperand(1); + ConstantInt *ConstLHS = dyn_cast<ConstantInt>(OldLHS); + ConstantInt *ConstRHS = dyn_cast<ConstantInt>(OldRHS); + IntrinsicInst *IntrLHS = dyn_cast<IntrinsicInst>(OldLHS); + IntrinsicInst *IntrRHS = dyn_cast<IntrinsicInst>(OldRHS); + bool IsBswapLHS = (IntrLHS && IntrLHS->getIntrinsicID() == Intrinsic::bswap); + bool IsBswapRHS = (IntrRHS && IntrRHS->getIntrinsicID() == Intrinsic::bswap); + + if (!IsBswapLHS && !IsBswapRHS) + return nullptr; + + if (!IsBswapLHS && !ConstLHS) + return nullptr; + + if (!IsBswapRHS && !ConstRHS) + return nullptr; + + /// OP( BSWAP(x), BSWAP(y) ) -> BSWAP( OP(x, y) ) + /// OP( BSWAP(x), CONSTANT ) -> BSWAP( OP(x, BSWAP(CONSTANT) ) ) + Value *NewLHS = IsBswapLHS ? IntrLHS->getOperand(0) : + Builder->getInt(ConstLHS->getValue().byteSwap()); + + Value *NewRHS = IsBswapRHS ? IntrRHS->getOperand(0) : + Builder->getInt(ConstRHS->getValue().byteSwap()); + + Value *BinOp = nullptr; + if (Op == Instruction::And) + BinOp = Builder->CreateAnd(NewLHS, NewRHS); + else if (Op == Instruction::Or) + BinOp = Builder->CreateOr(NewLHS, NewRHS); + else //if (Op == Instruction::Xor) + BinOp = Builder->CreateXor(NewLHS, NewRHS); + + Module *M = I.getParent()->getParent()->getParent(); + Function *F = Intrinsic::getDeclaration(M, Intrinsic::bswap, ITy); + return Builder->CreateCall(F, BinOp); +} + // OptAndOp - This handles expressions of the form ((val OP C1) & C2). Where // the Op parameter is 'OP', OpRHS is 'C1', and AndRHS is 'C2'. Op is // guaranteed to be a binary operator. @@ -355,7 +410,7 @@ Value *InstCombiner::FoldLogicalPlusAnd(Value *LHS, Value *RHS, if (isRunOfOnes(Mask, MB, ME)) { // begin/end bit of run, inclusive uint32_t BitWidth = cast<IntegerType>(RHS->getType())->getBitWidth(); APInt Mask(APInt::getLowBitsSet(BitWidth, MB-1)); - if (MaskedValueIsZero(RHS, Mask)) + if (MaskedValueIsZero(RHS, Mask, 0, &I)) break; } } @@ -614,7 +669,7 @@ static unsigned foldLogOpOfMaskedICmpsHelper(Value*& A, } else if (R1->getType()->isIntegerTy()) { if (!match(R1, m_And(m_Value(R11), m_Value(R12)))) { // As before, model no mask as a trivial mask if it'll let us do an - // optimisation. + // optimization. R11 = R1; R12 = Constant::getAllOnesValue(R1->getType()); } @@ -665,8 +720,8 @@ static unsigned foldLogOpOfMaskedICmpsHelper(Value*& A, /// foldLogOpOfMaskedICmps: /// try to fold (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E) /// into a single (icmp(A & X) ==/!= Y) -static Value* foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, - llvm::InstCombiner::BuilderTy* Builder) { +static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, + llvm::InstCombiner::BuilderTy *Builder) { Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr, *E = nullptr; ICmpInst::Predicate LHSCC = LHS->getPredicate(), RHSCC = RHS->getPredicate(); unsigned mask = foldLogOpOfMaskedICmpsHelper(A, B, C, D, E, LHS, RHS, @@ -697,26 +752,26 @@ static Value* foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, if (mask & FoldMskICmp_Mask_AllZeroes) { // (icmp eq (A & B), 0) & (icmp eq (A & D), 0) // -> (icmp eq (A & (B|D)), 0) - Value* newOr = Builder->CreateOr(B, D); - Value* newAnd = Builder->CreateAnd(A, newOr); + Value *newOr = Builder->CreateOr(B, D); + Value *newAnd = Builder->CreateAnd(A, newOr); // we can't use C as zero, because we might actually handle // (icmp ne (A & B), B) & (icmp ne (A & D), D) // with B and D, having a single bit set - Value* zero = Constant::getNullValue(A->getType()); + Value *zero = Constant::getNullValue(A->getType()); return Builder->CreateICmp(NEWCC, newAnd, zero); } if (mask & FoldMskICmp_BMask_AllOnes) { // (icmp eq (A & B), B) & (icmp eq (A & D), D) // -> (icmp eq (A & (B|D)), (B|D)) - Value* newOr = Builder->CreateOr(B, D); - Value* newAnd = Builder->CreateAnd(A, newOr); + Value *newOr = Builder->CreateOr(B, D); + Value *newAnd = Builder->CreateAnd(A, newOr); return Builder->CreateICmp(NEWCC, newAnd, newOr); } if (mask & FoldMskICmp_AMask_AllOnes) { // (icmp eq (A & B), A) & (icmp eq (A & D), A) // -> (icmp eq (A & (B&D)), A) - Value* newAnd1 = Builder->CreateAnd(B, D); - Value* newAnd = Builder->CreateAnd(A, newAnd1); + Value *newAnd1 = Builder->CreateAnd(B, D); + Value *newAnd = Builder->CreateAnd(A, newAnd1); return Builder->CreateICmp(NEWCC, newAnd, A); } @@ -766,19 +821,17 @@ static Value* foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, // with B and D, having a single bit set ConstantInt *CCst = dyn_cast<ConstantInt>(C); if (!CCst) return nullptr; - if (LHSCC != NEWCC) - CCst = dyn_cast<ConstantInt>( ConstantExpr::getXor(BCst, CCst) ); ConstantInt *ECst = dyn_cast<ConstantInt>(E); if (!ECst) return nullptr; + if (LHSCC != NEWCC) + CCst = cast<ConstantInt>(ConstantExpr::getXor(BCst, CCst)); if (RHSCC != NEWCC) - ECst = dyn_cast<ConstantInt>( ConstantExpr::getXor(DCst, ECst) ); - ConstantInt* MCst = dyn_cast<ConstantInt>( - ConstantExpr::getAnd(ConstantExpr::getAnd(BCst, DCst), - ConstantExpr::getXor(CCst, ECst)) ); + ECst = cast<ConstantInt>(ConstantExpr::getXor(DCst, ECst)); // if there is a conflict we should actually return a false for the // whole construct - if (!MCst->isZero()) - return nullptr; + if (((BCst->getValue() & DCst->getValue()) & + (CCst->getValue() ^ ECst->getValue())) != 0) + return ConstantInt::get(LHS->getType(), !IsAnd); Value *newOr1 = Builder->CreateOr(B, D); Value *newOr2 = ConstantExpr::getOr(CCst, ECst); Value *newAnd = Builder->CreateAnd(A, newOr1); @@ -787,6 +840,62 @@ static Value* foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, return nullptr; } +/// Try to fold a signed range checked with lower bound 0 to an unsigned icmp. +/// Example: (icmp sge x, 0) & (icmp slt x, n) --> icmp ult x, n +/// If \p Inverted is true then the check is for the inverted range, e.g. +/// (icmp slt x, 0) | (icmp sgt x, n) --> icmp ugt x, n +Value *InstCombiner::simplifyRangeCheck(ICmpInst *Cmp0, ICmpInst *Cmp1, + bool Inverted) { + // Check the lower range comparison, e.g. x >= 0 + // InstCombine already ensured that if there is a constant it's on the RHS. + ConstantInt *RangeStart = dyn_cast<ConstantInt>(Cmp0->getOperand(1)); + if (!RangeStart) + return nullptr; + + ICmpInst::Predicate Pred0 = (Inverted ? Cmp0->getInversePredicate() : + Cmp0->getPredicate()); + + // Accept x > -1 or x >= 0 (after potentially inverting the predicate). + if (!((Pred0 == ICmpInst::ICMP_SGT && RangeStart->isMinusOne()) || + (Pred0 == ICmpInst::ICMP_SGE && RangeStart->isZero()))) + return nullptr; + + ICmpInst::Predicate Pred1 = (Inverted ? Cmp1->getInversePredicate() : + Cmp1->getPredicate()); + + Value *Input = Cmp0->getOperand(0); + Value *RangeEnd; + if (Cmp1->getOperand(0) == Input) { + // For the upper range compare we have: icmp x, n + RangeEnd = Cmp1->getOperand(1); + } else if (Cmp1->getOperand(1) == Input) { + // For the upper range compare we have: icmp n, x + RangeEnd = Cmp1->getOperand(0); + Pred1 = ICmpInst::getSwappedPredicate(Pred1); + } else { + return nullptr; + } + + // Check the upper range comparison, e.g. x < n + ICmpInst::Predicate NewPred; + switch (Pred1) { + case ICmpInst::ICMP_SLT: NewPred = ICmpInst::ICMP_ULT; break; + case ICmpInst::ICMP_SLE: NewPred = ICmpInst::ICMP_ULE; break; + default: return nullptr; + } + + // This simplification is only valid if the upper range is not negative. + bool IsNegative, IsNotNegative; + ComputeSignBit(RangeEnd, IsNotNegative, IsNegative, /*Depth=*/0, Cmp1); + if (!IsNotNegative) + return nullptr; + + if (Inverted) + NewPred = ICmpInst::getInversePredicate(NewPred); + + return Builder->CreateICmp(NewPred, Input, RangeEnd); +} + /// FoldAndOfICmps - Fold (icmp)&(icmp) if possible. Value *InstCombiner::FoldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS) { ICmpInst::Predicate LHSCC = LHS->getPredicate(), RHSCC = RHS->getPredicate(); @@ -809,6 +918,14 @@ Value *InstCombiner::FoldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS) { if (Value *V = foldLogOpOfMaskedICmps(LHS, RHS, true, Builder)) return V; + // E.g. (icmp sge x, 0) & (icmp slt x, n) --> icmp ult x, n + if (Value *V = simplifyRangeCheck(LHS, RHS, /*Inverted=*/false)) + return V; + + // E.g. (icmp slt x, n) & (icmp sge x, 0) --> icmp ult x, n + if (Value *V = simplifyRangeCheck(RHS, LHS, /*Inverted=*/false)) + return V; + // This only handles icmp of constants: (icmp1 A, C1) & (icmp2 B, C2). Value *Val = LHS->getOperand(0), *Val2 = RHS->getOperand(0); ConstantInt *LHSCst = dyn_cast<ConstantInt>(LHS->getOperand(1)); @@ -930,6 +1047,8 @@ Value *InstCombiner::FoldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS) { case ICmpInst::ICMP_ULT: if (LHSCst == SubOne(RHSCst)) // (X != 13 & X u< 14) -> X < 13 return Builder->CreateICmpULT(Val, LHSCst); + if (LHSCst->isNullValue()) // (X != 0 & X u< 14) -> X-1 u< 13 + return InsertRangeTest(Val, AddOne(LHSCst), RHSCst, false, true); break; // (X != 13 & X u< 15) -> no change case ICmpInst::ICMP_SLT: if (LHSCst == SubOne(RHSCst)) // (X != 13 & X s< 14) -> X < 13 @@ -1101,7 +1220,6 @@ Value *InstCombiner::FoldAndOfFCmps(FCmpInst *LHS, FCmpInst *RHS) { return nullptr; } - Instruction *InstCombiner::visitAnd(BinaryOperator &I) { bool Changed = SimplifyAssociativeOrCommutative(I); Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); @@ -1109,7 +1227,7 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return ReplaceInstUsesWith(I, V); - if (Value *V = SimplifyAndInst(Op0, Op1, DL)) + if (Value *V = SimplifyAndInst(Op0, Op1, DL, TLI, DT, AC)) return ReplaceInstUsesWith(I, V); // (A|B)&(A|C) -> A|(B&C) etc @@ -1121,6 +1239,9 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { if (SimplifyDemandedInstructionBits(I)) return &I; + if (Value *V = SimplifyBSwap(I)) + return ReplaceInstUsesWith(I, V); + if (ConstantInt *AndRHS = dyn_cast<ConstantInt>(Op1)) { const APInt &AndRHSMask = AndRHS->getValue(); @@ -1136,14 +1257,14 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { if (!Op0I->hasOneUse()) break; APInt NotAndRHS(~AndRHSMask); - if (MaskedValueIsZero(Op0LHS, NotAndRHS)) { + if (MaskedValueIsZero(Op0LHS, NotAndRHS, 0, &I)) { // Not masking anything out for the LHS, move to RHS. Value *NewRHS = Builder->CreateAnd(Op0RHS, AndRHS, Op0RHS->getName()+".masked"); return BinaryOperator::Create(Op0I->getOpcode(), Op0LHS, NewRHS); } if (!isa<Constant>(Op0RHS) && - MaskedValueIsZero(Op0RHS, NotAndRHS)) { + MaskedValueIsZero(Op0RHS, NotAndRHS, 0, &I)) { // Not masking anything out for the RHS, move to LHS. Value *NewLHS = Builder->CreateAnd(Op0LHS, AndRHS, Op0LHS->getName()+".masked"); @@ -1176,7 +1297,7 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { uint32_t Zeros = AndRHSMask.countLeadingZeros(); APInt Mask = APInt::getLowBitsSet(BitWidth, BitWidth - Zeros); - if (MaskedValueIsZero(Op0LHS, Mask)) { + if (MaskedValueIsZero(Op0LHS, Mask, 0, &I)) { Value *NewNeg = Builder->CreateNeg(Op0RHS); return BinaryOperator::CreateAnd(NewNeg, AndRHS); } @@ -1283,13 +1404,58 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { if (match(Op1, m_Or(m_Not(m_Specific(Op0)), m_Value(A))) || match(Op1, m_Or(m_Value(A), m_Not(m_Specific(Op0))))) return BinaryOperator::CreateAnd(A, Op0); + + // (A ^ B) & ((B ^ C) ^ A) -> (A ^ B) & ~C + if (match(Op0, m_Xor(m_Value(A), m_Value(B)))) + if (match(Op1, m_Xor(m_Xor(m_Specific(B), m_Value(C)), m_Specific(A)))) + if (Op1->hasOneUse() || cast<BinaryOperator>(Op1)->hasOneUse()) + return BinaryOperator::CreateAnd(Op0, Builder->CreateNot(C)); + + // ((A ^ C) ^ B) & (B ^ A) -> (B ^ A) & ~C + if (match(Op0, m_Xor(m_Xor(m_Value(A), m_Value(C)), m_Value(B)))) + if (match(Op1, m_Xor(m_Specific(B), m_Specific(A)))) + if (Op0->hasOneUse() || cast<BinaryOperator>(Op0)->hasOneUse()) + return BinaryOperator::CreateAnd(Op1, Builder->CreateNot(C)); + + // (A | B) & ((~A) ^ B) -> (A & B) + if (match(Op0, m_Or(m_Value(A), m_Value(B))) && + match(Op1, m_Xor(m_Not(m_Specific(A)), m_Specific(B)))) + return BinaryOperator::CreateAnd(A, B); + + // ((~A) ^ B) & (A | B) -> (A & B) + if (match(Op0, m_Xor(m_Not(m_Value(A)), m_Value(B))) && + match(Op1, m_Or(m_Specific(A), m_Specific(B)))) + return BinaryOperator::CreateAnd(A, B); } - if (ICmpInst *RHS = dyn_cast<ICmpInst>(Op1)) - if (ICmpInst *LHS = dyn_cast<ICmpInst>(Op0)) + { + ICmpInst *LHS = dyn_cast<ICmpInst>(Op0); + ICmpInst *RHS = dyn_cast<ICmpInst>(Op1); + if (LHS && RHS) if (Value *Res = FoldAndOfICmps(LHS, RHS)) return ReplaceInstUsesWith(I, Res); + // TODO: Make this recursive; it's a little tricky because an arbitrary + // number of 'and' instructions might have to be created. + Value *X, *Y; + if (LHS && match(Op1, m_OneUse(m_And(m_Value(X), m_Value(Y))))) { + if (auto *Cmp = dyn_cast<ICmpInst>(X)) + if (Value *Res = FoldAndOfICmps(LHS, Cmp)) + return ReplaceInstUsesWith(I, Builder->CreateAnd(Res, Y)); + if (auto *Cmp = dyn_cast<ICmpInst>(Y)) + if (Value *Res = FoldAndOfICmps(LHS, Cmp)) + return ReplaceInstUsesWith(I, Builder->CreateAnd(Res, X)); + } + if (RHS && match(Op0, m_OneUse(m_And(m_Value(X), m_Value(Y))))) { + if (auto *Cmp = dyn_cast<ICmpInst>(X)) + if (Value *Res = FoldAndOfICmps(Cmp, RHS)) + return ReplaceInstUsesWith(I, Builder->CreateAnd(Res, Y)); + if (auto *Cmp = dyn_cast<ICmpInst>(Y)) + if (Value *Res = FoldAndOfICmps(Cmp, RHS)) + return ReplaceInstUsesWith(I, Builder->CreateAnd(Res, X)); + } + } + // If and'ing two fcmp, try combine them into one. if (FCmpInst *LHS = dyn_cast<FCmpInst>(I.getOperand(0))) if (FCmpInst *RHS = dyn_cast<FCmpInst>(I.getOperand(1))) @@ -1329,20 +1495,6 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { } } - // (X >> Z) & (Y >> Z) -> (X&Y) >> Z for all shifts. - if (BinaryOperator *SI1 = dyn_cast<BinaryOperator>(Op1)) { - if (BinaryOperator *SI0 = dyn_cast<BinaryOperator>(Op0)) - if (SI0->isShift() && SI0->getOpcode() == SI1->getOpcode() && - SI0->getOperand(1) == SI1->getOperand(1) && - (SI0->hasOneUse() || SI1->hasOneUse())) { - Value *NewOp = - Builder->CreateAnd(SI0->getOperand(0), SI1->getOperand(0), - SI0->getName()); - return BinaryOperator::Create(SI1->getOpcode(), NewOp, - SI1->getOperand(1)); - } - } - { Value *X = nullptr; bool OpsSwapped = false; @@ -1554,7 +1706,8 @@ static Instruction *MatchSelectFromAndOr(Value *A, Value *B, } /// FoldOrOfICmps - Fold (icmp)|(icmp) if possible. -Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS) { +Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, + Instruction *CxtI) { ICmpInst::Predicate LHSCC = LHS->getPredicate(), RHSCC = RHS->getPredicate(); // Fold (iszero(A & K1) | iszero(A & K2)) -> (A & (K1 | K2)) != (K1 | K2) @@ -1574,13 +1727,15 @@ Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS) { Value *Mask = nullptr; Value *Masked = nullptr; if (LAnd->getOperand(0) == RAnd->getOperand(0) && - isKnownToBeAPowerOfTwo(LAnd->getOperand(1)) && - isKnownToBeAPowerOfTwo(RAnd->getOperand(1))) { + isKnownToBeAPowerOfTwo(LAnd->getOperand(1), false, 0, AC, CxtI, DT) && + isKnownToBeAPowerOfTwo(RAnd->getOperand(1), false, 0, AC, CxtI, DT)) { Mask = Builder->CreateOr(LAnd->getOperand(1), RAnd->getOperand(1)); Masked = Builder->CreateAnd(LAnd->getOperand(0), Mask); } else if (LAnd->getOperand(1) == RAnd->getOperand(1) && - isKnownToBeAPowerOfTwo(LAnd->getOperand(0)) && - isKnownToBeAPowerOfTwo(RAnd->getOperand(0))) { + isKnownToBeAPowerOfTwo(LAnd->getOperand(0), false, 0, AC, CxtI, + DT) && + isKnownToBeAPowerOfTwo(RAnd->getOperand(0), false, 0, AC, CxtI, + DT)) { Mask = Builder->CreateOr(LAnd->getOperand(0), RAnd->getOperand(0)); Masked = Builder->CreateAnd(LAnd->getOperand(1), Mask); } @@ -1590,6 +1745,61 @@ Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS) { } } + // Fold (icmp ult/ule (A + C1), C3) | (icmp ult/ule (A + C2), C3) + // --> (icmp ult/ule ((A & ~(C1 ^ C2)) + max(C1, C2)), C3) + // The original condition actually refers to the following two ranges: + // [MAX_UINT-C1+1, MAX_UINT-C1+1+C3] and [MAX_UINT-C2+1, MAX_UINT-C2+1+C3] + // We can fold these two ranges if: + // 1) C1 and C2 is unsigned greater than C3. + // 2) The two ranges are separated. + // 3) C1 ^ C2 is one-bit mask. + // 4) LowRange1 ^ LowRange2 and HighRange1 ^ HighRange2 are one-bit mask. + // This implies all values in the two ranges differ by exactly one bit. + + if ((LHSCC == ICmpInst::ICMP_ULT || LHSCC == ICmpInst::ICMP_ULE) && + LHSCC == RHSCC && LHSCst && RHSCst && LHS->hasOneUse() && + RHS->hasOneUse() && LHSCst->getType() == RHSCst->getType() && + LHSCst->getValue() == (RHSCst->getValue())) { + + Value *LAdd = LHS->getOperand(0); + Value *RAdd = RHS->getOperand(0); + + Value *LAddOpnd, *RAddOpnd; + ConstantInt *LAddCst, *RAddCst; + if (match(LAdd, m_Add(m_Value(LAddOpnd), m_ConstantInt(LAddCst))) && + match(RAdd, m_Add(m_Value(RAddOpnd), m_ConstantInt(RAddCst))) && + LAddCst->getValue().ugt(LHSCst->getValue()) && + RAddCst->getValue().ugt(LHSCst->getValue())) { + + APInt DiffCst = LAddCst->getValue() ^ RAddCst->getValue(); + if (LAddOpnd == RAddOpnd && DiffCst.isPowerOf2()) { + ConstantInt *MaxAddCst = nullptr; + if (LAddCst->getValue().ult(RAddCst->getValue())) + MaxAddCst = RAddCst; + else + MaxAddCst = LAddCst; + + APInt RRangeLow = -RAddCst->getValue(); + APInt RRangeHigh = RRangeLow + LHSCst->getValue(); + APInt LRangeLow = -LAddCst->getValue(); + APInt LRangeHigh = LRangeLow + LHSCst->getValue(); + APInt LowRangeDiff = RRangeLow ^ LRangeLow; + APInt HighRangeDiff = RRangeHigh ^ LRangeHigh; + APInt RangeDiff = LRangeLow.sgt(RRangeLow) ? LRangeLow - RRangeLow + : RRangeLow - LRangeLow; + + if (LowRangeDiff.isPowerOf2() && LowRangeDiff == HighRangeDiff && + RangeDiff.ugt(LHSCst->getValue())) { + Value *MaskCst = ConstantInt::get(LAddCst->getType(), ~DiffCst); + + Value *NewAnd = Builder->CreateAnd(LAddOpnd, MaskCst); + Value *NewAdd = Builder->CreateAdd(NewAnd, MaxAddCst); + return (Builder->CreateICmp(LHS->getPredicate(), NewAdd, LHSCst)); + } + } + } + } + // (icmp1 A, B) | (icmp2 A, B) --> (icmp3 A, B) if (PredicatesFoldable(LHSCC, RHSCC)) { if (LHS->getOperand(0) == RHS->getOperand(1) && @@ -1636,6 +1846,14 @@ Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS) { Builder->CreateAdd(B, ConstantInt::getSigned(B->getType(), -1)), A); } + // E.g. (icmp slt x, 0) | (icmp sgt x, n) --> icmp ugt x, n + if (Value *V = simplifyRangeCheck(LHS, RHS, /*Inverted=*/true)) + return V; + + // E.g. (icmp sgt x, n) | (icmp slt x, 0) --> icmp ugt x, n + if (Value *V = simplifyRangeCheck(RHS, LHS, /*Inverted=*/true)) + return V; + // This only handles icmp of constants: (icmp1 A, C1) | (icmp2 B, C2). if (!LHSCst || !RHSCst) return nullptr; @@ -1906,6 +2124,38 @@ Instruction *InstCombiner::FoldOrWithConstants(BinaryOperator &I, Value *Op, return nullptr; } +/// \brief This helper function folds: +/// +/// ((A | B) & C1) ^ (B & C2) +/// +/// into: +/// +/// (A & C1) ^ B +/// +/// when the XOR of the two constants is "all ones" (-1). +Instruction *InstCombiner::FoldXorWithConstants(BinaryOperator &I, Value *Op, + Value *A, Value *B, Value *C) { + ConstantInt *CI1 = dyn_cast<ConstantInt>(C); + if (!CI1) + return nullptr; + + Value *V1 = nullptr; + ConstantInt *CI2 = nullptr; + if (!match(Op, m_And(m_Value(V1), m_ConstantInt(CI2)))) + return nullptr; + + APInt Xor = CI1->getValue() ^ CI2->getValue(); + if (!Xor.isAllOnesValue()) + return nullptr; + + if (V1 == A || V1 == B) { + Value *NewOp = Builder->CreateAnd(V1 == A ? B : A, CI1); + return BinaryOperator::CreateXor(NewOp, V1); + } + + return nullptr; +} + Instruction *InstCombiner::visitOr(BinaryOperator &I) { bool Changed = SimplifyAssociativeOrCommutative(I); Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); @@ -1913,7 +2163,7 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return ReplaceInstUsesWith(I, V); - if (Value *V = SimplifyOrInst(Op0, Op1, DL)) + if (Value *V = SimplifyOrInst(Op0, Op1, DL, TLI, DT, AC)) return ReplaceInstUsesWith(I, V); // (A&B)|(A&C) -> A&(B|C) etc @@ -1925,6 +2175,9 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { if (SimplifyDemandedInstructionBits(I)) return &I; + if (Value *V = SimplifyBSwap(I)) + return ReplaceInstUsesWith(I, V); + if (ConstantInt *RHS = dyn_cast<ConstantInt>(Op1)) { ConstantInt *C1 = nullptr; Value *X = nullptr; // (X & C1) | C2 --> (X | C2) & (C1|C2) @@ -1973,7 +2226,7 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { // (X^C)|Y -> (X|Y)^C iff Y&C == 0 if (Op0->hasOneUse() && match(Op0, m_Xor(m_Value(A), m_ConstantInt(C1))) && - MaskedValueIsZero(Op1, C1->getValue())) { + MaskedValueIsZero(Op1, C1->getValue(), 0, &I)) { Value *NOr = Builder->CreateOr(A, Op1); NOr->takeName(Op0); return BinaryOperator::CreateXor(NOr, C1); @@ -1982,12 +2235,32 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { // Y|(X^C) -> (X|Y)^C iff Y&C == 0 if (Op1->hasOneUse() && match(Op1, m_Xor(m_Value(A), m_ConstantInt(C1))) && - MaskedValueIsZero(Op0, C1->getValue())) { + MaskedValueIsZero(Op0, C1->getValue(), 0, &I)) { Value *NOr = Builder->CreateOr(A, Op0); NOr->takeName(Op0); return BinaryOperator::CreateXor(NOr, C1); } + // ((~A & B) | A) -> (A | B) + if (match(Op0, m_And(m_Not(m_Value(A)), m_Value(B))) && + match(Op1, m_Specific(A))) + return BinaryOperator::CreateOr(A, B); + + // ((A & B) | ~A) -> (~A | B) + if (match(Op0, m_And(m_Value(A), m_Value(B))) && + match(Op1, m_Not(m_Specific(A)))) + return BinaryOperator::CreateOr(Builder->CreateNot(A), B); + + // (A & (~B)) | (A ^ B) -> (A ^ B) + if (match(Op0, m_And(m_Value(A), m_Not(m_Value(B)))) && + match(Op1, m_Xor(m_Specific(A), m_Specific(B)))) + return BinaryOperator::CreateXor(A, B); + + // (A ^ B) | ( A & (~B)) -> (A ^ B) + if (match(Op0, m_Xor(m_Value(A), m_Value(B))) && + match(Op1, m_And(m_Specific(A), m_Not(m_Specific(B))))) + return BinaryOperator::CreateXor(A, B); + // (A & C)|(B & D) Value *C = nullptr, *D = nullptr; if (match(Op0, m_And(m_Value(A), m_Value(C))) && @@ -2000,14 +2273,18 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { // ((V | N) & C1) | (V & C2) --> (V|N) & (C1|C2) // iff (C1&C2) == 0 and (N&~C1) == 0 if (match(A, m_Or(m_Value(V1), m_Value(V2))) && - ((V1 == B && MaskedValueIsZero(V2, ~C1->getValue())) || // (V|N) - (V2 == B && MaskedValueIsZero(V1, ~C1->getValue())))) // (N|V) + ((V1 == B && + MaskedValueIsZero(V2, ~C1->getValue(), 0, &I)) || // (V|N) + (V2 == B && + MaskedValueIsZero(V1, ~C1->getValue(), 0, &I)))) // (N|V) return BinaryOperator::CreateAnd(A, Builder->getInt(C1->getValue()|C2->getValue())); // Or commutes, try both ways. if (match(B, m_Or(m_Value(V1), m_Value(V2))) && - ((V1 == A && MaskedValueIsZero(V2, ~C2->getValue())) || // (V|N) - (V2 == A && MaskedValueIsZero(V1, ~C2->getValue())))) // (N|V) + ((V1 == A && + MaskedValueIsZero(V2, ~C2->getValue(), 0, &I)) || // (V|N) + (V2 == A && + MaskedValueIsZero(V1, ~C2->getValue(), 0, &I)))) // (N|V) return BinaryOperator::CreateAnd(B, Builder->getInt(C1->getValue()|C2->getValue())); @@ -2068,20 +2345,35 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { Instruction *Ret = FoldOrWithConstants(I, Op0, A, V1, D); if (Ret) return Ret; } + // ((A^B)&1)|(B&-2) -> (A&1) ^ B + if (match(A, m_Xor(m_Value(V1), m_Specific(B))) || + match(A, m_Xor(m_Specific(B), m_Value(V1)))) { + Instruction *Ret = FoldXorWithConstants(I, Op1, V1, B, C); + if (Ret) return Ret; + } + // (B&-2)|((A^B)&1) -> (A&1) ^ B + if (match(B, m_Xor(m_Specific(A), m_Value(V1))) || + match(B, m_Xor(m_Value(V1), m_Specific(A)))) { + Instruction *Ret = FoldXorWithConstants(I, Op0, A, V1, D); + if (Ret) return Ret; + } } - // (X >> Z) | (Y >> Z) -> (X|Y) >> Z for all shifts. - if (BinaryOperator *SI1 = dyn_cast<BinaryOperator>(Op1)) { - if (BinaryOperator *SI0 = dyn_cast<BinaryOperator>(Op0)) - if (SI0->isShift() && SI0->getOpcode() == SI1->getOpcode() && - SI0->getOperand(1) == SI1->getOperand(1) && - (SI0->hasOneUse() || SI1->hasOneUse())) { - Value *NewOp = Builder->CreateOr(SI0->getOperand(0), SI1->getOperand(0), - SI0->getName()); - return BinaryOperator::Create(SI1->getOpcode(), NewOp, - SI1->getOperand(1)); - } - } + // (A ^ B) | ((B ^ C) ^ A) -> (A ^ B) | C + if (match(Op0, m_Xor(m_Value(A), m_Value(B)))) + if (match(Op1, m_Xor(m_Xor(m_Specific(B), m_Value(C)), m_Specific(A)))) + if (Op1->hasOneUse() || cast<BinaryOperator>(Op1)->hasOneUse()) + return BinaryOperator::CreateOr(Op0, C); + + // ((A ^ C) ^ B) | (B ^ A) -> (B ^ A) | C + if (match(Op0, m_Xor(m_Xor(m_Value(A), m_Value(C)), m_Value(B)))) + if (match(Op1, m_Xor(m_Specific(B), m_Specific(A)))) + if (Op0->hasOneUse() || cast<BinaryOperator>(Op0)->hasOneUse()) + return BinaryOperator::CreateOr(Op1, C); + + // ((B | C) & A) | B -> B | (A & C) + if (match(Op0, m_And(m_Or(m_Specific(Op1), m_Value(C)), m_Value(A)))) + return BinaryOperator::CreateOr(Op1, Builder->CreateAnd(A, C)); // (~A | ~B) == (~(A & B)) - De Morgan's Law if (Value *Op0NotVal = dyn_castNotVal(Op0)) @@ -2133,14 +2425,47 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { return BinaryOperator::CreateOr(Not, Op0); } + // (A & B) | ((~A) ^ B) -> (~A ^ B) + if (match(Op0, m_And(m_Value(A), m_Value(B))) && + match(Op1, m_Xor(m_Not(m_Specific(A)), m_Specific(B)))) + return BinaryOperator::CreateXor(Builder->CreateNot(A), B); + + // ((~A) ^ B) | (A & B) -> (~A ^ B) + if (match(Op0, m_Xor(m_Not(m_Value(A)), m_Value(B))) && + match(Op1, m_And(m_Specific(A), m_Specific(B)))) + return BinaryOperator::CreateXor(Builder->CreateNot(A), B); + if (SwappedForXor) std::swap(Op0, Op1); - if (ICmpInst *RHS = dyn_cast<ICmpInst>(I.getOperand(1))) - if (ICmpInst *LHS = dyn_cast<ICmpInst>(I.getOperand(0))) - if (Value *Res = FoldOrOfICmps(LHS, RHS)) + { + ICmpInst *LHS = dyn_cast<ICmpInst>(Op0); + ICmpInst *RHS = dyn_cast<ICmpInst>(Op1); + if (LHS && RHS) + if (Value *Res = FoldOrOfICmps(LHS, RHS, &I)) return ReplaceInstUsesWith(I, Res); + // TODO: Make this recursive; it's a little tricky because an arbitrary + // number of 'or' instructions might have to be created. + Value *X, *Y; + if (LHS && match(Op1, m_OneUse(m_Or(m_Value(X), m_Value(Y))))) { + if (auto *Cmp = dyn_cast<ICmpInst>(X)) + if (Value *Res = FoldOrOfICmps(LHS, Cmp, &I)) + return ReplaceInstUsesWith(I, Builder->CreateOr(Res, Y)); + if (auto *Cmp = dyn_cast<ICmpInst>(Y)) + if (Value *Res = FoldOrOfICmps(LHS, Cmp, &I)) + return ReplaceInstUsesWith(I, Builder->CreateOr(Res, X)); + } + if (RHS && match(Op0, m_OneUse(m_Or(m_Value(X), m_Value(Y))))) { + if (auto *Cmp = dyn_cast<ICmpInst>(X)) + if (Value *Res = FoldOrOfICmps(Cmp, RHS, &I)) + return ReplaceInstUsesWith(I, Builder->CreateOr(Res, Y)); + if (auto *Cmp = dyn_cast<ICmpInst>(Y)) + if (Value *Res = FoldOrOfICmps(Cmp, RHS, &I)) + return ReplaceInstUsesWith(I, Builder->CreateOr(Res, X)); + } + } + // (fcmp uno x, c) | (fcmp uno y, c) -> (fcmp uno x, y) if (FCmpInst *LHS = dyn_cast<FCmpInst>(I.getOperand(0))) if (FCmpInst *RHS = dyn_cast<FCmpInst>(I.getOperand(1))) @@ -2169,7 +2494,7 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { // cast is otherwise not optimizable. This happens for vector sexts. if (ICmpInst *RHS = dyn_cast<ICmpInst>(Op1COp)) if (ICmpInst *LHS = dyn_cast<ICmpInst>(Op0COp)) - if (Value *Res = FoldOrOfICmps(LHS, RHS)) + if (Value *Res = FoldOrOfICmps(LHS, RHS, &I)) return CastInst::Create(Op0C->getOpcode(), Res, I.getType()); // If this is or(cast(fcmp), cast(fcmp)), try to fold this even if the @@ -2225,7 +2550,7 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return ReplaceInstUsesWith(I, V); - if (Value *V = SimplifyXorInst(Op0, Op1, DL)) + if (Value *V = SimplifyXorInst(Op0, Op1, DL, TLI, DT, AC)) return ReplaceInstUsesWith(I, V); // (A&B)^(A&C) -> A&(B^C) etc @@ -2237,6 +2562,9 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { if (SimplifyDemandedInstructionBits(I)) return &I; + if (Value *V = SimplifyBSwap(I)) + return ReplaceInstUsesWith(I, V); + // Is this a ~ operation? if (Value *NotOp = dyn_castNotVal(&I)) { if (BinaryOperator *Op0I = dyn_cast<BinaryOperator>(NotOp)) { @@ -2327,7 +2655,8 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { } } else if (Op0I->getOpcode() == Instruction::Or) { // (X|C1)^C2 -> X^(C1|C2) iff X&~C1 == 0 - if (MaskedValueIsZero(Op0I->getOperand(0), Op0CI->getValue())) { + if (MaskedValueIsZero(Op0I->getOperand(0), Op0CI->getValue(), + 0, &I)) { Constant *NewRHS = ConstantExpr::getOr(Op0CI, RHS); // Anything in both C1 and C2 is known to be zero, remove it from // NewRHS. @@ -2418,18 +2747,6 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { } } - // (X >> Z) ^ (Y >> Z) -> (X^Y) >> Z for all shifts. - if (Op0I && Op1I && Op0I->isShift() && - Op0I->getOpcode() == Op1I->getOpcode() && - Op0I->getOperand(1) == Op1I->getOperand(1) && - (Op0I->hasOneUse() || Op1I->hasOneUse())) { - Value *NewOp = - Builder->CreateXor(Op0I->getOperand(0), Op1I->getOperand(0), - Op0I->getName()); - return BinaryOperator::Create(Op1I->getOpcode(), NewOp, - Op1I->getOperand(1)); - } - if (Op0I && Op1I) { Value *A, *B, *C, *D; // (A & B)^(A | B) -> A ^ B @@ -2444,8 +2761,62 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { if ((A == C && B == D) || (A == D && B == C)) return BinaryOperator::CreateXor(A, B); } + // (A | ~B) ^ (~A | B) -> A ^ B + if (match(Op0I, m_Or(m_Value(A), m_Not(m_Value(B)))) && + match(Op1I, m_Or(m_Not(m_Specific(A)), m_Specific(B)))) { + return BinaryOperator::CreateXor(A, B); + } + // (~A | B) ^ (A | ~B) -> A ^ B + if (match(Op0I, m_Or(m_Not(m_Value(A)), m_Value(B))) && + match(Op1I, m_Or(m_Specific(A), m_Not(m_Specific(B))))) { + return BinaryOperator::CreateXor(A, B); + } + // (A & ~B) ^ (~A & B) -> A ^ B + if (match(Op0I, m_And(m_Value(A), m_Not(m_Value(B)))) && + match(Op1I, m_And(m_Not(m_Specific(A)), m_Specific(B)))) { + return BinaryOperator::CreateXor(A, B); + } + // (~A & B) ^ (A & ~B) -> A ^ B + if (match(Op0I, m_And(m_Not(m_Value(A)), m_Value(B))) && + match(Op1I, m_And(m_Specific(A), m_Not(m_Specific(B))))) { + return BinaryOperator::CreateXor(A, B); + } + // (A ^ C)^(A | B) -> ((~A) & B) ^ C + if (match(Op0I, m_Xor(m_Value(D), m_Value(C))) && + match(Op1I, m_Or(m_Value(A), m_Value(B)))) { + if (D == A) + return BinaryOperator::CreateXor( + Builder->CreateAnd(Builder->CreateNot(A), B), C); + if (D == B) + return BinaryOperator::CreateXor( + Builder->CreateAnd(Builder->CreateNot(B), A), C); + } + // (A | B)^(A ^ C) -> ((~A) & B) ^ C + if (match(Op0I, m_Or(m_Value(A), m_Value(B))) && + match(Op1I, m_Xor(m_Value(D), m_Value(C)))) { + if (D == A) + return BinaryOperator::CreateXor( + Builder->CreateAnd(Builder->CreateNot(A), B), C); + if (D == B) + return BinaryOperator::CreateXor( + Builder->CreateAnd(Builder->CreateNot(B), A), C); + } + // (A & B) ^ (A ^ B) -> (A | B) + if (match(Op0I, m_And(m_Value(A), m_Value(B))) && + match(Op1I, m_Xor(m_Specific(A), m_Specific(B)))) + return BinaryOperator::CreateOr(A, B); + // (A ^ B) ^ (A & B) -> (A | B) + if (match(Op0I, m_Xor(m_Value(A), m_Value(B))) && + match(Op1I, m_And(m_Specific(A), m_Specific(B)))) + return BinaryOperator::CreateOr(A, B); } + Value *A = nullptr, *B = nullptr; + // (A & ~B) ^ (~A) -> ~(A & B) + if (match(Op0, m_And(m_Value(A), m_Not(m_Value(B)))) && + match(Op1, m_Not(m_Specific(A)))) + return BinaryOperator::CreateNot(Builder->CreateAnd(A, B)); + // (icmp1 A, B) ^ (icmp2 A, B) --> (icmp3 A, B) if (ICmpInst *RHS = dyn_cast<ICmpInst>(I.getOperand(1))) if (ICmpInst *LHS = dyn_cast<ICmpInst>(I.getOperand(0))) diff --git a/lib/Transforms/InstCombine/InstCombineCalls.cpp b/lib/Transforms/InstCombine/InstCombineCalls.cpp index 658178d5914e..dab2c4b47ad6 100644 --- a/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -16,7 +16,9 @@ #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/Dominators.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Statepoint.h" #include "llvm/Transforms/Utils/BuildLibCalls.h" #include "llvm/Transforms/Utils/Local.h" using namespace llvm; @@ -58,8 +60,8 @@ static Type *reduceToSingleValueType(Type *T) { } Instruction *InstCombiner::SimplifyMemTransfer(MemIntrinsic *MI) { - unsigned DstAlign = getKnownAlignment(MI->getArgOperand(0), DL); - unsigned SrcAlign = getKnownAlignment(MI->getArgOperand(1), DL); + unsigned DstAlign = getKnownAlignment(MI->getArgOperand(0), DL, AC, MI, DT); + unsigned SrcAlign = getKnownAlignment(MI->getArgOperand(1), DL, AC, MI, DT); unsigned MinAlign = std::min(DstAlign, SrcAlign); unsigned CopyAlign = MI->getAlignment(); @@ -117,15 +119,14 @@ Instruction *InstCombiner::SimplifyMemTransfer(MemIntrinsic *MI) { // If the memcpy has metadata describing the members, see if we can // get the TBAA tag describing our copy. if (MDNode *M = MI->getMetadata(LLVMContext::MD_tbaa_struct)) { - if (M->getNumOperands() == 3 && - M->getOperand(0) && - isa<ConstantInt>(M->getOperand(0)) && - cast<ConstantInt>(M->getOperand(0))->isNullValue() && + if (M->getNumOperands() == 3 && M->getOperand(0) && + mdconst::hasa<ConstantInt>(M->getOperand(0)) && + mdconst::extract<ConstantInt>(M->getOperand(0))->isNullValue() && M->getOperand(1) && - isa<ConstantInt>(M->getOperand(1)) && - cast<ConstantInt>(M->getOperand(1))->getValue() == Size && - M->getOperand(2) && - isa<MDNode>(M->getOperand(2))) + mdconst::hasa<ConstantInt>(M->getOperand(1)) && + mdconst::extract<ConstantInt>(M->getOperand(1))->getValue() == + Size && + M->getOperand(2) && isa<MDNode>(M->getOperand(2))) CopyMD = cast<MDNode>(M->getOperand(2)); } } @@ -154,7 +155,7 @@ Instruction *InstCombiner::SimplifyMemTransfer(MemIntrinsic *MI) { } Instruction *InstCombiner::SimplifyMemSet(MemSetInst *MI) { - unsigned Alignment = getKnownAlignment(MI->getDest(), DL); + unsigned Alignment = getKnownAlignment(MI->getDest(), DL, AC, MI, DT); if (MI->getAlignment() < Alignment) { MI->setAlignment(ConstantInt::get(MI->getAlignmentType(), Alignment, false)); @@ -322,7 +323,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { uint32_t BitWidth = IT->getBitWidth(); APInt KnownZero(BitWidth, 0); APInt KnownOne(BitWidth, 0); - computeKnownBits(II->getArgOperand(0), KnownZero, KnownOne); + computeKnownBits(II->getArgOperand(0), KnownZero, KnownOne, 0, II); unsigned TrailingZeros = KnownOne.countTrailingZeros(); APInt Mask(APInt::getLowBitsSet(BitWidth, TrailingZeros)); if ((Mask & KnownZero) == Mask) @@ -340,7 +341,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { uint32_t BitWidth = IT->getBitWidth(); APInt KnownZero(BitWidth, 0); APInt KnownOne(BitWidth, 0); - computeKnownBits(II->getArgOperand(0), KnownZero, KnownOne); + computeKnownBits(II->getArgOperand(0), KnownZero, KnownOne, 0, II); unsigned LeadingZeros = KnownOne.countLeadingZeros(); APInt Mask(APInt::getHighBitsSet(BitWidth, LeadingZeros)); if ((Mask & KnownZero) == Mask) @@ -351,48 +352,11 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { break; case Intrinsic::uadd_with_overflow: { Value *LHS = II->getArgOperand(0), *RHS = II->getArgOperand(1); - IntegerType *IT = cast<IntegerType>(II->getArgOperand(0)->getType()); - uint32_t BitWidth = IT->getBitWidth(); - APInt LHSKnownZero(BitWidth, 0); - APInt LHSKnownOne(BitWidth, 0); - computeKnownBits(LHS, LHSKnownZero, LHSKnownOne); - bool LHSKnownNegative = LHSKnownOne[BitWidth - 1]; - bool LHSKnownPositive = LHSKnownZero[BitWidth - 1]; - - if (LHSKnownNegative || LHSKnownPositive) { - APInt RHSKnownZero(BitWidth, 0); - APInt RHSKnownOne(BitWidth, 0); - computeKnownBits(RHS, RHSKnownZero, RHSKnownOne); - bool RHSKnownNegative = RHSKnownOne[BitWidth - 1]; - bool RHSKnownPositive = RHSKnownZero[BitWidth - 1]; - if (LHSKnownNegative && RHSKnownNegative) { - // The sign bit is set in both cases: this MUST overflow. - // Create a simple add instruction, and insert it into the struct. - Value *Add = Builder->CreateAdd(LHS, RHS); - Add->takeName(&CI); - Constant *V[] = { - UndefValue::get(LHS->getType()), - ConstantInt::getTrue(II->getContext()) - }; - StructType *ST = cast<StructType>(II->getType()); - Constant *Struct = ConstantStruct::get(ST, V); - return InsertValueInst::Create(Struct, Add, 0); - } - - if (LHSKnownPositive && RHSKnownPositive) { - // The sign bit is clear in both cases: this CANNOT overflow. - // Create a simple add instruction, and insert it into the struct. - Value *Add = Builder->CreateNUWAdd(LHS, RHS); - Add->takeName(&CI); - Constant *V[] = { - UndefValue::get(LHS->getType()), - ConstantInt::getFalse(II->getContext()) - }; - StructType *ST = cast<StructType>(II->getType()); - Constant *Struct = ConstantStruct::get(ST, V); - return InsertValueInst::Create(Struct, Add, 0); - } - } + OverflowResult OR = computeOverflowForUnsignedAdd(LHS, RHS, II); + if (OR == OverflowResult::NeverOverflows) + return CreateOverflowTuple(II, Builder->CreateNUWAdd(LHS, RHS), false); + if (OR == OverflowResult::AlwaysOverflows) + return CreateOverflowTuple(II, Builder->CreateAdd(LHS, RHS), true); } // FALL THROUGH uadd into sadd case Intrinsic::sadd_with_overflow: @@ -412,13 +376,8 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { if (ConstantInt *RHS = dyn_cast<ConstantInt>(II->getArgOperand(1))) { // X + 0 -> {X, false} if (RHS->isZero()) { - Constant *V[] = { - UndefValue::get(II->getArgOperand(0)->getType()), - ConstantInt::getFalse(II->getContext()) - }; - Constant *Struct = - ConstantStruct::get(cast<StructType>(II->getType()), V); - return InsertValueInst::Create(Struct, II->getArgOperand(0), 0); + return CreateOverflowTuple(II, II->getArgOperand(0), false, + /*ReUseName*/false); } } @@ -426,66 +385,44 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // can prove that it will never overflow. if (II->getIntrinsicID() == Intrinsic::sadd_with_overflow) { Value *LHS = II->getArgOperand(0), *RHS = II->getArgOperand(1); - if (WillNotOverflowSignedAdd(LHS, RHS)) { - Value *Add = Builder->CreateNSWAdd(LHS, RHS); - Add->takeName(&CI); - Constant *V[] = {UndefValue::get(Add->getType()), Builder->getFalse()}; - StructType *ST = cast<StructType>(II->getType()); - Constant *Struct = ConstantStruct::get(ST, V); - return InsertValueInst::Create(Struct, Add, 0); + if (WillNotOverflowSignedAdd(LHS, RHS, II)) { + return CreateOverflowTuple(II, Builder->CreateNSWAdd(LHS, RHS), false); } } break; case Intrinsic::usub_with_overflow: - case Intrinsic::ssub_with_overflow: + case Intrinsic::ssub_with_overflow: { + Value *LHS = II->getArgOperand(0), *RHS = II->getArgOperand(1); // undef - X -> undef // X - undef -> undef - if (isa<UndefValue>(II->getArgOperand(0)) || - isa<UndefValue>(II->getArgOperand(1))) + if (isa<UndefValue>(LHS) || isa<UndefValue>(RHS)) return ReplaceInstUsesWith(CI, UndefValue::get(II->getType())); - if (ConstantInt *RHS = dyn_cast<ConstantInt>(II->getArgOperand(1))) { + if (ConstantInt *ConstRHS = dyn_cast<ConstantInt>(RHS)) { // X - 0 -> {X, false} - if (RHS->isZero()) { - Constant *V[] = { - UndefValue::get(II->getArgOperand(0)->getType()), - ConstantInt::getFalse(II->getContext()) - }; - Constant *Struct = - ConstantStruct::get(cast<StructType>(II->getType()), V); - return InsertValueInst::Create(Struct, II->getArgOperand(0), 0); + if (ConstRHS->isZero()) { + return CreateOverflowTuple(II, LHS, false, /*ReUseName*/false); + } + } + if (II->getIntrinsicID() == Intrinsic::ssub_with_overflow) { + if (WillNotOverflowSignedSub(LHS, RHS, II)) { + return CreateOverflowTuple(II, Builder->CreateNSWSub(LHS, RHS), false); + } + } else { + if (WillNotOverflowUnsignedSub(LHS, RHS, II)) { + return CreateOverflowTuple(II, Builder->CreateNUWSub(LHS, RHS), false); } } break; + } case Intrinsic::umul_with_overflow: { Value *LHS = II->getArgOperand(0), *RHS = II->getArgOperand(1); - unsigned BitWidth = cast<IntegerType>(LHS->getType())->getBitWidth(); - - APInt LHSKnownZero(BitWidth, 0); - APInt LHSKnownOne(BitWidth, 0); - computeKnownBits(LHS, LHSKnownZero, LHSKnownOne); - APInt RHSKnownZero(BitWidth, 0); - APInt RHSKnownOne(BitWidth, 0); - computeKnownBits(RHS, RHSKnownZero, RHSKnownOne); - - // Get the largest possible values for each operand. - APInt LHSMax = ~LHSKnownZero; - APInt RHSMax = ~RHSKnownZero; - - // If multiplying the maximum values does not overflow then we can turn - // this into a plain NUW mul. - bool Overflow; - LHSMax.umul_ov(RHSMax, Overflow); - if (!Overflow) { - Value *Mul = Builder->CreateNUWMul(LHS, RHS, "umul_with_overflow"); - Constant *V[] = { - UndefValue::get(LHS->getType()), - Builder->getFalse() - }; - Constant *Struct = ConstantStruct::get(cast<StructType>(II->getType()),V); - return InsertValueInst::Create(Struct, Mul, 0); - } + OverflowResult OR = computeOverflowForUnsignedMul(LHS, RHS, II); + if (OR == OverflowResult::NeverOverflows) + return CreateOverflowTuple(II, Builder->CreateNUWMul(LHS, RHS), false); + if (OR == OverflowResult::AlwaysOverflows) + return CreateOverflowTuple(II, Builder->CreateMul(LHS, RHS), true); } // FALL THROUGH case Intrinsic::smul_with_overflow: // Canonicalize constants into the RHS. @@ -508,40 +445,142 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // X * 1 -> {X, false} if (RHSI->equalsInt(1)) { - Constant *V[] = { - UndefValue::get(II->getArgOperand(0)->getType()), - ConstantInt::getFalse(II->getContext()) - }; - Constant *Struct = - ConstantStruct::get(cast<StructType>(II->getType()), V); - return InsertValueInst::Create(Struct, II->getArgOperand(0), 0); + return CreateOverflowTuple(II, II->getArgOperand(0), false, + /*ReUseName*/false); + } + } + if (II->getIntrinsicID() == Intrinsic::smul_with_overflow) { + Value *LHS = II->getArgOperand(0), *RHS = II->getArgOperand(1); + if (WillNotOverflowSignedMul(LHS, RHS, II)) { + return CreateOverflowTuple(II, Builder->CreateNSWMul(LHS, RHS), false); + } + } + break; + case Intrinsic::minnum: + case Intrinsic::maxnum: { + Value *Arg0 = II->getArgOperand(0); + Value *Arg1 = II->getArgOperand(1); + + // fmin(x, x) -> x + if (Arg0 == Arg1) + return ReplaceInstUsesWith(CI, Arg0); + + const ConstantFP *C0 = dyn_cast<ConstantFP>(Arg0); + const ConstantFP *C1 = dyn_cast<ConstantFP>(Arg1); + + // Canonicalize constants into the RHS. + if (C0 && !C1) { + II->setArgOperand(0, Arg1); + II->setArgOperand(1, Arg0); + return II; + } + + // fmin(x, nan) -> x + if (C1 && C1->isNaN()) + return ReplaceInstUsesWith(CI, Arg0); + + // This is the value because if undef were NaN, we would return the other + // value and cannot return a NaN unless both operands are. + // + // fmin(undef, x) -> x + if (isa<UndefValue>(Arg0)) + return ReplaceInstUsesWith(CI, Arg1); + + // fmin(x, undef) -> x + if (isa<UndefValue>(Arg1)) + return ReplaceInstUsesWith(CI, Arg0); + + Value *X = nullptr; + Value *Y = nullptr; + if (II->getIntrinsicID() == Intrinsic::minnum) { + // fmin(x, fmin(x, y)) -> fmin(x, y) + // fmin(y, fmin(x, y)) -> fmin(x, y) + if (match(Arg1, m_FMin(m_Value(X), m_Value(Y)))) { + if (Arg0 == X || Arg0 == Y) + return ReplaceInstUsesWith(CI, Arg1); + } + + // fmin(fmin(x, y), x) -> fmin(x, y) + // fmin(fmin(x, y), y) -> fmin(x, y) + if (match(Arg0, m_FMin(m_Value(X), m_Value(Y)))) { + if (Arg1 == X || Arg1 == Y) + return ReplaceInstUsesWith(CI, Arg0); + } + + // TODO: fmin(nnan x, inf) -> x + // TODO: fmin(nnan ninf x, flt_max) -> x + if (C1 && C1->isInfinity()) { + // fmin(x, -inf) -> -inf + if (C1->isNegative()) + return ReplaceInstUsesWith(CI, Arg1); + } + } else { + assert(II->getIntrinsicID() == Intrinsic::maxnum); + // fmax(x, fmax(x, y)) -> fmax(x, y) + // fmax(y, fmax(x, y)) -> fmax(x, y) + if (match(Arg1, m_FMax(m_Value(X), m_Value(Y)))) { + if (Arg0 == X || Arg0 == Y) + return ReplaceInstUsesWith(CI, Arg1); + } + + // fmax(fmax(x, y), x) -> fmax(x, y) + // fmax(fmax(x, y), y) -> fmax(x, y) + if (match(Arg0, m_FMax(m_Value(X), m_Value(Y)))) { + if (Arg1 == X || Arg1 == Y) + return ReplaceInstUsesWith(CI, Arg0); + } + + // TODO: fmax(nnan x, -inf) -> x + // TODO: fmax(nnan ninf x, -flt_max) -> x + if (C1 && C1->isInfinity()) { + // fmax(x, inf) -> inf + if (!C1->isNegative()) + return ReplaceInstUsesWith(CI, Arg1); } } break; + } case Intrinsic::ppc_altivec_lvx: case Intrinsic::ppc_altivec_lvxl: // Turn PPC lvx -> load if the pointer is known aligned. - if (getOrEnforceKnownAlignment(II->getArgOperand(0), 16, DL) >= 16) { + if (getOrEnforceKnownAlignment(II->getArgOperand(0), 16, DL, AC, II, DT) >= + 16) { Value *Ptr = Builder->CreateBitCast(II->getArgOperand(0), PointerType::getUnqual(II->getType())); return new LoadInst(Ptr); } break; + case Intrinsic::ppc_vsx_lxvw4x: + case Intrinsic::ppc_vsx_lxvd2x: { + // Turn PPC VSX loads into normal loads. + Value *Ptr = Builder->CreateBitCast(II->getArgOperand(0), + PointerType::getUnqual(II->getType())); + return new LoadInst(Ptr, Twine(""), false, 1); + } case Intrinsic::ppc_altivec_stvx: case Intrinsic::ppc_altivec_stvxl: // Turn stvx -> store if the pointer is known aligned. - if (getOrEnforceKnownAlignment(II->getArgOperand(1), 16, DL) >= 16) { + if (getOrEnforceKnownAlignment(II->getArgOperand(1), 16, DL, AC, II, DT) >= + 16) { Type *OpPtrTy = PointerType::getUnqual(II->getArgOperand(0)->getType()); Value *Ptr = Builder->CreateBitCast(II->getArgOperand(1), OpPtrTy); return new StoreInst(II->getArgOperand(0), Ptr); } break; + case Intrinsic::ppc_vsx_stxvw4x: + case Intrinsic::ppc_vsx_stxvd2x: { + // Turn PPC VSX stores into normal stores. + Type *OpPtrTy = PointerType::getUnqual(II->getArgOperand(0)->getType()); + Value *Ptr = Builder->CreateBitCast(II->getArgOperand(1), OpPtrTy); + return new StoreInst(II->getArgOperand(0), Ptr, false, 1); + } case Intrinsic::x86_sse_storeu_ps: case Intrinsic::x86_sse2_storeu_pd: case Intrinsic::x86_sse2_storeu_dq: // Turn X86 storeu -> store if the pointer is known aligned. - if (getOrEnforceKnownAlignment(II->getArgOperand(0), 16, DL) >= 16) { + if (getOrEnforceKnownAlignment(II->getArgOperand(0), 16, DL, AC, II, DT) >= + 16) { Type *OpPtrTy = PointerType::getUnqual(II->getArgOperand(1)->getType()); Value *Ptr = Builder->CreateBitCast(II->getArgOperand(0), OpPtrTy); @@ -672,7 +711,22 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // TODO: eventually we should lower this intrinsic to IR if (auto CIWidth = dyn_cast<ConstantInt>(II->getArgOperand(2))) { if (auto CIStart = dyn_cast<ConstantInt>(II->getArgOperand(3))) { - if (CIWidth->equalsInt(64) && CIStart->isZero()) { + unsigned Index = CIStart->getZExtValue(); + // From AMD documentation: "a value of zero in the field length is + // defined as length of 64". + unsigned Length = CIWidth->equalsInt(0) ? 64 : CIWidth->getZExtValue(); + + // From AMD documentation: "If the sum of the bit index + length field + // is greater than 64, the results are undefined". + + // Note that both field index and field length are 8-bit quantities. + // Since variables 'Index' and 'Length' are unsigned values + // obtained from zero-extending field index and field length + // respectively, their sum should never wrap around. + if ((Index + Length) > 64) + return ReplaceInstUsesWith(CI, UndefValue::get(II->getType())); + + if (Length == 64 && Index == 0) { Value *Vec = II->getArgOperand(1); Value *Undef = UndefValue::get(Vec->getType()); const uint32_t Mask[] = { 0, 2 }; @@ -680,7 +734,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { CI, Builder->CreateShuffleVector( Vec, Undef, ConstantDataVector::get( - II->getContext(), ArrayRef<uint32_t>(Mask)))); + II->getContext(), makeArrayRef(Mask)))); } else if (auto Source = dyn_cast<IntrinsicInst>(II->getArgOperand(0))) { @@ -886,7 +940,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::arm_neon_vst2lane: case Intrinsic::arm_neon_vst3lane: case Intrinsic::arm_neon_vst4lane: { - unsigned MemAlign = getKnownAlignment(II->getArgOperand(0), DL); + unsigned MemAlign = getKnownAlignment(II->getArgOperand(0), DL, AC, II, DT); unsigned AlignArg = II->getNumArgOperands() - 1; ConstantInt *IntrAlign = dyn_cast<ConstantInt>(II->getArgOperand(AlignArg)); if (IntrAlign && IntrAlign->getZExtValue() < MemAlign) { @@ -994,6 +1048,91 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { return EraseInstFromFunction(CI); break; } + case Intrinsic::assume: { + // Canonicalize assume(a && b) -> assume(a); assume(b); + // Note: New assumption intrinsics created here are registered by + // the InstCombineIRInserter object. + Value *IIOperand = II->getArgOperand(0), *A, *B, + *AssumeIntrinsic = II->getCalledValue(); + if (match(IIOperand, m_And(m_Value(A), m_Value(B)))) { + Builder->CreateCall(AssumeIntrinsic, A, II->getName()); + Builder->CreateCall(AssumeIntrinsic, B, II->getName()); + return EraseInstFromFunction(*II); + } + // assume(!(a || b)) -> assume(!a); assume(!b); + if (match(IIOperand, m_Not(m_Or(m_Value(A), m_Value(B))))) { + Builder->CreateCall(AssumeIntrinsic, Builder->CreateNot(A), + II->getName()); + Builder->CreateCall(AssumeIntrinsic, Builder->CreateNot(B), + II->getName()); + return EraseInstFromFunction(*II); + } + + // assume( (load addr) != null ) -> add 'nonnull' metadata to load + // (if assume is valid at the load) + if (ICmpInst* ICmp = dyn_cast<ICmpInst>(IIOperand)) { + Value *LHS = ICmp->getOperand(0); + Value *RHS = ICmp->getOperand(1); + if (ICmpInst::ICMP_NE == ICmp->getPredicate() && + isa<LoadInst>(LHS) && + isa<Constant>(RHS) && + RHS->getType()->isPointerTy() && + cast<Constant>(RHS)->isNullValue()) { + LoadInst* LI = cast<LoadInst>(LHS); + if (isValidAssumeForContext(II, LI, DL, DT)) { + MDNode *MD = MDNode::get(II->getContext(), None); + LI->setMetadata(LLVMContext::MD_nonnull, MD); + return EraseInstFromFunction(*II); + } + } + // TODO: apply nonnull return attributes to calls and invokes + // TODO: apply range metadata for range check patterns? + } + // If there is a dominating assume with the same condition as this one, + // then this one is redundant, and should be removed. + APInt KnownZero(1, 0), KnownOne(1, 0); + computeKnownBits(IIOperand, KnownZero, KnownOne, 0, II); + if (KnownOne.isAllOnesValue()) + return EraseInstFromFunction(*II); + + break; + } + case Intrinsic::experimental_gc_relocate: { + // Translate facts known about a pointer before relocating into + // facts about the relocate value, while being careful to + // preserve relocation semantics. + GCRelocateOperands Operands(II); + Value *DerivedPtr = Operands.derivedPtr(); + + // Remove the relocation if unused, note that this check is required + // to prevent the cases below from looping forever. + if (II->use_empty()) + return EraseInstFromFunction(*II); + + // Undef is undef, even after relocation. + // TODO: provide a hook for this in GCStrategy. This is clearly legal for + // most practical collectors, but there was discussion in the review thread + // about whether it was legal for all possible collectors. + if (isa<UndefValue>(DerivedPtr)) + return ReplaceInstUsesWith(*II, DerivedPtr); + + // The relocation of null will be null for most any collector. + // TODO: provide a hook for this in GCStrategy. There might be some weird + // collector this property does not hold for. + if (isa<ConstantPointerNull>(DerivedPtr)) + return ReplaceInstUsesWith(*II, DerivedPtr); + + // isKnownNonNull -> nonnull attribute + if (isKnownNonNull(DerivedPtr)) + II->addAttribute(AttributeSet::ReturnIndex, Attribute::NonNull); + + // TODO: dereferenceable -> deref attribute + + // TODO: bitcast(relocate(p)) -> relocate(bitcast(p)) + // Canonicalize on the type from the uses to the defs + + // TODO: relocate((gep p, C, C2, ...)) -> gep(relocate(p), C, C2, ...) + } } return visitCallSite(II); @@ -1014,6 +1153,14 @@ static bool isSafeToEliminateVarargsCast(const CallSite CS, if (!CI->isLosslessCast()) return false; + // If this is a GC intrinsic, avoid munging types. We need types for + // statepoint reconstruction in SelectionDAG. + // TODO: This is probably something which should be expanded to all + // intrinsics since the entire point of intrinsics is that + // they are understandable by the optimizer. + if (isStatepoint(CS) || isGCRelocate(CS) || isGCResult(CS)) + return false; + // The size of ByVal or InAlloca arguments is derived from the type, so we // can't change to a type with a different size. If the size were // passed explicitly we could avoid this check. @@ -1246,14 +1393,14 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) { if (NewRetTy->isStructTy()) return false; // TODO: Handle multiple return values. - if (!CastInst::isBitCastable(NewRetTy, OldRetTy)) { + if (!CastInst::isBitOrNoopPointerCastable(NewRetTy, OldRetTy, DL)) { if (Callee->isDeclaration()) return false; // Cannot transform this return value. if (!Caller->use_empty() && // void -> non-void is handled specially !NewRetTy->isVoidTy()) - return false; // Cannot transform this return value. + return false; // Cannot transform this return value. } if (!CallerPAL.isEmpty() && !Caller->use_empty()) { @@ -1281,12 +1428,21 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) { unsigned NumActualArgs = CS.arg_size(); unsigned NumCommonArgs = std::min(FT->getNumParams(), NumActualArgs); + // Prevent us turning: + // declare void @takes_i32_inalloca(i32* inalloca) + // call void bitcast (void (i32*)* @takes_i32_inalloca to void (i32)*)(i32 0) + // + // into: + // call void @takes_i32_inalloca(i32* null) + if (Callee->getAttributes().hasAttrSomewhere(Attribute::InAlloca)) + return false; + CallSite::arg_iterator AI = CS.arg_begin(); for (unsigned i = 0, e = NumCommonArgs; i != e; ++i, ++AI) { Type *ParamTy = FT->getParamType(i); Type *ActTy = (*AI)->getType(); - if (!CastInst::isBitCastable(ActTy, ParamTy)) + if (!CastInst::isBitOrNoopPointerCastable(ActTy, ParamTy, DL)) return false; // Cannot transform this parameter value. if (AttrBuilder(CallerPAL.getParamAttributes(i + 1), i + 1). @@ -1381,7 +1537,7 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) { if ((*AI)->getType() == ParamTy) { Args.push_back(*AI); } else { - Args.push_back(Builder->CreateBitCast(*AI, ParamTy)); + Args.push_back(Builder->CreateBitOrPointerCast(*AI, ParamTy)); } // Add any parameter attributes. @@ -1452,7 +1608,7 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) { Value *NV = NC; if (OldRetTy != NV->getType() && !Caller->use_empty()) { if (!NV->getType()->isVoidTy()) { - NV = NC = CastInst::Create(CastInst::BitCast, NC, OldRetTy); + NV = NC = CastInst::CreateBitOrPointerCast(NC, OldRetTy); NC->setDebugLoc(Caller->getDebugLoc()); // If this is an invoke instruction, we should insert it after the first @@ -1472,8 +1628,14 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) { if (!Caller->use_empty()) ReplaceInstUsesWith(*Caller, NV); - else if (Caller->hasValueHandle()) - ValueHandleBase::ValueIsRAUWd(Caller, NV); + else if (Caller->hasValueHandle()) { + if (OldRetTy == NV->getType()) + ValueHandleBase::ValueIsRAUWd(Caller, NV); + else + // We cannot call ValueIsRAUWd with a different type, and the + // actual tracked value will disappear. + ValueHandleBase::ValueIsDeleted(Caller); + } EraseInstFromFunction(*Caller); return true; diff --git a/lib/Transforms/InstCombine/InstCombineCasts.cpp b/lib/Transforms/InstCombine/InstCombineCasts.cpp index b9c3d0f64718..54157268e9f6 100644 --- a/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -335,7 +335,8 @@ Instruction *InstCombiner::commonCastTransforms(CastInst &CI) { /// /// This function works on both vectors and scalars. /// -static bool CanEvaluateTruncated(Value *V, Type *Ty) { +static bool CanEvaluateTruncated(Value *V, Type *Ty, InstCombiner &IC, + Instruction *CxtI) { // We can always evaluate constants in another type. if (isa<Constant>(V)) return true; @@ -364,8 +365,8 @@ static bool CanEvaluateTruncated(Value *V, Type *Ty) { case Instruction::Or: case Instruction::Xor: // These operators can all arbitrarily be extended or truncated. - return CanEvaluateTruncated(I->getOperand(0), Ty) && - CanEvaluateTruncated(I->getOperand(1), Ty); + return CanEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) && + CanEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI); case Instruction::UDiv: case Instruction::URem: { @@ -374,10 +375,10 @@ static bool CanEvaluateTruncated(Value *V, Type *Ty) { uint32_t BitWidth = Ty->getScalarSizeInBits(); if (BitWidth < OrigBitWidth) { APInt Mask = APInt::getHighBitsSet(OrigBitWidth, OrigBitWidth-BitWidth); - if (MaskedValueIsZero(I->getOperand(0), Mask) && - MaskedValueIsZero(I->getOperand(1), Mask)) { - return CanEvaluateTruncated(I->getOperand(0), Ty) && - CanEvaluateTruncated(I->getOperand(1), Ty); + if (IC.MaskedValueIsZero(I->getOperand(0), Mask, 0, CxtI) && + IC.MaskedValueIsZero(I->getOperand(1), Mask, 0, CxtI)) { + return CanEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) && + CanEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI); } } break; @@ -388,7 +389,7 @@ static bool CanEvaluateTruncated(Value *V, Type *Ty) { if (ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1))) { uint32_t BitWidth = Ty->getScalarSizeInBits(); if (CI->getLimitedValue(BitWidth) < BitWidth) - return CanEvaluateTruncated(I->getOperand(0), Ty); + return CanEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI); } break; case Instruction::LShr: @@ -398,10 +399,10 @@ static bool CanEvaluateTruncated(Value *V, Type *Ty) { if (ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1))) { uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits(); uint32_t BitWidth = Ty->getScalarSizeInBits(); - if (MaskedValueIsZero(I->getOperand(0), - APInt::getHighBitsSet(OrigBitWidth, OrigBitWidth-BitWidth)) && + if (IC.MaskedValueIsZero(I->getOperand(0), + APInt::getHighBitsSet(OrigBitWidth, OrigBitWidth-BitWidth), 0, CxtI) && CI->getLimitedValue(BitWidth) < BitWidth) { - return CanEvaluateTruncated(I->getOperand(0), Ty); + return CanEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI); } } break; @@ -415,8 +416,8 @@ static bool CanEvaluateTruncated(Value *V, Type *Ty) { return true; case Instruction::Select: { SelectInst *SI = cast<SelectInst>(I); - return CanEvaluateTruncated(SI->getTrueValue(), Ty) && - CanEvaluateTruncated(SI->getFalseValue(), Ty); + return CanEvaluateTruncated(SI->getTrueValue(), Ty, IC, CxtI) && + CanEvaluateTruncated(SI->getFalseValue(), Ty, IC, CxtI); } case Instruction::PHI: { // We can change a phi if we can change all operands. Note that we never @@ -424,7 +425,7 @@ static bool CanEvaluateTruncated(Value *V, Type *Ty) { // instructions with a single use. PHINode *PN = cast<PHINode>(I); for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) - if (!CanEvaluateTruncated(PN->getIncomingValue(i), Ty)) + if (!CanEvaluateTruncated(PN->getIncomingValue(i), Ty, IC, CxtI)) return false; return true; } @@ -453,7 +454,7 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) { // expression tree to something weird like i93 unless the source is also // strange. if ((DestTy->isVectorTy() || ShouldChangeType(SrcTy, DestTy)) && - CanEvaluateTruncated(Src, DestTy)) { + CanEvaluateTruncated(Src, DestTy, *this, &CI)) { // If this cast is a truncate, evaluting in a different type always // eliminates the cast, so it is always a win. @@ -553,7 +554,7 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, Instruction &CI, // If Op1C some other power of two, convert: uint32_t BitWidth = Op1C->getType()->getBitWidth(); APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); - computeKnownBits(ICI->getOperand(0), KnownZero, KnownOne); + computeKnownBits(ICI->getOperand(0), KnownZero, KnownOne, 0, &CI); APInt KnownZeroMask(~KnownZero); if (KnownZeroMask.isPowerOf2()) { // Exactly 1 possible 1? @@ -601,8 +602,8 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, Instruction &CI, APInt KnownZeroLHS(BitWidth, 0), KnownOneLHS(BitWidth, 0); APInt KnownZeroRHS(BitWidth, 0), KnownOneRHS(BitWidth, 0); - computeKnownBits(LHS, KnownZeroLHS, KnownOneLHS); - computeKnownBits(RHS, KnownZeroRHS, KnownOneRHS); + computeKnownBits(LHS, KnownZeroLHS, KnownOneLHS, 0, &CI); + computeKnownBits(RHS, KnownZeroRHS, KnownOneRHS, 0, &CI); if (KnownZeroLHS == KnownZeroRHS && KnownOneLHS == KnownOneRHS) { APInt KnownBits = KnownZeroLHS | KnownOneLHS; @@ -651,7 +652,8 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, Instruction &CI, /// clear the top bits anyway, doing this has no extra cost. /// /// This function works on both vectors and scalars. -static bool CanEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear) { +static bool CanEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear, + InstCombiner &IC, Instruction *CxtI) { BitsToClear = 0; if (isa<Constant>(V)) return true; @@ -680,8 +682,8 @@ static bool CanEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear) { case Instruction::Add: case Instruction::Sub: case Instruction::Mul: - if (!CanEvaluateZExtd(I->getOperand(0), Ty, BitsToClear) || - !CanEvaluateZExtd(I->getOperand(1), Ty, Tmp)) + if (!CanEvaluateZExtd(I->getOperand(0), Ty, BitsToClear, IC, CxtI) || + !CanEvaluateZExtd(I->getOperand(1), Ty, Tmp, IC, CxtI)) return false; // These can all be promoted if neither operand has 'bits to clear'. if (BitsToClear == 0 && Tmp == 0) @@ -695,8 +697,9 @@ static bool CanEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear) { // We use MaskedValueIsZero here for generality, but the case we care // about the most is constant RHS. unsigned VSize = V->getType()->getScalarSizeInBits(); - if (MaskedValueIsZero(I->getOperand(1), - APInt::getHighBitsSet(VSize, BitsToClear))) + if (IC.MaskedValueIsZero(I->getOperand(1), + APInt::getHighBitsSet(VSize, BitsToClear), + 0, CxtI)) return true; } @@ -707,7 +710,7 @@ static bool CanEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear) { // We can promote shl(x, cst) if we can promote x. Since shl overwrites the // upper bits we can reduce BitsToClear by the shift amount. if (ConstantInt *Amt = dyn_cast<ConstantInt>(I->getOperand(1))) { - if (!CanEvaluateZExtd(I->getOperand(0), Ty, BitsToClear)) + if (!CanEvaluateZExtd(I->getOperand(0), Ty, BitsToClear, IC, CxtI)) return false; uint64_t ShiftAmt = Amt->getZExtValue(); BitsToClear = ShiftAmt < BitsToClear ? BitsToClear - ShiftAmt : 0; @@ -718,7 +721,7 @@ static bool CanEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear) { // We can promote lshr(x, cst) if we can promote x. This requires the // ultimate 'and' to clear out the high zero bits we're clearing out though. if (ConstantInt *Amt = dyn_cast<ConstantInt>(I->getOperand(1))) { - if (!CanEvaluateZExtd(I->getOperand(0), Ty, BitsToClear)) + if (!CanEvaluateZExtd(I->getOperand(0), Ty, BitsToClear, IC, CxtI)) return false; BitsToClear += Amt->getZExtValue(); if (BitsToClear > V->getType()->getScalarSizeInBits()) @@ -728,8 +731,8 @@ static bool CanEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear) { // Cannot promote variable LSHR. return false; case Instruction::Select: - if (!CanEvaluateZExtd(I->getOperand(1), Ty, Tmp) || - !CanEvaluateZExtd(I->getOperand(2), Ty, BitsToClear) || + if (!CanEvaluateZExtd(I->getOperand(1), Ty, Tmp, IC, CxtI) || + !CanEvaluateZExtd(I->getOperand(2), Ty, BitsToClear, IC, CxtI) || // TODO: If important, we could handle the case when the BitsToClear are // known zero in the disagreeing side. Tmp != BitsToClear) @@ -741,10 +744,10 @@ static bool CanEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear) { // get into trouble with cyclic PHIs here because we only consider // instructions with a single use. PHINode *PN = cast<PHINode>(I); - if (!CanEvaluateZExtd(PN->getIncomingValue(0), Ty, BitsToClear)) + if (!CanEvaluateZExtd(PN->getIncomingValue(0), Ty, BitsToClear, IC, CxtI)) return false; for (unsigned i = 1, e = PN->getNumIncomingValues(); i != e; ++i) - if (!CanEvaluateZExtd(PN->getIncomingValue(i), Ty, Tmp) || + if (!CanEvaluateZExtd(PN->getIncomingValue(i), Ty, Tmp, IC, CxtI) || // TODO: If important, we could handle the case when the BitsToClear // are known zero in the disagreeing input. Tmp != BitsToClear) @@ -781,7 +784,7 @@ Instruction *InstCombiner::visitZExt(ZExtInst &CI) { // strange. unsigned BitsToClear; if ((DestTy->isVectorTy() || ShouldChangeType(SrcTy, DestTy)) && - CanEvaluateZExtd(Src, DestTy, BitsToClear)) { + CanEvaluateZExtd(Src, DestTy, BitsToClear, *this, &CI)) { assert(BitsToClear < SrcTy->getScalarSizeInBits() && "Unreasonable BitsToClear"); @@ -796,8 +799,10 @@ Instruction *InstCombiner::visitZExt(ZExtInst &CI) { // If the high bits are already filled with zeros, just replace this // cast with the result. - if (MaskedValueIsZero(Res, APInt::getHighBitsSet(DestBitSize, - DestBitSize-SrcBitsKept))) + if (MaskedValueIsZero(Res, + APInt::getHighBitsSet(DestBitSize, + DestBitSize-SrcBitsKept), + 0, &CI)) return ReplaceInstUsesWith(CI, Res); // We need to emit an AND to clear the high bits. @@ -895,6 +900,10 @@ Instruction *InstCombiner::transformSExtICmp(ICmpInst *ICI, Instruction &CI) { Value *Op0 = ICI->getOperand(0), *Op1 = ICI->getOperand(1); ICmpInst::Predicate Pred = ICI->getPredicate(); + // Don't bother if Op1 isn't of vector or integer type. + if (!Op1->getType()->isIntOrIntVectorTy()) + return nullptr; + if (Constant *Op1C = dyn_cast<Constant>(Op1)) { // (x <s 0) ? -1 : 0 -> ashr x, 31 -> all ones if negative // (x >s -1) ? -1 : 0 -> not (ashr x, 31) -> all ones if positive @@ -921,7 +930,7 @@ Instruction *InstCombiner::transformSExtICmp(ICmpInst *ICI, Instruction &CI) { ICI->isEquality() && (Op1C->isZero() || Op1C->getValue().isPowerOf2())){ unsigned BitWidth = Op1C->getType()->getBitWidth(); APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); - computeKnownBits(Op0, KnownZero, KnownOne); + computeKnownBits(Op0, KnownZero, KnownOne, 0, &CI); APInt KnownZeroMask(~KnownZero); if (KnownZeroMask.isPowerOf2()) { @@ -1072,7 +1081,7 @@ Instruction *InstCombiner::visitSExt(SExtInst &CI) { // If the high bits are already filled with sign bit, just replace this // cast with the result. - if (ComputeNumSignBits(Res) > DestBitSize - SrcBitSize) + if (ComputeNumSignBits(Res, 0, &CI) > DestBitSize - SrcBitSize) return ReplaceInstUsesWith(CI, Res); // We need to emit a shl + ashr to do the sign extend. @@ -1260,14 +1269,18 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &CI) { // type of OpI doesn't enter into things at all. We simply evaluate // in whichever source type is larger, then convert to the // destination type. + if (SrcWidth == OpWidth) + break; if (LHSWidth < SrcWidth) LHSOrig = Builder->CreateFPExt(LHSOrig, RHSOrig->getType()); else if (RHSWidth <= SrcWidth) RHSOrig = Builder->CreateFPExt(RHSOrig, LHSOrig->getType()); - Value *ExactResult = Builder->CreateFRem(LHSOrig, RHSOrig); - if (Instruction *RI = dyn_cast<Instruction>(ExactResult)) - RI->copyFastMathFlags(OpI); - return CastInst::CreateFPCast(ExactResult, CI.getType()); + if (LHSOrig != OpI->getOperand(0) || RHSOrig != OpI->getOperand(1)) { + Value *ExactResult = Builder->CreateFRem(LHSOrig, RHSOrig); + if (Instruction *RI = dyn_cast<Instruction>(ExactResult)) + RI->copyFastMathFlags(OpI); + return CastInst::CreateFPCast(ExactResult, CI.getType()); + } } // (fptrunc (fneg x)) -> (fneg (fptrunc x)) @@ -1312,42 +1325,6 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &CI) { } } - // Fold (fptrunc (sqrt (fpext x))) -> (sqrtf x) - // Note that we restrict this transformation based on - // TLI->has(LibFunc::sqrtf), even for the sqrt intrinsic, because - // TLI->has(LibFunc::sqrtf) is sufficient to guarantee that the - // single-precision intrinsic can be expanded in the backend. - CallInst *Call = dyn_cast<CallInst>(CI.getOperand(0)); - if (Call && Call->getCalledFunction() && TLI->has(LibFunc::sqrtf) && - (Call->getCalledFunction()->getName() == TLI->getName(LibFunc::sqrt) || - Call->getCalledFunction()->getIntrinsicID() == Intrinsic::sqrt) && - Call->getNumArgOperands() == 1 && - Call->hasOneUse()) { - CastInst *Arg = dyn_cast<CastInst>(Call->getArgOperand(0)); - if (Arg && Arg->getOpcode() == Instruction::FPExt && - CI.getType()->isFloatTy() && - Call->getType()->isDoubleTy() && - Arg->getType()->isDoubleTy() && - Arg->getOperand(0)->getType()->isFloatTy()) { - Function *Callee = Call->getCalledFunction(); - Module *M = CI.getParent()->getParent()->getParent(); - Constant *SqrtfFunc = (Callee->getIntrinsicID() == Intrinsic::sqrt) ? - Intrinsic::getDeclaration(M, Intrinsic::sqrt, Builder->getFloatTy()) : - M->getOrInsertFunction("sqrtf", Callee->getAttributes(), - Builder->getFloatTy(), Builder->getFloatTy(), - NULL); - CallInst *ret = CallInst::Create(SqrtfFunc, Arg->getOperand(0), - "sqrtfcall"); - ret->setAttributes(Callee->getAttributes()); - - - // Remove the old Call. With -fmath-errno, it won't get marked readnone. - ReplaceInstUsesWith(*Call, UndefValue::get(Call->getType())); - EraseInstFromFunction(*Call); - return ret; - } - } - return nullptr; } diff --git a/lib/Transforms/InstCombine/InstCombineCompares.cpp b/lib/Transforms/InstCombine/InstCombineCompares.cpp index 5e71c5c4b7cb..c07c96d49aab 100644 --- a/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -12,6 +12,8 @@ //===----------------------------------------------------------------------===// #include "InstCombine.h" +#include "llvm/ADT/APSInt.h" +#include "llvm/ADT/Statistic.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/MemoryBuiltins.h" @@ -20,12 +22,20 @@ #include "llvm/IR/GetElementPtrTypeIterator.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" #include "llvm/Target/TargetLibraryInfo.h" + using namespace llvm; using namespace PatternMatch; #define DEBUG_TYPE "instcombine" +// How many times is a select replaced by one of its operands? +STATISTIC(NumSel, "Number of select opts"); + +// Initialization Routines + static ConstantInt *getOne(Constant *C) { return ConstantInt::get(cast<IntegerType>(C->getType()), 1); } @@ -740,21 +750,6 @@ Instruction *InstCombiner::FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS, Instruction *InstCombiner::FoldICmpAddOpCst(Instruction &ICI, Value *X, ConstantInt *CI, ICmpInst::Predicate Pred) { - // If we have X+0, exit early (simplifying logic below) and let it get folded - // elsewhere. icmp X+0, X -> icmp X, X - if (CI->isZero()) { - bool isTrue = ICmpInst::isTrueWhenEqual(Pred); - return ReplaceInstUsesWith(ICI, ConstantInt::get(ICI.getType(), isTrue)); - } - - // (X+4) == X -> false. - if (Pred == ICmpInst::ICMP_EQ) - return ReplaceInstUsesWith(ICI, Builder->getFalse()); - - // (X+4) != X -> true. - if (Pred == ICmpInst::ICMP_NE) - return ReplaceInstUsesWith(ICI, Builder->getTrue()); - // From this point on, we know that (X+C <= X) --> (X+C < X) because C != 0, // so the values can never be equal. Similarly for all other "or equals" // operators. @@ -1044,6 +1039,111 @@ Instruction *InstCombiner::FoldICmpShrCst(ICmpInst &ICI, BinaryOperator *Shr, return nullptr; } +/// FoldICmpCstShrCst - Handle "(icmp eq/ne (ashr/lshr const2, A), const1)" -> +/// (icmp eq/ne A, Log2(const2/const1)) -> +/// (icmp eq/ne A, Log2(const2) - Log2(const1)). +Instruction *InstCombiner::FoldICmpCstShrCst(ICmpInst &I, Value *Op, Value *A, + ConstantInt *CI1, + ConstantInt *CI2) { + assert(I.isEquality() && "Cannot fold icmp gt/lt"); + + auto getConstant = [&I, this](bool IsTrue) { + if (I.getPredicate() == I.ICMP_NE) + IsTrue = !IsTrue; + return ReplaceInstUsesWith(I, ConstantInt::get(I.getType(), IsTrue)); + }; + + auto getICmp = [&I](CmpInst::Predicate Pred, Value *LHS, Value *RHS) { + if (I.getPredicate() == I.ICMP_NE) + Pred = CmpInst::getInversePredicate(Pred); + return new ICmpInst(Pred, LHS, RHS); + }; + + APInt AP1 = CI1->getValue(); + APInt AP2 = CI2->getValue(); + + // Don't bother doing any work for cases which InstSimplify handles. + if (AP2 == 0) + return nullptr; + bool IsAShr = isa<AShrOperator>(Op); + if (IsAShr) { + if (AP2.isAllOnesValue()) + return nullptr; + if (AP2.isNegative() != AP1.isNegative()) + return nullptr; + if (AP2.sgt(AP1)) + return nullptr; + } + + if (!AP1) + // 'A' must be large enough to shift out the highest set bit. + return getICmp(I.ICMP_UGT, A, + ConstantInt::get(A->getType(), AP2.logBase2())); + + if (AP1 == AP2) + return getICmp(I.ICMP_EQ, A, ConstantInt::getNullValue(A->getType())); + + // Get the distance between the highest bit that's set. + int Shift; + // Both the constants are negative, take their positive to calculate log. + if (IsAShr && AP1.isNegative()) + // Get the ones' complement of AP2 and AP1 when computing the distance. + Shift = (~AP2).logBase2() - (~AP1).logBase2(); + else + Shift = AP2.logBase2() - AP1.logBase2(); + + if (Shift > 0) { + if (IsAShr ? AP1 == AP2.ashr(Shift) : AP1 == AP2.lshr(Shift)) + return getICmp(I.ICMP_EQ, A, ConstantInt::get(A->getType(), Shift)); + } + // Shifting const2 will never be equal to const1. + return getConstant(false); +} + +/// FoldICmpCstShlCst - Handle "(icmp eq/ne (shl const2, A), const1)" -> +/// (icmp eq/ne A, TrailingZeros(const1) - TrailingZeros(const2)). +Instruction *InstCombiner::FoldICmpCstShlCst(ICmpInst &I, Value *Op, Value *A, + ConstantInt *CI1, + ConstantInt *CI2) { + assert(I.isEquality() && "Cannot fold icmp gt/lt"); + + auto getConstant = [&I, this](bool IsTrue) { + if (I.getPredicate() == I.ICMP_NE) + IsTrue = !IsTrue; + return ReplaceInstUsesWith(I, ConstantInt::get(I.getType(), IsTrue)); + }; + + auto getICmp = [&I](CmpInst::Predicate Pred, Value *LHS, Value *RHS) { + if (I.getPredicate() == I.ICMP_NE) + Pred = CmpInst::getInversePredicate(Pred); + return new ICmpInst(Pred, LHS, RHS); + }; + + APInt AP1 = CI1->getValue(); + APInt AP2 = CI2->getValue(); + + // Don't bother doing any work for cases which InstSimplify handles. + if (AP2 == 0) + return nullptr; + + unsigned AP2TrailingZeros = AP2.countTrailingZeros(); + + if (!AP1 && AP2TrailingZeros != 0) + return getICmp(I.ICMP_UGE, A, + ConstantInt::get(A->getType(), AP2.getBitWidth() - AP2TrailingZeros)); + + if (AP1 == AP2) + return getICmp(I.ICMP_EQ, A, ConstantInt::getNullValue(A->getType())); + + // Get the distance between the lowest bits that are set. + int Shift = AP1.countTrailingZeros() - AP2TrailingZeros; + + if (Shift > 0 && AP2.shl(Shift) == AP1) + return getICmp(I.ICMP_EQ, A, ConstantInt::get(A->getType(), Shift)); + + // Shifting const2 will never be equal to const1. + return getConstant(false); +} /// visitICmpInstWithInstAndIntCst - Handle "icmp (instr, intcst)". /// @@ -1060,7 +1160,7 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, unsigned DstBits = LHSI->getType()->getPrimitiveSizeInBits(), SrcBits = LHSI->getOperand(0)->getType()->getPrimitiveSizeInBits(); APInt KnownZero(SrcBits, 0), KnownOne(SrcBits, 0); - computeKnownBits(LHSI->getOperand(0), KnownZero, KnownOne); + computeKnownBits(LHSI->getOperand(0), KnownZero, KnownOne, 0, &ICI); // If all the high bits are known, we can do this xform. if ((KnownZero|KnownOne).countLeadingOnes() >= SrcBits-DstBits) { @@ -1282,6 +1382,48 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, return &ICI; } + // (icmp pred (and (or (lshr X, Y), X), 1), 0) --> + // (icmp pred (and X, (or (shl 1, Y), 1), 0)) + // + // iff pred isn't signed + { + Value *X, *Y, *LShr; + if (!ICI.isSigned() && RHSV == 0) { + if (match(LHSI->getOperand(1), m_One())) { + Constant *One = cast<Constant>(LHSI->getOperand(1)); + Value *Or = LHSI->getOperand(0); + if (match(Or, m_Or(m_Value(LShr), m_Value(X))) && + match(LShr, m_LShr(m_Specific(X), m_Value(Y)))) { + unsigned UsesRemoved = 0; + if (LHSI->hasOneUse()) + ++UsesRemoved; + if (Or->hasOneUse()) + ++UsesRemoved; + if (LShr->hasOneUse()) + ++UsesRemoved; + Value *NewOr = nullptr; + // Compute X & ((1 << Y) | 1) + if (auto *C = dyn_cast<Constant>(Y)) { + if (UsesRemoved >= 1) + NewOr = + ConstantExpr::getOr(ConstantExpr::getNUWShl(One, C), One); + } else { + if (UsesRemoved >= 3) + NewOr = Builder->CreateOr(Builder->CreateShl(One, Y, + LShr->getName(), + /*HasNUW=*/true), + One, Or->getName()); + } + if (NewOr) { + Value *NewAnd = Builder->CreateAnd(X, NewOr, LHSI->getName()); + ICI.setOperand(0, NewAnd); + return &ICI; + } + } + } + } + } + // Replace ((X & AndCst) > RHSV) with ((X & AndCst) != 0), if any // bit set in (X & AndCst) will produce a result greater than RHSV. if (ICI.getPredicate() == ICmpInst::ICMP_UGT) { @@ -1377,16 +1519,10 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, unsigned RHSLog2 = RHSV.logBase2(); // (1 << X) >= 2147483648 -> X >= 31 -> X == 31 - // (1 << X) > 2147483648 -> X > 31 -> false - // (1 << X) <= 2147483648 -> X <= 31 -> true // (1 << X) < 2147483648 -> X < 31 -> X != 31 if (RHSLog2 == TypeBits-1) { if (Pred == ICmpInst::ICMP_UGE) Pred = ICmpInst::ICMP_EQ; - else if (Pred == ICmpInst::ICMP_UGT) - return ReplaceInstUsesWith(ICI, Builder->getFalse()); - else if (Pred == ICmpInst::ICMP_ULE) - return ReplaceInstUsesWith(ICI, Builder->getTrue()); else if (Pred == ICmpInst::ICMP_ULT) Pred = ICmpInst::ICMP_NE; } @@ -1421,10 +1557,6 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, if (RHSVIsPowerOf2) return new ICmpInst( Pred, X, ConstantInt::get(RHS->getType(), RHSV.logBase2())); - - return ReplaceInstUsesWith( - ICI, Pred == ICmpInst::ICMP_EQ ? Builder->getFalse() - : Builder->getTrue()); } } break; @@ -1932,8 +2064,8 @@ static Instruction *ProcessUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B, // sign-extended; check for that condition. For example, if CI2 is 2^31 and // the operands of the add are 64 bits wide, we need at least 33 sign bits. unsigned NeededSignBits = CI1->getBitWidth() - NewWidth + 1; - if (IC.ComputeNumSignBits(A) < NeededSignBits || - IC.ComputeNumSignBits(B) < NeededSignBits) + if (IC.ComputeNumSignBits(A, 0, &I) < NeededSignBits || + IC.ComputeNumSignBits(B, 0, &I) < NeededSignBits) return nullptr; // In order to replace the original add with a narrower @@ -2038,8 +2170,8 @@ static Instruction *ProcessUMulZExtIdiom(ICmpInst &I, Value *MulVal, Instruction *MulInstr = cast<Instruction>(MulVal); assert(MulInstr->getOpcode() == Instruction::Mul); - Instruction *LHS = cast<Instruction>(MulInstr->getOperand(0)), - *RHS = cast<Instruction>(MulInstr->getOperand(1)); + auto *LHS = cast<ZExtOperator>(MulInstr->getOperand(0)), + *RHS = cast<ZExtOperator>(MulInstr->getOperand(1)); assert(LHS->getOpcode() == Instruction::ZExt); assert(RHS->getOpcode() == Instruction::ZExt); Value *A = LHS->getOperand(0), *B = RHS->getOperand(0); @@ -2324,6 +2456,122 @@ static bool swapMayExposeCSEOpportunities(const Value * Op0, return GlobalSwapBenefits > 0; } +/// \brief Check that one use is in the same block as the definition and all +/// other uses are in blocks dominated by a given block +/// +/// \param DI Definition +/// \param UI Use +/// \param DB Block that must dominate all uses of \p DI outside +/// the parent block +/// \return true when \p UI is the only use of \p DI in the parent block +/// and all other uses of \p DI are in blocks dominated by \p DB. +/// +bool InstCombiner::dominatesAllUses(const Instruction *DI, + const Instruction *UI, + const BasicBlock *DB) const { + assert(DI && UI && "Instruction not defined\n"); + // ignore incomplete definitions + if (!DI->getParent()) + return false; + // DI and UI must be in the same block + if (DI->getParent() != UI->getParent()) + return false; + // Protect from self-referencing blocks + if (DI->getParent() == DB) + return false; + // DominatorTree available? + if (!DT) + return false; + for (const User *U : DI->users()) { + auto *Usr = cast<Instruction>(U); + if (Usr != UI && !DT->dominates(DB, Usr->getParent())) + return false; + } + return true; +} + +/// +/// true when the instruction sequence within a block is select-cmp-br. +/// +static bool isChainSelectCmpBranch(const SelectInst *SI) { + const BasicBlock *BB = SI->getParent(); + if (!BB) + return false; + auto *BI = dyn_cast_or_null<BranchInst>(BB->getTerminator()); + if (!BI || BI->getNumSuccessors() != 2) + return false; + auto *IC = dyn_cast<ICmpInst>(BI->getCondition()); + if (!IC || (IC->getOperand(0) != SI && IC->getOperand(1) != SI)) + return false; + return true; +} + +/// +/// \brief True when a select result is replaced by one of its operands +/// in select-icmp sequence. This will eventually result in the elimination +/// of the select. +/// +/// \param SI Select instruction +/// \param Icmp Compare instruction +/// \param SIOpd Operand that replaces the select +/// +/// Notes: +/// - The replacement is global and requires dominator information +/// - The caller is responsible for the actual replacement +/// +/// Example: +/// +/// entry: +/// %4 = select i1 %3, %C* %0, %C* null +/// %5 = icmp eq %C* %4, null +/// br i1 %5, label %9, label %7 +/// ... +/// ; <label>:7 ; preds = %entry +/// %8 = getelementptr inbounds %C* %4, i64 0, i32 0 +/// ... +/// +/// can be transformed to +/// +/// %5 = icmp eq %C* %0, null +/// %6 = select i1 %3, i1 %5, i1 true +/// br i1 %6, label %9, label %7 +/// ... +/// ; <label>:7 ; preds = %entry +/// %8 = getelementptr inbounds %C* %0, i64 0, i32 0 // replace by %0! +/// +/// Similar when the first operand of the select is a constant or/and +/// the compare is for not equal rather than equal. +/// +/// NOTE: The function is only called when the select and compare constants +/// are equal, the optimization can work only for EQ predicates. This is not a +/// major restriction since a NE compare should be 'normalized' to an equal +/// compare, which usually happens in the combiner and test case +/// select-cmp-br.ll +/// checks for it. +bool InstCombiner::replacedSelectWithOperand(SelectInst *SI, + const ICmpInst *Icmp, + const unsigned SIOpd) { + assert((SIOpd == 1 || SIOpd == 2) && "Invalid select operand!"); + if (isChainSelectCmpBranch(SI) && Icmp->getPredicate() == ICmpInst::ICMP_EQ) { + BasicBlock *Succ = SI->getParent()->getTerminator()->getSuccessor(1); + // The check for the unique predecessor is not the best that can be + // done. But it protects efficiently against cases like when SI's + // home block has two successors, Succ and Succ1, and Succ1 predecessor + // of Succ. Then SI can't be replaced by SIOpd because the use that gets + // replaced can be reached on either path. So the uniqueness check + // guarantees that the path all uses of SI (outside SI's parent) are on + // is disjoint from all other paths out of SI. But that information + // is more expensive to compute, and the trade-off here is in favor + // of compile-time. + if (Succ->getUniquePredecessor() && dominatesAllUses(SI, Icmp, Succ)) { + NumSel++; + SI->replaceUsesOutsideBlock(SI->getOperand(SIOpd), SI->getParent()); + return true; + } + } + return false; +} + Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { bool Changed = false; Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); @@ -2341,7 +2589,7 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { Changed = true; } - if (Value *V = SimplifyICmpInst(I.getPredicate(), Op0, Op1, DL)) + if (Value *V = SimplifyICmpInst(I.getPredicate(), Op0, Op1, DL, TLI, DT, AC)) return ReplaceInstUsesWith(I, V); // comparing -val or val with non-zero is the same as just comparing val @@ -2438,11 +2686,33 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { return Res; } - // (icmp ne/eq (sub A B) 0) -> (icmp ne/eq A, B) - if (I.isEquality() && CI->isZero() && - match(Op0, m_Sub(m_Value(A), m_Value(B)))) { - // (icmp cond A B) if cond is equality - return new ICmpInst(I.getPredicate(), A, B); + // The following transforms are only 'worth it' if the only user of the + // subtraction is the icmp. + if (Op0->hasOneUse()) { + // (icmp ne/eq (sub A B) 0) -> (icmp ne/eq A, B) + if (I.isEquality() && CI->isZero() && + match(Op0, m_Sub(m_Value(A), m_Value(B)))) + return new ICmpInst(I.getPredicate(), A, B); + + // (icmp sgt (sub nsw A B), -1) -> (icmp sge A, B) + if (I.getPredicate() == ICmpInst::ICMP_SGT && CI->isAllOnesValue() && + match(Op0, m_NSWSub(m_Value(A), m_Value(B)))) + return new ICmpInst(ICmpInst::ICMP_SGE, A, B); + + // (icmp sgt (sub nsw A B), 0) -> (icmp sgt A, B) + if (I.getPredicate() == ICmpInst::ICMP_SGT && CI->isZero() && + match(Op0, m_NSWSub(m_Value(A), m_Value(B)))) + return new ICmpInst(ICmpInst::ICMP_SGT, A, B); + + // (icmp slt (sub nsw A B), 0) -> (icmp slt A, B) + if (I.getPredicate() == ICmpInst::ICMP_SLT && CI->isZero() && + match(Op0, m_NSWSub(m_Value(A), m_Value(B)))) + return new ICmpInst(ICmpInst::ICMP_SLT, A, B); + + // (icmp slt (sub nsw A B), 1) -> (icmp sle A, B) + if (I.getPredicate() == ICmpInst::ICMP_SLT && CI->isOne() && + match(Op0, m_NSWSub(m_Value(A), m_Value(B)))) + return new ICmpInst(ICmpInst::ICMP_SLE, A, B); } // If we have an icmp le or icmp ge instruction, turn it into the @@ -2469,6 +2739,21 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { Builder->getInt(CI->getValue()-1)); } + if (I.isEquality()) { + ConstantInt *CI2; + if (match(Op0, m_AShr(m_ConstantInt(CI2), m_Value(A))) || + match(Op0, m_LShr(m_ConstantInt(CI2), m_Value(A)))) { + // (icmp eq/ne (ashr/lshr const2, A), const1) + if (Instruction *Inst = FoldICmpCstShrCst(I, Op0, A, CI, CI2)) + return Inst; + } + if (match(Op0, m_Shl(m_ConstantInt(CI2), m_Value(A)))) { + // (icmp eq/ne (shl const2, A), const1) + if (Instruction *Inst = FoldICmpCstShlCst(I, Op0, A, CI, CI2)) + return Inst; + } + } + // If this comparison is a normal comparison, it demands all // bits, if it is a sign bit comparison, it only demands the sign bit. bool UnusedBit; @@ -2761,18 +3046,39 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { // comparison into the select arms, which will cause one to be // constant folded and the select turned into a bitwise or. Value *Op1 = nullptr, *Op2 = nullptr; - if (Constant *C = dyn_cast<Constant>(LHSI->getOperand(1))) + ConstantInt *CI = 0; + if (Constant *C = dyn_cast<Constant>(LHSI->getOperand(1))) { Op1 = ConstantExpr::getICmp(I.getPredicate(), C, RHSC); - if (Constant *C = dyn_cast<Constant>(LHSI->getOperand(2))) + CI = dyn_cast<ConstantInt>(Op1); + } + if (Constant *C = dyn_cast<Constant>(LHSI->getOperand(2))) { Op2 = ConstantExpr::getICmp(I.getPredicate(), C, RHSC); + CI = dyn_cast<ConstantInt>(Op2); + } // We only want to perform this transformation if it will not lead to // additional code. This is true if either both sides of the select // fold to a constant (in which case the icmp is replaced with a select // which will usually simplify) or this is the only user of the // select (in which case we are trading a select+icmp for a simpler - // select+icmp). - if ((Op1 && Op2) || (LHSI->hasOneUse() && (Op1 || Op2))) { + // select+icmp) or all uses of the select can be replaced based on + // dominance information ("Global cases"). + bool Transform = false; + if (Op1 && Op2) + Transform = true; + else if (Op1 || Op2) { + // Local case + if (LHSI->hasOneUse()) + Transform = true; + // Global cases + else if (CI && !CI->isZero()) + // When Op1 is constant try replacing select with second operand. + // Otherwise Op2 is constant and try replacing select with first + // operand. + Transform = replacedSelectWithOperand(cast<SelectInst>(LHSI), &I, + Op1 ? 2 : 1); + } + if (Transform) { if (!Op1) Op1 = Builder->CreateICmp(I.getPredicate(), LHSI->getOperand(1), RHSC, I.getName()); @@ -2878,6 +3184,12 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { if (BO1 && BO1->getOpcode() == Instruction::Add) C = BO1->getOperand(0), D = BO1->getOperand(1); + // icmp (X+cst) < 0 --> X < -cst + if (NoOp0WrapProblem && ICmpInst::isSigned(Pred) && match(Op1, m_Zero())) + if (ConstantInt *RHSC = dyn_cast_or_null<ConstantInt>(B)) + if (!RHSC->isMinValue(/*isSigned=*/true)) + return new ICmpInst(Pred, A, ConstantExpr::getNeg(RHSC)); + // icmp (X+Y), X -> icmp Y, 0 for equalities or if there is no overflow. if ((A == Op1 || B == Op1) && NoOp0WrapProblem) return new ICmpInst(Pred, A == Op1 ? B : A, @@ -3112,7 +3424,8 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { // and (A & ~B) != 0 --> (A & B) == 0 // if A is a power of 2. if (match(Op0, m_And(m_Value(A), m_Not(m_Value(B)))) && - match(Op1, m_Zero()) && isKnownToBeAPowerOfTwo(A) && I.isEquality()) + match(Op1, m_Zero()) && + isKnownToBeAPowerOfTwo(A, false, 0, AC, &I, DT) && I.isEquality()) return new ICmpInst(I.getInversePredicate(), Builder->CreateAnd(A, B), Op1); @@ -3273,6 +3586,22 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { } } + // The 'cmpxchg' instruction returns an aggregate containing the old value and + // an i1 which indicates whether or not we successfully did the swap. + // + // Replace comparisons between the old value and the expected value with the + // indicator that 'cmpxchg' returns. + // + // N.B. This transform is only valid when the 'cmpxchg' is not permitted to + // spuriously fail. In those cases, the old value may equal the expected + // value but it is possible for the swap to not occur. + if (I.getPredicate() == ICmpInst::ICMP_EQ) + if (auto *EVI = dyn_cast<ExtractValueInst>(Op0)) + if (auto *ACXI = dyn_cast<AtomicCmpXchgInst>(EVI->getAggregateOperand())) + if (EVI->getIndices()[0] == 0 && ACXI->getCompareOperand() == Op1 && + !ACXI->isWeak()) + return ExtractValueInst::Create(ACXI, 1); + { Value *X; ConstantInt *Cst; // icmp X+Cst, X @@ -3287,7 +3616,6 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { } /// FoldFCmp_IntToFP_Cst - Fold fcmp ([us]itofp x, cst) if possible. -/// Instruction *InstCombiner::FoldFCmp_IntToFP_Cst(FCmpInst &I, Instruction *LHSI, Constant *RHSC) { @@ -3299,18 +3627,49 @@ Instruction *InstCombiner::FoldFCmp_IntToFP_Cst(FCmpInst &I, int MantissaWidth = LHSI->getType()->getFPMantissaWidth(); if (MantissaWidth == -1) return nullptr; // Unknown. + IntegerType *IntTy = cast<IntegerType>(LHSI->getOperand(0)->getType()); + // Check to see that the input is converted from an integer type that is small // enough that preserves all bits. TODO: check here for "known" sign bits. // This would allow us to handle (fptosi (x >>s 62) to float) if x is i64 f.e. - unsigned InputSize = LHSI->getOperand(0)->getType()->getScalarSizeInBits(); + unsigned InputSize = IntTy->getScalarSizeInBits(); // If this is a uitofp instruction, we need an extra bit to hold the sign. bool LHSUnsigned = isa<UIToFPInst>(LHSI); if (LHSUnsigned) ++InputSize; + if (I.isEquality()) { + FCmpInst::Predicate P = I.getPredicate(); + bool IsExact = false; + APSInt RHSCvt(IntTy->getBitWidth(), LHSUnsigned); + RHS.convertToInteger(RHSCvt, APFloat::rmNearestTiesToEven, &IsExact); + + // If the floating point constant isn't an integer value, we know if we will + // ever compare equal / not equal to it. + if (!IsExact) { + // TODO: Can never be -0.0 and other non-representable values + APFloat RHSRoundInt(RHS); + RHSRoundInt.roundToIntegral(APFloat::rmNearestTiesToEven); + if (RHS.compare(RHSRoundInt) != APFloat::cmpEqual) { + if (P == FCmpInst::FCMP_OEQ || P == FCmpInst::FCMP_UEQ) + return ReplaceInstUsesWith(I, Builder->getFalse()); + + assert(P == FCmpInst::FCMP_ONE || P == FCmpInst::FCMP_UNE); + return ReplaceInstUsesWith(I, Builder->getTrue()); + } + } + + // TODO: If the constant is exactly representable, is it always OK to do + // equality compares as integer? + } + + // Comparisons with zero are a special case where we know we won't lose + // information. + bool IsCmpZero = RHS.isPosZero(); + // If the conversion would lose info, don't hack on this. - if ((int)InputSize > MantissaWidth) + if ((int)InputSize > MantissaWidth && !IsCmpZero) return nullptr; // Otherwise, we can potentially simplify the comparison. We know that it @@ -3351,8 +3710,6 @@ Instruction *InstCombiner::FoldFCmp_IntToFP_Cst(FCmpInst &I, return ReplaceInstUsesWith(I, Builder->getFalse()); } - IntegerType *IntTy = cast<IntegerType>(LHSI->getOperand(0)->getType()); - // Now we know that the APFloat is a normal number, zero or inf. // See if the FP constant is too large for the integer. For example, @@ -3502,7 +3859,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - if (Value *V = SimplifyFCmpInst(I.getPredicate(), Op0, Op1, DL)) + if (Value *V = SimplifyFCmpInst(I.getPredicate(), Op0, Op1, DL, TLI, DT, AC)) return ReplaceInstUsesWith(I, V); // Simplify 'fcmp pred X, X' @@ -3605,40 +3962,42 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { } break; case Instruction::Call: { + if (!RHSC->isNullValue()) + break; + CallInst *CI = cast<CallInst>(LHSI); - LibFunc::Func Func; + const Function *F = CI->getCalledFunction(); + if (!F) + break; + // Various optimization for fabs compared with zero. - if (RHSC->isNullValue() && CI->getCalledFunction() && - TLI->getLibFunc(CI->getCalledFunction()->getName(), Func) && - TLI->has(Func)) { - if (Func == LibFunc::fabs || Func == LibFunc::fabsf || - Func == LibFunc::fabsl) { - switch (I.getPredicate()) { - default: break; + LibFunc::Func Func; + if (F->getIntrinsicID() == Intrinsic::fabs || + (TLI->getLibFunc(F->getName(), Func) && TLI->has(Func) && + (Func == LibFunc::fabs || Func == LibFunc::fabsf || + Func == LibFunc::fabsl))) { + switch (I.getPredicate()) { + default: + break; // fabs(x) < 0 --> false - case FCmpInst::FCMP_OLT: - return ReplaceInstUsesWith(I, Builder->getFalse()); + case FCmpInst::FCMP_OLT: + return ReplaceInstUsesWith(I, Builder->getFalse()); // fabs(x) > 0 --> x != 0 - case FCmpInst::FCMP_OGT: - return new FCmpInst(FCmpInst::FCMP_ONE, CI->getArgOperand(0), - RHSC); + case FCmpInst::FCMP_OGT: + return new FCmpInst(FCmpInst::FCMP_ONE, CI->getArgOperand(0), RHSC); // fabs(x) <= 0 --> x == 0 - case FCmpInst::FCMP_OLE: - return new FCmpInst(FCmpInst::FCMP_OEQ, CI->getArgOperand(0), - RHSC); + case FCmpInst::FCMP_OLE: + return new FCmpInst(FCmpInst::FCMP_OEQ, CI->getArgOperand(0), RHSC); // fabs(x) >= 0 --> !isnan(x) - case FCmpInst::FCMP_OGE: - return new FCmpInst(FCmpInst::FCMP_ORD, CI->getArgOperand(0), - RHSC); + case FCmpInst::FCMP_OGE: + return new FCmpInst(FCmpInst::FCMP_ORD, CI->getArgOperand(0), RHSC); // fabs(x) == 0 --> x == 0 // fabs(x) != 0 --> x != 0 - case FCmpInst::FCMP_OEQ: - case FCmpInst::FCMP_UEQ: - case FCmpInst::FCMP_ONE: - case FCmpInst::FCMP_UNE: - return new FCmpInst(I.getPredicate(), CI->getArgOperand(0), - RHSC); - } + case FCmpInst::FCMP_OEQ: + case FCmpInst::FCMP_UEQ: + case FCmpInst::FCMP_ONE: + case FCmpInst::FCMP_UNE: + return new FCmpInst(I.getPredicate(), CI->getArgOperand(0), RHSC); } } } diff --git a/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp index e9c25d32c281..af1694d3453c 100644 --- a/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp +++ b/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp @@ -15,6 +15,7 @@ #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/Loads.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/LLVMContext.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" @@ -267,8 +268,8 @@ Instruction *InstCombiner::visitAllocaInst(AllocaInst &AI) { // is only subsequently read. SmallVector<Instruction *, 4> ToDelete; if (MemTransferInst *Copy = isOnlyCopiedFromConstantGlobal(&AI, ToDelete)) { - unsigned SourceAlign = getOrEnforceKnownAlignment(Copy->getSource(), - AI.getAlignment(), DL); + unsigned SourceAlign = getOrEnforceKnownAlignment( + Copy->getSource(), AI.getAlignment(), DL, AC, &AI, DT); if (AI.getAlignment() <= SourceAlign) { DEBUG(dbgs() << "Found alloca equal to global: " << AI << '\n'); DEBUG(dbgs() << " memcpy = " << *Copy << '\n'); @@ -290,80 +291,111 @@ Instruction *InstCombiner::visitAllocaInst(AllocaInst &AI) { return visitAllocSite(AI); } +/// \brief Helper to combine a load to a new type. +/// +/// This just does the work of combining a load to a new type. It handles +/// metadata, etc., and returns the new instruction. The \c NewTy should be the +/// loaded *value* type. This will convert it to a pointer, cast the operand to +/// that pointer type, load it, etc. +/// +/// Note that this will create all of the instructions with whatever insert +/// point the \c InstCombiner currently is using. +static LoadInst *combineLoadToNewType(InstCombiner &IC, LoadInst &LI, Type *NewTy) { + Value *Ptr = LI.getPointerOperand(); + unsigned AS = LI.getPointerAddressSpace(); + SmallVector<std::pair<unsigned, MDNode *>, 8> MD; + LI.getAllMetadata(MD); + + LoadInst *NewLoad = IC.Builder->CreateAlignedLoad( + IC.Builder->CreateBitCast(Ptr, NewTy->getPointerTo(AS)), + LI.getAlignment(), LI.getName()); + for (const auto &MDPair : MD) { + unsigned ID = MDPair.first; + MDNode *N = MDPair.second; + // Note, essentially every kind of metadata should be preserved here! This + // routine is supposed to clone a load instruction changing *only its type*. + // The only metadata it makes sense to drop is metadata which is invalidated + // when the pointer type changes. This should essentially never be the case + // in LLVM, but we explicitly switch over only known metadata to be + // conservatively correct. If you are adding metadata to LLVM which pertains + // to loads, you almost certainly want to add it here. + switch (ID) { + case LLVMContext::MD_dbg: + case LLVMContext::MD_tbaa: + case LLVMContext::MD_prof: + case LLVMContext::MD_fpmath: + case LLVMContext::MD_tbaa_struct: + case LLVMContext::MD_invariant_load: + case LLVMContext::MD_alias_scope: + case LLVMContext::MD_noalias: + case LLVMContext::MD_nontemporal: + case LLVMContext::MD_mem_parallel_loop_access: + case LLVMContext::MD_nonnull: + // All of these directly apply. + NewLoad->setMetadata(ID, N); + break; -/// InstCombineLoadCast - Fold 'load (cast P)' -> cast (load P)' when possible. -static Instruction *InstCombineLoadCast(InstCombiner &IC, LoadInst &LI, - const DataLayout *DL) { - User *CI = cast<User>(LI.getOperand(0)); - Value *CastOp = CI->getOperand(0); - - PointerType *DestTy = cast<PointerType>(CI->getType()); - Type *DestPTy = DestTy->getElementType(); - if (PointerType *SrcTy = dyn_cast<PointerType>(CastOp->getType())) { - - // If the address spaces don't match, don't eliminate the cast. - if (DestTy->getAddressSpace() != SrcTy->getAddressSpace()) - return nullptr; - - Type *SrcPTy = SrcTy->getElementType(); - - if (DestPTy->isIntegerTy() || DestPTy->isPointerTy() || - DestPTy->isVectorTy()) { - // If the source is an array, the code below will not succeed. Check to - // see if a trivial 'gep P, 0, 0' will help matters. Only do this for - // constants. - if (ArrayType *ASrcTy = dyn_cast<ArrayType>(SrcPTy)) - if (Constant *CSrc = dyn_cast<Constant>(CastOp)) - if (ASrcTy->getNumElements() != 0) { - Type *IdxTy = DL - ? DL->getIntPtrType(SrcTy) - : Type::getInt64Ty(SrcTy->getContext()); - Value *Idx = Constant::getNullValue(IdxTy); - Value *Idxs[2] = { Idx, Idx }; - CastOp = ConstantExpr::getGetElementPtr(CSrc, Idxs); - SrcTy = cast<PointerType>(CastOp->getType()); - SrcPTy = SrcTy->getElementType(); - } - - if (IC.getDataLayout() && - (SrcPTy->isIntegerTy() || SrcPTy->isPointerTy() || - SrcPTy->isVectorTy()) && - // Do not allow turning this into a load of an integer, which is then - // casted to a pointer, this pessimizes pointer analysis a lot. - (SrcPTy->isPtrOrPtrVectorTy() == - LI.getType()->isPtrOrPtrVectorTy()) && - IC.getDataLayout()->getTypeSizeInBits(SrcPTy) == - IC.getDataLayout()->getTypeSizeInBits(DestPTy)) { - - // Okay, we are casting from one integer or pointer type to another of - // the same size. Instead of casting the pointer before the load, cast - // the result of the loaded value. - LoadInst *NewLoad = - IC.Builder->CreateLoad(CastOp, LI.isVolatile(), CI->getName()); - NewLoad->setAlignment(LI.getAlignment()); - NewLoad->setAtomic(LI.getOrdering(), LI.getSynchScope()); - // Now cast the result of the load. - PointerType *OldTy = dyn_cast<PointerType>(NewLoad->getType()); - PointerType *NewTy = dyn_cast<PointerType>(LI.getType()); - if (OldTy && NewTy && - OldTy->getAddressSpace() != NewTy->getAddressSpace()) { - return new AddrSpaceCastInst(NewLoad, LI.getType()); - } - - return new BitCastInst(NewLoad, LI.getType()); - } + case LLVMContext::MD_range: + // FIXME: It would be nice to propagate this in some way, but the type + // conversions make it hard. + break; } } + return NewLoad; +} + +/// \brief Combine loads to match the type of value their uses after looking +/// through intervening bitcasts. +/// +/// The core idea here is that if the result of a load is used in an operation, +/// we should load the type most conducive to that operation. For example, when +/// loading an integer and converting that immediately to a pointer, we should +/// instead directly load a pointer. +/// +/// However, this routine must never change the width of a load or the number of +/// loads as that would introduce a semantic change. This combine is expected to +/// be a semantic no-op which just allows loads to more closely model the types +/// of their consuming operations. +/// +/// Currently, we also refuse to change the precise type used for an atomic load +/// or a volatile load. This is debatable, and might be reasonable to change +/// later. However, it is risky in case some backend or other part of LLVM is +/// relying on the exact type loaded to select appropriate atomic operations. +static Instruction *combineLoadToOperationType(InstCombiner &IC, LoadInst &LI) { + // FIXME: We could probably with some care handle both volatile and atomic + // loads here but it isn't clear that this is important. + if (!LI.isSimple()) + return nullptr; + + if (LI.use_empty()) + return nullptr; + + + // Fold away bit casts of the loaded value by loading the desired type. + if (LI.hasOneUse()) + if (auto *BC = dyn_cast<BitCastInst>(LI.user_back())) { + LoadInst *NewLoad = combineLoadToNewType(IC, LI, BC->getDestTy()); + BC->replaceAllUsesWith(NewLoad); + IC.EraseInstFromFunction(*BC); + return &LI; + } + + // FIXME: We should also canonicalize loads of vectors when their elements are + // cast to other types. return nullptr; } Instruction *InstCombiner::visitLoadInst(LoadInst &LI) { Value *Op = LI.getOperand(0); + // Try to canonicalize the loaded type. + if (Instruction *Res = combineLoadToOperationType(*this, LI)) + return Res; + // Attempt to improve the alignment. if (DL) { - unsigned KnownAlign = - getOrEnforceKnownAlignment(Op, DL->getPrefTypeAlignment(LI.getType()),DL); + unsigned KnownAlign = getOrEnforceKnownAlignment( + Op, DL->getPrefTypeAlignment(LI.getType()), DL, AC, &LI, DT); unsigned LoadAlign = LI.getAlignment(); unsigned EffectiveLoadAlign = LoadAlign != 0 ? LoadAlign : DL->getABITypeAlignment(LI.getType()); @@ -374,11 +406,6 @@ Instruction *InstCombiner::visitLoadInst(LoadInst &LI) { LI.setAlignment(EffectiveLoadAlign); } - // load (cast X) --> cast (load X) iff safe. - if (isa<CastInst>(Op)) - if (Instruction *Res = InstCombineLoadCast(*this, LI, DL)) - return Res; - // None of the following transforms are legal for volatile/atomic loads. // FIXME: Some of it is okay for atomic loads; needs refactoring. if (!LI.isSimple()) return nullptr; @@ -388,7 +415,9 @@ Instruction *InstCombiner::visitLoadInst(LoadInst &LI) { // separated by a few arithmetic operations. BasicBlock::iterator BBI = &LI; if (Value *AvailableVal = FindAvailableLoadedValue(Op, LI.getParent(), BBI,6)) - return ReplaceInstUsesWith(LI, AvailableVal); + return ReplaceInstUsesWith( + LI, Builder->CreateBitOrPointerCast(AvailableVal, LI.getType(), + LI.getName() + ".cast")); // load(gep null, ...) -> unreachable if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(Op)) { @@ -417,12 +446,6 @@ Instruction *InstCombiner::visitLoadInst(LoadInst &LI) { return ReplaceInstUsesWith(LI, UndefValue::get(LI.getType())); } - // Instcombine load (constantexpr_cast global) -> cast (load global) - if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Op)) - if (CE->isCast()) - if (Instruction *Res = InstCombineLoadCast(*this, LI, DL)) - return Res; - if (Op->hasOneUse()) { // Change select and PHI nodes to select values instead of addresses: this // helps alias analysis out a lot, allows many others simplifications, and @@ -449,119 +472,98 @@ Instruction *InstCombiner::visitLoadInst(LoadInst &LI) { } // load (select (cond, null, P)) -> load P - if (Constant *C = dyn_cast<Constant>(SI->getOperand(1))) - if (C->isNullValue()) { - LI.setOperand(0, SI->getOperand(2)); - return &LI; - } + if (isa<ConstantPointerNull>(SI->getOperand(1)) && + LI.getPointerAddressSpace() == 0) { + LI.setOperand(0, SI->getOperand(2)); + return &LI; + } // load (select (cond, P, null)) -> load P - if (Constant *C = dyn_cast<Constant>(SI->getOperand(2))) - if (C->isNullValue()) { - LI.setOperand(0, SI->getOperand(1)); - return &LI; - } + if (isa<ConstantPointerNull>(SI->getOperand(2)) && + LI.getPointerAddressSpace() == 0) { + LI.setOperand(0, SI->getOperand(1)); + return &LI; + } } } return nullptr; } -/// InstCombineStoreToCast - Fold store V, (cast P) -> store (cast V), P -/// when possible. This makes it generally easy to do alias analysis and/or -/// SROA/mem2reg of the memory object. -static Instruction *InstCombineStoreToCast(InstCombiner &IC, StoreInst &SI) { - User *CI = cast<User>(SI.getOperand(1)); - Value *CastOp = CI->getOperand(0); - - Type *DestPTy = CI->getType()->getPointerElementType(); - PointerType *SrcTy = dyn_cast<PointerType>(CastOp->getType()); - if (!SrcTy) return nullptr; - - Type *SrcPTy = SrcTy->getElementType(); +/// \brief Combine stores to match the type of value being stored. +/// +/// The core idea here is that the memory does not have any intrinsic type and +/// where we can we should match the type of a store to the type of value being +/// stored. +/// +/// However, this routine must never change the width of a store or the number of +/// stores as that would introduce a semantic change. This combine is expected to +/// be a semantic no-op which just allows stores to more closely model the types +/// of their incoming values. +/// +/// Currently, we also refuse to change the precise type used for an atomic or +/// volatile store. This is debatable, and might be reasonable to change later. +/// However, it is risky in case some backend or other part of LLVM is relying +/// on the exact type stored to select appropriate atomic operations. +/// +/// \returns true if the store was successfully combined away. This indicates +/// the caller must erase the store instruction. We have to let the caller erase +/// the store instruction sas otherwise there is no way to signal whether it was +/// combined or not: IC.EraseInstFromFunction returns a null pointer. +static bool combineStoreToValueType(InstCombiner &IC, StoreInst &SI) { + // FIXME: We could probably with some care handle both volatile and atomic + // stores here but it isn't clear that this is important. + if (!SI.isSimple()) + return false; - if (!DestPTy->isIntegerTy() && !DestPTy->isPointerTy()) - return nullptr; + Value *Ptr = SI.getPointerOperand(); + Value *V = SI.getValueOperand(); + unsigned AS = SI.getPointerAddressSpace(); + SmallVector<std::pair<unsigned, MDNode *>, 8> MD; + SI.getAllMetadata(MD); + + // Fold away bit casts of the stored value by storing the original type. + if (auto *BC = dyn_cast<BitCastInst>(V)) { + V = BC->getOperand(0); + StoreInst *NewStore = IC.Builder->CreateAlignedStore( + V, IC.Builder->CreateBitCast(Ptr, V->getType()->getPointerTo(AS)), + SI.getAlignment()); + for (const auto &MDPair : MD) { + unsigned ID = MDPair.first; + MDNode *N = MDPair.second; + // Note, essentially every kind of metadata should be preserved here! This + // routine is supposed to clone a store instruction changing *only its + // type*. The only metadata it makes sense to drop is metadata which is + // invalidated when the pointer type changes. This should essentially + // never be the case in LLVM, but we explicitly switch over only known + // metadata to be conservatively correct. If you are adding metadata to + // LLVM which pertains to stores, you almost certainly want to add it + // here. + switch (ID) { + case LLVMContext::MD_dbg: + case LLVMContext::MD_tbaa: + case LLVMContext::MD_prof: + case LLVMContext::MD_fpmath: + case LLVMContext::MD_tbaa_struct: + case LLVMContext::MD_alias_scope: + case LLVMContext::MD_noalias: + case LLVMContext::MD_nontemporal: + case LLVMContext::MD_mem_parallel_loop_access: + case LLVMContext::MD_nonnull: + // All of these directly apply. + NewStore->setMetadata(ID, N); + break; - /// NewGEPIndices - If SrcPTy is an aggregate type, we can emit a "noop gep" - /// to its first element. This allows us to handle things like: - /// store i32 xxx, (bitcast {foo*, float}* %P to i32*) - /// on 32-bit hosts. - SmallVector<Value*, 4> NewGEPIndices; - - // If the source is an array, the code below will not succeed. Check to - // see if a trivial 'gep P, 0, 0' will help matters. Only do this for - // constants. - if (SrcPTy->isArrayTy() || SrcPTy->isStructTy()) { - // Index through pointer. - Constant *Zero = Constant::getNullValue(Type::getInt32Ty(SI.getContext())); - NewGEPIndices.push_back(Zero); - - while (1) { - if (StructType *STy = dyn_cast<StructType>(SrcPTy)) { - if (!STy->getNumElements()) /* Struct can be empty {} */ - break; - NewGEPIndices.push_back(Zero); - SrcPTy = STy->getElementType(0); - } else if (ArrayType *ATy = dyn_cast<ArrayType>(SrcPTy)) { - NewGEPIndices.push_back(Zero); - SrcPTy = ATy->getElementType(); - } else { + case LLVMContext::MD_invariant_load: + case LLVMContext::MD_range: break; } } - - SrcTy = PointerType::get(SrcPTy, SrcTy->getAddressSpace()); - } - - if (!SrcPTy->isIntegerTy() && !SrcPTy->isPointerTy()) - return nullptr; - - // If the pointers point into different address spaces don't do the - // transformation. - if (SrcTy->getAddressSpace() != CI->getType()->getPointerAddressSpace()) - return nullptr; - - // If the pointers point to values of different sizes don't do the - // transformation. - if (!IC.getDataLayout() || - IC.getDataLayout()->getTypeSizeInBits(SrcPTy) != - IC.getDataLayout()->getTypeSizeInBits(DestPTy)) - return nullptr; - - // If the pointers point to pointers to different address spaces don't do the - // transformation. It is not safe to introduce an addrspacecast instruction in - // this case since, depending on the target, addrspacecast may not be a no-op - // cast. - if (SrcPTy->isPointerTy() && DestPTy->isPointerTy() && - SrcPTy->getPointerAddressSpace() != DestPTy->getPointerAddressSpace()) - return nullptr; - - // Okay, we are casting from one integer or pointer type to another of - // the same size. Instead of casting the pointer before - // the store, cast the value to be stored. - Value *NewCast; - Instruction::CastOps opcode = Instruction::BitCast; - Type* CastSrcTy = DestPTy; - Type* CastDstTy = SrcPTy; - if (CastDstTy->isPointerTy()) { - if (CastSrcTy->isIntegerTy()) - opcode = Instruction::IntToPtr; - } else if (CastDstTy->isIntegerTy()) { - if (CastSrcTy->isPointerTy()) - opcode = Instruction::PtrToInt; + return true; } - // SIOp0 is a pointer to aggregate and this is a store to the first field, - // emit a GEP to index into its first field. - if (!NewGEPIndices.empty()) - CastOp = IC.Builder->CreateInBoundsGEP(CastOp, NewGEPIndices); - - Value *SIOp0 = SI.getOperand(0); - NewCast = IC.Builder->CreateCast(opcode, SIOp0, CastDstTy, - SIOp0->getName()+".c"); - SI.setOperand(0, NewCast); - SI.setOperand(1, CastOp); - return &SI; + // FIXME: We should also canonicalize loads of vectors when their elements are + // cast to other types. + return false; } /// equivalentAddressValues - Test if A and B will obviously have the same @@ -597,11 +599,14 @@ Instruction *InstCombiner::visitStoreInst(StoreInst &SI) { Value *Val = SI.getOperand(0); Value *Ptr = SI.getOperand(1); + // Try to canonicalize the stored type. + if (combineStoreToValueType(*this, SI)) + return EraseInstFromFunction(SI); + // Attempt to improve the alignment. if (DL) { - unsigned KnownAlign = - getOrEnforceKnownAlignment(Ptr, DL->getPrefTypeAlignment(Val->getType()), - DL); + unsigned KnownAlign = getOrEnforceKnownAlignment( + Ptr, DL->getPrefTypeAlignment(Val->getType()), DL, AC, &SI, DT); unsigned StoreAlign = SI.getAlignment(); unsigned EffectiveStoreAlign = StoreAlign != 0 ? StoreAlign : DL->getABITypeAlignment(Val->getType()); @@ -688,17 +693,6 @@ Instruction *InstCombiner::visitStoreInst(StoreInst &SI) { if (isa<UndefValue>(Val)) return EraseInstFromFunction(SI); - // If the pointer destination is a cast, see if we can fold the cast into the - // source instead. - if (isa<CastInst>(Ptr)) - if (Instruction *Res = InstCombineStoreToCast(*this, SI)) - return Res; - if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Ptr)) - if (CE->isCast()) - if (Instruction *Res = InstCombineStoreToCast(*this, SI)) - return Res; - - // If this store is the last instruction in the basic block (possibly // excepting debug info instructions), and if the block ends with an // unconditional branch, try to move it to the successor block. @@ -836,12 +830,13 @@ bool InstCombiner::SimplifyStoreAtEndOfBlock(StoreInst &SI) { InsertNewInstBefore(NewSI, *BBI); NewSI->setDebugLoc(OtherStore->getDebugLoc()); - // If the two stores had the same TBAA tag, preserve it. - if (MDNode *TBAATag = SI.getMetadata(LLVMContext::MD_tbaa)) - if ((TBAATag = MDNode::getMostGenericTBAA(TBAATag, - OtherStore->getMetadata(LLVMContext::MD_tbaa)))) - NewSI->setMetadata(LLVMContext::MD_tbaa, TBAATag); - + // If the two stores had AA tags, merge them. + AAMDNodes AATags; + SI.getAAMetadata(AATags); + if (AATags) { + OtherStore->getAAMetadata(AATags, /* Merge = */ true); + NewSI->setAAMetadata(AATags); + } // Nuke the old stores. EraseInstFromFunction(SI); diff --git a/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp index 6c6e7d815163..b2ff96f401e9 100644 --- a/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -25,7 +25,8 @@ using namespace PatternMatch; /// simplifyValueKnownNonZero - The specific integer value is used in a context /// where it is known to be non-zero. If this allows us to simplify the /// computation, do so and return the new operand, otherwise return null. -static Value *simplifyValueKnownNonZero(Value *V, InstCombiner &IC) { +static Value *simplifyValueKnownNonZero(Value *V, InstCombiner &IC, + Instruction *CxtI) { // If V has multiple uses, then we would have to do more analysis to determine // if this is safe. For example, the use could be in dynamically unreached // code. @@ -35,22 +36,23 @@ static Value *simplifyValueKnownNonZero(Value *V, InstCombiner &IC) { // ((1 << A) >>u B) --> (1 << (A-B)) // Because V cannot be zero, we know that B is less than A. - Value *A = nullptr, *B = nullptr, *PowerOf2 = nullptr; - if (match(V, m_LShr(m_OneUse(m_Shl(m_Value(PowerOf2), m_Value(A))), - m_Value(B))) && - // The "1" can be any value known to be a power of 2. - isKnownToBeAPowerOfTwo(PowerOf2)) { + Value *A = nullptr, *B = nullptr, *One = nullptr; + if (match(V, m_LShr(m_OneUse(m_Shl(m_Value(One), m_Value(A))), m_Value(B))) && + match(One, m_One())) { A = IC.Builder->CreateSub(A, B); - return IC.Builder->CreateShl(PowerOf2, A); + return IC.Builder->CreateShl(One, A); } // (PowerOfTwo >>u B) --> isExact since shifting out the result would make it // inexact. Similarly for <<. if (BinaryOperator *I = dyn_cast<BinaryOperator>(V)) - if (I->isLogicalShift() && isKnownToBeAPowerOfTwo(I->getOperand(0))) { + if (I->isLogicalShift() && + isKnownToBeAPowerOfTwo(I->getOperand(0), false, 0, + IC.getAssumptionCache(), CxtI, + IC.getDominatorTree())) { // We know that this is an exact/nuw shift and that the input is a // non-zero context as well. - if (Value *V2 = simplifyValueKnownNonZero(I->getOperand(0), IC)) { + if (Value *V2 = simplifyValueKnownNonZero(I->getOperand(0), IC, CxtI)) { I->setOperand(0, V2); MadeChange = true; } @@ -76,25 +78,30 @@ static Value *simplifyValueKnownNonZero(Value *V, InstCombiner &IC) { /// MultiplyOverflows - True if the multiply can not be expressed in an int /// this size. -static bool MultiplyOverflows(ConstantInt *C1, ConstantInt *C2, bool sign) { - uint32_t W = C1->getBitWidth(); - APInt LHSExt = C1->getValue(), RHSExt = C2->getValue(); - if (sign) { - LHSExt = LHSExt.sext(W * 2); - RHSExt = RHSExt.sext(W * 2); - } else { - LHSExt = LHSExt.zext(W * 2); - RHSExt = RHSExt.zext(W * 2); - } +static bool MultiplyOverflows(const APInt &C1, const APInt &C2, APInt &Product, + bool IsSigned) { + bool Overflow; + if (IsSigned) + Product = C1.smul_ov(C2, Overflow); + else + Product = C1.umul_ov(C2, Overflow); + + return Overflow; +} - APInt MulExt = LHSExt * RHSExt; +/// \brief True if C2 is a multiple of C1. Quotient contains C2/C1. +static bool IsMultiple(const APInt &C1, const APInt &C2, APInt &Quotient, + bool IsSigned) { + assert(C1.getBitWidth() == C2.getBitWidth() && + "Inconsistent width of constants!"); - if (!sign) - return MulExt.ugt(APInt::getLowBitsSet(W * 2, W)); + APInt Remainder(C1.getBitWidth(), /*Val=*/0ULL, IsSigned); + if (IsSigned) + APInt::sdivrem(C1, C2, Quotient, Remainder); + else + APInt::udivrem(C1, C2, Quotient, Remainder); - APInt Min = APInt::getSignedMinValue(W).sext(W * 2); - APInt Max = APInt::getSignedMaxValue(W).sext(W * 2); - return MulExt.slt(Min) || MulExt.sgt(Max); + return Remainder.isMinValue(); } /// \brief A helper routine of InstCombiner::visitMul(). @@ -116,6 +123,48 @@ static Constant *getLogBase2Vector(ConstantDataVector *CV) { return ConstantVector::get(Elts); } +/// \brief Return true if we can prove that: +/// (mul LHS, RHS) === (mul nsw LHS, RHS) +bool InstCombiner::WillNotOverflowSignedMul(Value *LHS, Value *RHS, + Instruction *CxtI) { + // Multiplying n * m significant bits yields a result of n + m significant + // bits. If the total number of significant bits does not exceed the + // result bit width (minus 1), there is no overflow. + // This means if we have enough leading sign bits in the operands + // we can guarantee that the result does not overflow. + // Ref: "Hacker's Delight" by Henry Warren + unsigned BitWidth = LHS->getType()->getScalarSizeInBits(); + + // Note that underestimating the number of sign bits gives a more + // conservative answer. + unsigned SignBits = ComputeNumSignBits(LHS, 0, CxtI) + + ComputeNumSignBits(RHS, 0, CxtI); + + // First handle the easy case: if we have enough sign bits there's + // definitely no overflow. + if (SignBits > BitWidth + 1) + return true; + + // There are two ambiguous cases where there can be no overflow: + // SignBits == BitWidth + 1 and + // SignBits == BitWidth + // The second case is difficult to check, therefore we only handle the + // first case. + if (SignBits == BitWidth + 1) { + // It overflows only when both arguments are negative and the true + // product is exactly the minimum negative number. + // E.g. mul i16 with 17 sign bits: 0xff00 * 0xff80 = 0x8000 + // For simplicity we just check if at least one side is not negative. + bool LHSNonNegative, LHSNegative; + bool RHSNonNegative, RHSNegative; + ComputeSignBit(LHS, LHSNonNegative, LHSNegative, /*Depth=*/0, CxtI); + ComputeSignBit(RHS, RHSNonNegative, RHSNegative, /*Depth=*/0, CxtI); + if (LHSNonNegative || RHSNonNegative) + return true; + } + return false; +} + Instruction *InstCombiner::visitMul(BinaryOperator &I) { bool Changed = SimplifyAssociativeOrCommutative(I); Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); @@ -123,14 +172,19 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return ReplaceInstUsesWith(I, V); - if (Value *V = SimplifyMulInst(Op0, Op1, DL)) + if (Value *V = SimplifyMulInst(Op0, Op1, DL, TLI, DT, AC)) return ReplaceInstUsesWith(I, V); if (Value *V = SimplifyUsingDistributiveLaws(I)) return ReplaceInstUsesWith(I, V); - if (match(Op1, m_AllOnes())) // X * -1 == 0 - X - return BinaryOperator::CreateNeg(Op0, I.getName()); + // X * -1 == 0 - X + if (match(Op1, m_AllOnes())) { + BinaryOperator *BO = BinaryOperator::CreateNeg(Op0, I.getName()); + if (I.hasNoSignedWrap()) + BO->setHasNoSignedWrap(); + return BO; + } // Also allow combining multiply instructions on vectors. { @@ -139,9 +193,18 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { const APInt *IVal; if (match(&I, m_Mul(m_Shl(m_Value(NewOp), m_Constant(C2)), m_Constant(C1))) && - match(C1, m_APInt(IVal))) - // ((X << C1)*C2) == (X * (C2 << C1)) - return BinaryOperator::CreateMul(NewOp, ConstantExpr::getShl(C1, C2)); + match(C1, m_APInt(IVal))) { + // ((X << C2)*C1) == (X * (C1 << C2)) + Constant *Shl = ConstantExpr::getShl(C1, C2); + BinaryOperator *Mul = cast<BinaryOperator>(I.getOperand(0)); + BinaryOperator *BO = BinaryOperator::CreateMul(NewOp, Shl); + if (I.hasNoUnsignedWrap() && Mul->hasNoUnsignedWrap()) + BO->setHasNoUnsignedWrap(); + if (I.hasNoSignedWrap() && Mul->hasNoSignedWrap() && + Shl->isNotMinSignedValue()) + BO->setHasNoSignedWrap(); + return BO; + } if (match(&I, m_Mul(m_Value(NewOp), m_Constant(C1)))) { Constant *NewCst = nullptr; @@ -155,8 +218,12 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { if (NewCst) { BinaryOperator *Shl = BinaryOperator::CreateShl(NewOp, NewCst); - if (I.hasNoSignedWrap()) Shl->setHasNoSignedWrap(); - if (I.hasNoUnsignedWrap()) Shl->setHasNoUnsignedWrap(); + + if (I.hasNoUnsignedWrap()) + Shl->setHasNoUnsignedWrap(); + if (I.hasNoSignedWrap() && NewCst->isNotMinSignedValue()) + Shl->setHasNoSignedWrap(); + return Shl; } } @@ -212,9 +279,16 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { } } - if (Value *Op0v = dyn_castNegVal(Op0)) // -X * -Y = X*Y - if (Value *Op1v = dyn_castNegVal(Op1)) - return BinaryOperator::CreateMul(Op0v, Op1v); + if (Value *Op0v = dyn_castNegVal(Op0)) { // -X * -Y = X*Y + if (Value *Op1v = dyn_castNegVal(Op1)) { + BinaryOperator *BO = BinaryOperator::CreateMul(Op0v, Op1v); + if (I.hasNoSignedWrap() && + match(Op0, m_NSWSub(m_Value(), m_Value())) && + match(Op1, m_NSWSub(m_Value(), m_Value()))) + BO->setHasNoSignedWrap(); + return BO; + } + } // (X / Y) * Y = X - (X % Y) // (X / Y) * -Y = (X % Y) - X @@ -263,10 +337,22 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { // (1 << Y)*X --> X << Y { Value *Y; - if (match(Op0, m_Shl(m_One(), m_Value(Y)))) - return BinaryOperator::CreateShl(Op1, Y); - if (match(Op1, m_Shl(m_One(), m_Value(Y)))) - return BinaryOperator::CreateShl(Op0, Y); + BinaryOperator *BO = nullptr; + bool ShlNSW = false; + if (match(Op0, m_Shl(m_One(), m_Value(Y)))) { + BO = BinaryOperator::CreateShl(Op1, Y); + ShlNSW = cast<ShlOperator>(Op0)->hasNoSignedWrap(); + } else if (match(Op1, m_Shl(m_One(), m_Value(Y)))) { + BO = BinaryOperator::CreateShl(Op0, Y); + ShlNSW = cast<ShlOperator>(Op1)->hasNoSignedWrap(); + } + if (BO) { + if (I.hasNoUnsignedWrap()) + BO->setHasNoUnsignedWrap(); + if (I.hasNoSignedWrap() && ShlNSW) + BO->setHasNoSignedWrap(); + return BO; + } } // If one of the operands of the multiply is a cast from a boolean value, then @@ -277,9 +363,9 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { APInt Negative2(I.getType()->getPrimitiveSizeInBits(), (uint64_t)-2, true); Value *BoolCast = nullptr, *OtherOp = nullptr; - if (MaskedValueIsZero(Op0, Negative2)) + if (MaskedValueIsZero(Op0, Negative2, 0, &I)) BoolCast = Op0, OtherOp = Op1; - else if (MaskedValueIsZero(Op1, Negative2)) + else if (MaskedValueIsZero(Op1, Negative2, 0, &I)) BoolCast = Op1, OtherOp = Op0; if (BoolCast) { @@ -289,43 +375,47 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { } } + if (!I.hasNoSignedWrap() && WillNotOverflowSignedMul(Op0, Op1, &I)) { + Changed = true; + I.setHasNoSignedWrap(true); + } + + if (!I.hasNoUnsignedWrap() && + computeOverflowForUnsignedMul(Op0, Op1, &I) == + OverflowResult::NeverOverflows) { + Changed = true; + I.setHasNoUnsignedWrap(true); + } + return Changed ? &I : nullptr; } -// -// Detect pattern: -// -// log2(Y*0.5) -// -// And check for corresponding fast math flags -// - +/// Detect pattern log2(Y * 0.5) with corresponding fast math flags. static void detectLog2OfHalf(Value *&Op, Value *&Y, IntrinsicInst *&Log2) { - - if (!Op->hasOneUse()) - return; - - IntrinsicInst *II = dyn_cast<IntrinsicInst>(Op); - if (!II) - return; - if (II->getIntrinsicID() != Intrinsic::log2 || !II->hasUnsafeAlgebra()) - return; - Log2 = II; - - Value *OpLog2Of = II->getArgOperand(0); - if (!OpLog2Of->hasOneUse()) - return; - - Instruction *I = dyn_cast<Instruction>(OpLog2Of); - if (!I) - return; - if (I->getOpcode() != Instruction::FMul || !I->hasUnsafeAlgebra()) - return; - - if (match(I->getOperand(0), m_SpecificFP(0.5))) - Y = I->getOperand(1); - else if (match(I->getOperand(1), m_SpecificFP(0.5))) - Y = I->getOperand(0); + if (!Op->hasOneUse()) + return; + + IntrinsicInst *II = dyn_cast<IntrinsicInst>(Op); + if (!II) + return; + if (II->getIntrinsicID() != Intrinsic::log2 || !II->hasUnsafeAlgebra()) + return; + Log2 = II; + + Value *OpLog2Of = II->getArgOperand(0); + if (!OpLog2Of->hasOneUse()) + return; + + Instruction *I = dyn_cast<Instruction>(OpLog2Of); + if (!I) + return; + if (I->getOpcode() != Instruction::FMul || !I->hasUnsafeAlgebra()) + return; + + if (match(I->getOperand(0), m_SpecificFP(0.5))) + Y = I->getOperand(1); + else if (match(I->getOperand(1), m_SpecificFP(0.5))) + Y = I->getOperand(0); } static bool isFiniteNonZeroFp(Constant *C) { @@ -440,7 +530,8 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { if (isa<Constant>(Op0)) std::swap(Op0, Op1); - if (Value *V = SimplifyFMulInst(Op0, Op1, I.getFastMathFlags(), DL)) + if (Value *V = + SimplifyFMulInst(Op0, Op1, I.getFastMathFlags(), DL, TLI, DT, AC)) return ReplaceInstUsesWith(I, V); bool AllowReassociate = I.hasUnsafeAlgebra(); @@ -510,10 +601,15 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { } } + // sqrt(X) * sqrt(X) -> X + if (AllowReassociate && (Op0 == Op1)) + if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Op0)) + if (II->getIntrinsicID() == Intrinsic::sqrt) + return ReplaceInstUsesWith(I, II->getOperand(0)); // Under unsafe algebra do: // X * log2(0.5*Y) = X*log2(Y) - X - if (I.hasUnsafeAlgebra()) { + if (AllowReassociate) { Value *OpX = nullptr; Value *OpY = nullptr; IntrinsicInst *Log2; @@ -596,36 +692,6 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { } } - // B * (uitofp i1 C) -> select C, B, 0 - if (I.hasNoNaNs() && I.hasNoInfs() && I.hasNoSignedZeros()) { - Value *LHS = Op0, *RHS = Op1; - Value *B, *C; - if (!match(RHS, m_UIToFP(m_Value(C)))) - std::swap(LHS, RHS); - - if (match(RHS, m_UIToFP(m_Value(C))) && - C->getType()->getScalarType()->isIntegerTy(1)) { - B = LHS; - Value *Zero = ConstantFP::getNegativeZero(B->getType()); - return SelectInst::Create(C, B, Zero); - } - } - - // A * (1 - uitofp i1 C) -> select C, 0, A - if (I.hasNoNaNs() && I.hasNoInfs() && I.hasNoSignedZeros()) { - Value *LHS = Op0, *RHS = Op1; - Value *A, *C; - if (!match(RHS, m_FSub(m_FPOne(), m_UIToFP(m_Value(C))))) - std::swap(LHS, RHS); - - if (match(RHS, m_FSub(m_FPOne(), m_UIToFP(m_Value(C)))) && - C->getType()->getScalarType()->isIntegerTy(1)) { - A = LHS; - Value *Zero = ConstantFP::getNegativeZero(A->getType()); - return SelectInst::Create(C, Zero, A); - } - } - if (!isa<Constant>(Op1)) std::swap(Opnd0, Opnd1); else @@ -714,7 +780,7 @@ Instruction *InstCombiner::commonIDivTransforms(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); // The RHS is known non-zero. - if (Value *V = simplifyValueKnownNonZero(I.getOperand(1), *this)) { + if (Value *V = simplifyValueKnownNonZero(I.getOperand(1), *this, &I)) { I.setOperand(1, V); return &I; } @@ -724,25 +790,83 @@ Instruction *InstCombiner::commonIDivTransforms(BinaryOperator &I) { if (isa<SelectInst>(Op1) && SimplifyDivRemOfSelect(I)) return &I; - if (ConstantInt *RHS = dyn_cast<ConstantInt>(Op1)) { - // (X / C1) / C2 -> X / (C1*C2) - if (Instruction *LHS = dyn_cast<Instruction>(Op0)) - if (Instruction::BinaryOps(LHS->getOpcode()) == I.getOpcode()) - if (ConstantInt *LHSRHS = dyn_cast<ConstantInt>(LHS->getOperand(1))) { - if (MultiplyOverflows(RHS, LHSRHS, - I.getOpcode() == Instruction::SDiv)) - return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType())); - return BinaryOperator::Create(I.getOpcode(), LHS->getOperand(0), - ConstantExpr::getMul(RHS, LHSRHS)); + if (Instruction *LHS = dyn_cast<Instruction>(Op0)) { + const APInt *C2; + if (match(Op1, m_APInt(C2))) { + Value *X; + const APInt *C1; + bool IsSigned = I.getOpcode() == Instruction::SDiv; + + // (X / C1) / C2 -> X / (C1*C2) + if ((IsSigned && match(LHS, m_SDiv(m_Value(X), m_APInt(C1)))) || + (!IsSigned && match(LHS, m_UDiv(m_Value(X), m_APInt(C1))))) { + APInt Product(C1->getBitWidth(), /*Val=*/0ULL, IsSigned); + if (!MultiplyOverflows(*C1, *C2, Product, IsSigned)) + return BinaryOperator::Create(I.getOpcode(), X, + ConstantInt::get(I.getType(), Product)); + } + + if ((IsSigned && match(LHS, m_NSWMul(m_Value(X), m_APInt(C1)))) || + (!IsSigned && match(LHS, m_NUWMul(m_Value(X), m_APInt(C1))))) { + APInt Quotient(C1->getBitWidth(), /*Val=*/0ULL, IsSigned); + + // (X * C1) / C2 -> X / (C2 / C1) if C2 is a multiple of C1. + if (IsMultiple(*C2, *C1, Quotient, IsSigned)) { + BinaryOperator *BO = BinaryOperator::Create( + I.getOpcode(), X, ConstantInt::get(X->getType(), Quotient)); + BO->setIsExact(I.isExact()); + return BO; } - if (!RHS->isZero()) { // avoid X udiv 0 - if (SelectInst *SI = dyn_cast<SelectInst>(Op0)) - if (Instruction *R = FoldOpIntoSelect(I, SI)) - return R; - if (isa<PHINode>(Op0)) - if (Instruction *NV = FoldOpIntoPhi(I)) - return NV; + // (X * C1) / C2 -> X * (C1 / C2) if C1 is a multiple of C2. + if (IsMultiple(*C1, *C2, Quotient, IsSigned)) { + BinaryOperator *BO = BinaryOperator::Create( + Instruction::Mul, X, ConstantInt::get(X->getType(), Quotient)); + BO->setHasNoUnsignedWrap( + !IsSigned && + cast<OverflowingBinaryOperator>(LHS)->hasNoUnsignedWrap()); + BO->setHasNoSignedWrap( + cast<OverflowingBinaryOperator>(LHS)->hasNoSignedWrap()); + return BO; + } + } + + if ((IsSigned && match(LHS, m_NSWShl(m_Value(X), m_APInt(C1))) && + *C1 != C1->getBitWidth() - 1) || + (!IsSigned && match(LHS, m_NUWShl(m_Value(X), m_APInt(C1))))) { + APInt Quotient(C1->getBitWidth(), /*Val=*/0ULL, IsSigned); + APInt C1Shifted = APInt::getOneBitSet( + C1->getBitWidth(), static_cast<unsigned>(C1->getLimitedValue())); + + // (X << C1) / C2 -> X / (C2 >> C1) if C2 is a multiple of C1. + if (IsMultiple(*C2, C1Shifted, Quotient, IsSigned)) { + BinaryOperator *BO = BinaryOperator::Create( + I.getOpcode(), X, ConstantInt::get(X->getType(), Quotient)); + BO->setIsExact(I.isExact()); + return BO; + } + + // (X << C1) / C2 -> X * (C2 >> C1) if C1 is a multiple of C2. + if (IsMultiple(C1Shifted, *C2, Quotient, IsSigned)) { + BinaryOperator *BO = BinaryOperator::Create( + Instruction::Mul, X, ConstantInt::get(X->getType(), Quotient)); + BO->setHasNoUnsignedWrap( + !IsSigned && + cast<OverflowingBinaryOperator>(LHS)->hasNoUnsignedWrap()); + BO->setHasNoSignedWrap( + cast<OverflowingBinaryOperator>(LHS)->hasNoSignedWrap()); + return BO; + } + } + + if (*C2 != 0) { // avoid X udiv 0 + if (SelectInst *SI = dyn_cast<SelectInst>(Op0)) + if (Instruction *R = FoldOpIntoSelect(I, SI)) + return R; + if (isa<PHINode>(Op0)) + if (Instruction *NV = FoldOpIntoPhi(I)) + return NV; + } } } @@ -828,7 +952,8 @@ static Instruction *foldUDivPow2Cst(Value *Op0, Value *Op1, const APInt &C = cast<Constant>(Op1)->getUniqueInteger(); BinaryOperator *LShr = BinaryOperator::CreateLShr( Op0, ConstantInt::get(Op0->getType(), C.logBase2())); - if (I.isExact()) LShr->setIsExact(); + if (I.isExact()) + LShr->setIsExact(); return LShr; } @@ -856,7 +981,8 @@ static Instruction *foldUDivShl(Value *Op0, Value *Op1, const BinaryOperator &I, if (ZExtInst *Z = dyn_cast<ZExtInst>(Op1)) N = IC.Builder->CreateZExt(N, Z->getDestTy()); BinaryOperator *LShr = BinaryOperator::CreateLShr(Op0, N); - if (I.isExact()) LShr->setIsExact(); + if (I.isExact()) + LShr->setIsExact(); return LShr; } @@ -893,10 +1019,10 @@ static size_t visitUDivOperand(Value *Op0, Value *Op1, const BinaryOperator &I, return 0; if (SelectInst *SI = dyn_cast<SelectInst>(Op1)) - if (size_t LHSIdx = visitUDivOperand(Op0, SI->getOperand(1), I, Actions)) - if (visitUDivOperand(Op0, SI->getOperand(2), I, Actions)) { - Actions.push_back(UDivFoldAction((FoldUDivOperandCb)nullptr, Op1, - LHSIdx-1)); + if (size_t LHSIdx = + visitUDivOperand(Op0, SI->getOperand(1), I, Actions, Depth)) + if (visitUDivOperand(Op0, SI->getOperand(2), I, Actions, Depth)) { + Actions.push_back(UDivFoldAction(nullptr, Op1, LHSIdx - 1)); return Actions.size(); } @@ -909,7 +1035,7 @@ Instruction *InstCombiner::visitUDiv(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return ReplaceInstUsesWith(I, V); - if (Value *V = SimplifyUDivInst(Op0, Op1, DL)) + if (Value *V = SimplifyUDivInst(Op0, Op1, DL, TLI, DT, AC)) return ReplaceInstUsesWith(I, V); // Handle the integer div common cases @@ -917,19 +1043,30 @@ Instruction *InstCombiner::visitUDiv(BinaryOperator &I) { return Common; // (x lshr C1) udiv C2 --> x udiv (C2 << C1) - if (Constant *C2 = dyn_cast<Constant>(Op1)) { + { Value *X; - Constant *C1; - if (match(Op0, m_LShr(m_Value(X), m_Constant(C1)))) - return BinaryOperator::CreateUDiv(X, ConstantExpr::getShl(C2, C1)); + const APInt *C1, *C2; + if (match(Op0, m_LShr(m_Value(X), m_APInt(C1))) && + match(Op1, m_APInt(C2))) { + bool Overflow; + APInt C2ShlC1 = C2->ushl_ov(*C1, Overflow); + if (!Overflow) { + bool IsExact = I.isExact() && match(Op0, m_Exact(m_Value())); + BinaryOperator *BO = BinaryOperator::CreateUDiv( + X, ConstantInt::get(X->getType(), C2ShlC1)); + if (IsExact) + BO->setIsExact(); + return BO; + } + } } // (zext A) udiv (zext B) --> zext (A udiv B) if (ZExtInst *ZOp0 = dyn_cast<ZExtInst>(Op0)) if (Value *ZOp1 = dyn_castZExtVal(Op1, ZOp0->getSrcTy())) - return new ZExtInst(Builder->CreateUDiv(ZOp0->getOperand(0), ZOp1, "div", - I.isExact()), - I.getType()); + return new ZExtInst( + Builder->CreateUDiv(ZOp0->getOperand(0), ZOp1, "div", I.isExact()), + I.getType()); // (LHS udiv (select (select (...)))) -> (LHS >> (select (select (...)))) SmallVector<UDivFoldAction, 6> UDivActions; @@ -971,7 +1108,7 @@ Instruction *InstCombiner::visitSDiv(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return ReplaceInstUsesWith(I, V); - if (Value *V = SimplifySDivInst(Op0, Op1, DL)) + if (Value *V = SimplifySDivInst(Op0, Op1, DL, TLI, DT, AC)) return ReplaceInstUsesWith(I, V); // Handle the integer div common cases @@ -998,28 +1135,34 @@ Instruction *InstCombiner::visitSDiv(BinaryOperator &I) { return new ZExtInst(Builder->CreateICmpEQ(Op0, Op1), I.getType()); // -X/C --> X/-C provided the negation doesn't overflow. - if (SubOperator *Sub = dyn_cast<SubOperator>(Op0)) - if (match(Sub->getOperand(0), m_Zero()) && Sub->hasNoSignedWrap()) - return BinaryOperator::CreateSDiv(Sub->getOperand(1), - ConstantExpr::getNeg(RHS)); + Value *X; + if (match(Op0, m_NSWSub(m_Zero(), m_Value(X)))) { + auto *BO = BinaryOperator::CreateSDiv(X, ConstantExpr::getNeg(RHS)); + BO->setIsExact(I.isExact()); + return BO; + } } // If the sign bits of both operands are zero (i.e. we can prove they are // unsigned inputs), turn this into a udiv. if (I.getType()->isIntegerTy()) { APInt Mask(APInt::getSignBit(I.getType()->getPrimitiveSizeInBits())); - if (MaskedValueIsZero(Op0, Mask)) { - if (MaskedValueIsZero(Op1, Mask)) { + if (MaskedValueIsZero(Op0, Mask, 0, &I)) { + if (MaskedValueIsZero(Op1, Mask, 0, &I)) { // X sdiv Y -> X udiv Y, iff X and Y don't have sign bit set - return BinaryOperator::CreateUDiv(Op0, Op1, I.getName()); + auto *BO = BinaryOperator::CreateUDiv(Op0, Op1, I.getName()); + BO->setIsExact(I.isExact()); + return BO; } - if (match(Op1, m_Shl(m_Power2(), m_Value()))) { + if (isKnownToBeAPowerOfTwo(Op1, /*OrZero*/ true, 0, AC, &I, DT)) { // X sdiv (1 << Y) -> X udiv (1 << Y) ( -> X u>> Y) // Safe because the only negative value (1 << Y) can take on is // INT_MIN, and X sdiv INT_MIN == X udiv INT_MIN == 0 if X doesn't have // the sign bit set. - return BinaryOperator::CreateUDiv(Op0, Op1, I.getName()); + auto *BO = BinaryOperator::CreateUDiv(Op0, Op1, I.getName()); + BO->setIsExact(I.isExact()); + return BO; } } } @@ -1034,8 +1177,7 @@ Instruction *InstCombiner::visitSDiv(BinaryOperator &I) { /// If the conversion was successful, the simplified expression "X * 1/C" is /// returned; otherwise, NULL is returned. /// -static Instruction *CvtFDivConstToReciprocal(Value *Dividend, - Constant *Divisor, +static Instruction *CvtFDivConstToReciprocal(Value *Dividend, Constant *Divisor, bool AllowReciprocal) { if (!isa<ConstantFP>(Divisor)) // TODO: handle vectors. return nullptr; @@ -1064,7 +1206,7 @@ Instruction *InstCombiner::visitFDiv(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return ReplaceInstUsesWith(I, V); - if (Value *V = SimplifyFDivInst(Op0, Op1, DL)) + if (Value *V = SimplifyFDivInst(Op0, Op1, DL, TLI, DT, AC)) return ReplaceInstUsesWith(I, V); if (isa<Constant>(Op0)) @@ -1195,7 +1337,7 @@ Instruction *InstCombiner::commonIRemTransforms(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); // The RHS is known non-zero. - if (Value *V = simplifyValueKnownNonZero(I.getOperand(1), *this)) { + if (Value *V = simplifyValueKnownNonZero(I.getOperand(1), *this, &I)) { I.setOperand(1, V); return &I; } @@ -1229,7 +1371,7 @@ Instruction *InstCombiner::visitURem(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return ReplaceInstUsesWith(I, V); - if (Value *V = SimplifyURemInst(Op0, Op1, DL)) + if (Value *V = SimplifyURemInst(Op0, Op1, DL, TLI, DT, AC)) return ReplaceInstUsesWith(I, V); if (Instruction *common = commonIRemTransforms(I)) @@ -1242,7 +1384,7 @@ Instruction *InstCombiner::visitURem(BinaryOperator &I) { I.getType()); // X urem Y -> X and Y-1, where Y is a power of 2, - if (isKnownToBeAPowerOfTwo(Op1, /*OrZero*/true)) { + if (isKnownToBeAPowerOfTwo(Op1, /*OrZero*/ true, 0, AC, &I, DT)) { Constant *N1 = Constant::getAllOnesValue(I.getType()); Value *Add = Builder->CreateAdd(Op1, N1); return BinaryOperator::CreateAnd(Op0, Add); @@ -1264,28 +1406,29 @@ Instruction *InstCombiner::visitSRem(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return ReplaceInstUsesWith(I, V); - if (Value *V = SimplifySRemInst(Op0, Op1, DL)) + if (Value *V = SimplifySRemInst(Op0, Op1, DL, TLI, DT, AC)) return ReplaceInstUsesWith(I, V); // Handle the integer rem common cases if (Instruction *Common = commonIRemTransforms(I)) return Common; - if (Value *RHSNeg = dyn_castNegVal(Op1)) - if (!isa<Constant>(RHSNeg) || - (isa<ConstantInt>(RHSNeg) && - cast<ConstantInt>(RHSNeg)->getValue().isStrictlyPositive())) { - // X % -Y -> X % Y + { + const APInt *Y; + // X % -Y -> X % Y + if (match(Op1, m_APInt(Y)) && Y->isNegative() && !Y->isMinSignedValue()) { Worklist.AddValue(I.getOperand(1)); - I.setOperand(1, RHSNeg); + I.setOperand(1, ConstantInt::get(I.getType(), -*Y)); return &I; } + } // If the sign bits of both operands are zero (i.e. we can prove they are // unsigned inputs), turn this into a urem. if (I.getType()->isIntegerTy()) { APInt Mask(APInt::getSignBit(I.getType()->getPrimitiveSizeInBits())); - if (MaskedValueIsZero(Op1, Mask) && MaskedValueIsZero(Op0, Mask)) { + if (MaskedValueIsZero(Op1, Mask, 0, &I) && + MaskedValueIsZero(Op0, Mask, 0, &I)) { // X srem Y -> X urem Y, iff X and Y don't have sign bit set return BinaryOperator::CreateURem(Op0, Op1, I.getName()); } @@ -1338,7 +1481,7 @@ Instruction *InstCombiner::visitFRem(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return ReplaceInstUsesWith(I, V); - if (Value *V = SimplifyFRemInst(Op0, Op1, DL)) + if (Value *V = SimplifyFRemInst(Op0, Op1, DL, TLI, DT, AC)) return ReplaceInstUsesWith(I, V); // Handle cases involving: rem X, (select Cond, Y, Z) diff --git a/lib/Transforms/InstCombine/InstCombinePHI.cpp b/lib/Transforms/InstCombine/InstCombinePHI.cpp index 46f7b8a095c5..53831c8149ee 100644 --- a/lib/Transforms/InstCombine/InstCombinePHI.cpp +++ b/lib/Transforms/InstCombine/InstCombinePHI.cpp @@ -506,12 +506,12 @@ Instruction *InstCombiner::FoldPHIArgOpIntoPHI(PHINode &PN) { /// DeadPHICycle - Return true if this PHI node is only used by a PHI node cycle /// that is dead. static bool DeadPHICycle(PHINode *PN, - SmallPtrSet<PHINode*, 16> &PotentiallyDeadPHIs) { + SmallPtrSetImpl<PHINode*> &PotentiallyDeadPHIs) { if (PN->use_empty()) return true; if (!PN->hasOneUse()) return false; // Remember this node, and if we find the cycle, return. - if (!PotentiallyDeadPHIs.insert(PN)) + if (!PotentiallyDeadPHIs.insert(PN).second) return true; // Don't scan crazily complex things. @@ -528,9 +528,9 @@ static bool DeadPHICycle(PHINode *PN, /// NonPhiInVal. This happens with mutually cyclic phi nodes like: /// z = some value; x = phi (y, z); y = phi (x, z) static bool PHIsEqualValue(PHINode *PN, Value *NonPhiInVal, - SmallPtrSet<PHINode*, 16> &ValueEqualPHIs) { + SmallPtrSetImpl<PHINode*> &ValueEqualPHIs) { // See if we already saw this PHI node. - if (!ValueEqualPHIs.insert(PN)) + if (!ValueEqualPHIs.insert(PN).second) return true; // Don't scan crazily complex things. @@ -654,7 +654,7 @@ Instruction *InstCombiner::SliceUpIllegalIntegerPHI(PHINode &FirstPhi) { // If the user is a PHI, inspect its uses recursively. if (PHINode *UserPN = dyn_cast<PHINode>(UserI)) { - if (PHIsInspected.insert(UserPN)) + if (PHIsInspected.insert(UserPN).second) PHIsToSlice.push_back(UserPN); continue; } @@ -788,7 +788,7 @@ Instruction *InstCombiner::SliceUpIllegalIntegerPHI(PHINode &FirstPhi) { // PHINode simplification // Instruction *InstCombiner::visitPHINode(PHINode &PN) { - if (Value *V = SimplifyInstruction(&PN, DL, TLI)) + if (Value *V = SimplifyInstruction(&PN, DL, TLI, DT, AC)) return ReplaceInstUsesWith(PN, V); // If all PHI operands are the same operation, pull them through the PHI, diff --git a/lib/Transforms/InstCombine/InstCombineSelect.cpp b/lib/Transforms/InstCombine/InstCombineSelect.cpp index 06c9e290c6ea..bf3c33e45b8d 100644 --- a/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -313,7 +313,8 @@ Instruction *InstCombiner::FoldSelectIntoOp(SelectInst &SI, Value *TrueVal, /// replaced with RepOp. static Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, const DataLayout *TD, - const TargetLibraryInfo *TLI) { + const TargetLibraryInfo *TLI, + DominatorTree *DT, AssumptionCache *AC) { // Trivial replacement. if (V == Op) return RepOp; @@ -334,10 +335,10 @@ static Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, if (CmpInst *C = dyn_cast<CmpInst>(I)) { if (C->getOperand(0) == Op) return SimplifyCmpInst(C->getPredicate(), RepOp, C->getOperand(1), TD, - TLI); + TLI, DT, AC); if (C->getOperand(1) == Op) return SimplifyCmpInst(C->getPredicate(), C->getOperand(0), RepOp, TD, - TLI); + TLI, DT, AC); } // TODO: We could hand off more cases to instsimplify here. @@ -387,15 +388,7 @@ static Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, /// 1. The icmp predicate is inverted /// 2. The select operands are reversed /// 3. The magnitude of C2 and C1 are flipped -/// -/// This also tries to turn -/// --- Single bit tests: -/// if ((x & C) == 0) x |= C to x |= C -/// if ((x & C) != 0) x ^= C to x &= ~C -/// if ((x & C) == 0) x ^= C to x |= C -/// if ((x & C) != 0) x &= ~C to x &= ~C -/// if ((x & C) == 0) x &= ~C to nothing -static Value *foldSelectICmpAndOr(SelectInst &SI, Value *TrueVal, +static Value *foldSelectICmpAndOr(const SelectInst &SI, Value *TrueVal, Value *FalseVal, InstCombiner::BuilderTy *Builder) { const ICmpInst *IC = dyn_cast<ICmpInst>(SI.getCondition()); @@ -414,25 +407,6 @@ static Value *foldSelectICmpAndOr(SelectInst &SI, Value *TrueVal, return nullptr; const APInt *C2; - if (match(TrueVal, m_Specific(X))) { - // if ((X & C) != 0) X ^= C becomes X &= ~C - if (match(FalseVal, m_Xor(m_Specific(X), m_APInt(C2))) && C1 == C2) - return Builder->CreateAnd(X, ~(*C1)); - // if ((X & C) != 0) X &= ~C becomes X &= ~C - if (match(FalseVal, m_And(m_Specific(X), m_APInt(C2))) && *C1 == ~(*C2)) - return FalseVal; - } else if (match(FalseVal, m_Specific(X))) { - // if ((X & C) == 0) X ^= C becomes X |= C - if (match(TrueVal, m_Xor(m_Specific(X), m_APInt(C2))) && C1 == C2) - return Builder->CreateOr(X, *C1); - // if ((X & C) == 0) X &= ~C becomes nothing - if (match(TrueVal, m_And(m_Specific(X), m_APInt(C2))) && *C1 == ~(*C2)) - return X; - // if ((X & C) == 0) X |= C becomes X |= C - if (match(TrueVal, m_Or(m_Specific(X), m_APInt(C2))) && C1 == C2) - return TrueVal; - } - bool OrOnTrueVal = false; bool OrOnFalseVal = match(FalseVal, m_Or(m_Specific(TrueVal), m_Power2(C2))); if (!OrOnFalseVal) @@ -605,18 +579,26 @@ Instruction *InstCombiner::visitSelectInstWithICmp(SelectInst &SI, // arms of the select. See if substituting this value into the arm and // simplifying the result yields the same value as the other arm. if (Pred == ICmpInst::ICMP_EQ) { - if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, DL, TLI) == TrueVal || - SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, DL, TLI) == TrueVal) + if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, DL, TLI, DT, AC) == + TrueVal || + SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, DL, TLI, DT, AC) == + TrueVal) return ReplaceInstUsesWith(SI, FalseVal); - if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, DL, TLI) == FalseVal || - SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, DL, TLI) == FalseVal) + if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, DL, TLI, DT, AC) == + FalseVal || + SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, DL, TLI, DT, AC) == + FalseVal) return ReplaceInstUsesWith(SI, FalseVal); } else if (Pred == ICmpInst::ICMP_NE) { - if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, DL, TLI) == FalseVal || - SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, DL, TLI) == FalseVal) + if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, DL, TLI, DT, AC) == + FalseVal || + SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, DL, TLI, DT, AC) == + FalseVal) return ReplaceInstUsesWith(SI, TrueVal); - if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, DL, TLI) == TrueVal || - SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, DL, TLI) == TrueVal) + if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, DL, TLI, DT, AC) == + TrueVal || + SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, DL, TLI, DT, AC) == + TrueVal) return ReplaceInstUsesWith(SI, TrueVal); } @@ -634,6 +616,52 @@ Instruction *InstCombiner::visitSelectInstWithICmp(SelectInst &SI, } } + if (unsigned BitWidth = TrueVal->getType()->getScalarSizeInBits()) { + APInt MinSignedValue = APInt::getSignBit(BitWidth); + Value *X; + const APInt *Y, *C; + bool TrueWhenUnset; + bool IsBitTest = false; + if (ICmpInst::isEquality(Pred) && + match(CmpLHS, m_And(m_Value(X), m_Power2(Y))) && + match(CmpRHS, m_Zero())) { + IsBitTest = true; + TrueWhenUnset = Pred == ICmpInst::ICMP_EQ; + } else if (Pred == ICmpInst::ICMP_SLT && match(CmpRHS, m_Zero())) { + X = CmpLHS; + Y = &MinSignedValue; + IsBitTest = true; + TrueWhenUnset = false; + } else if (Pred == ICmpInst::ICMP_SGT && match(CmpRHS, m_AllOnes())) { + X = CmpLHS; + Y = &MinSignedValue; + IsBitTest = true; + TrueWhenUnset = true; + } + if (IsBitTest) { + Value *V = nullptr; + // (X & Y) == 0 ? X : X ^ Y --> X & ~Y + if (TrueWhenUnset && TrueVal == X && + match(FalseVal, m_Xor(m_Specific(X), m_APInt(C))) && *Y == *C) + V = Builder->CreateAnd(X, ~(*Y)); + // (X & Y) != 0 ? X ^ Y : X --> X & ~Y + else if (!TrueWhenUnset && FalseVal == X && + match(TrueVal, m_Xor(m_Specific(X), m_APInt(C))) && *Y == *C) + V = Builder->CreateAnd(X, ~(*Y)); + // (X & Y) == 0 ? X ^ Y : X --> X | Y + else if (TrueWhenUnset && FalseVal == X && + match(TrueVal, m_Xor(m_Specific(X), m_APInt(C))) && *Y == *C) + V = Builder->CreateOr(X, *Y); + // (X & Y) != 0 ? X : X ^ Y --> X | Y + else if (!TrueWhenUnset && TrueVal == X && + match(FalseVal, m_Xor(m_Specific(X), m_APInt(C))) && *Y == *C) + V = Builder->CreateOr(X, *Y); + + if (V) + return ReplaceInstUsesWith(SI, V); + } + } + if (Value *V = foldSelectICmpAndOr(SI, TrueVal, FalseVal, Builder)) return ReplaceInstUsesWith(SI, V); @@ -825,7 +853,8 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { Value *TrueVal = SI.getTrueValue(); Value *FalseVal = SI.getFalseValue(); - if (Value *V = SimplifySelectInst(CondVal, TrueVal, FalseVal, DL)) + if (Value *V = + SimplifySelectInst(CondVal, TrueVal, FalseVal, DL, TLI, DT, AC)) return ReplaceInstUsesWith(SI, V); if (SI.getType()->isIntegerTy(1)) { @@ -917,8 +946,22 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { !CFPf->getValueAPF().isZero())) return ReplaceInstUsesWith(SI, TrueVal); } - // NOTE: if we wanted to, this is where to detect MIN/MAX + // Canonicalize to use ordered comparisons by swapping the select + // operands. + // + // e.g. + // (X ugt Y) ? X : Y -> (X ole Y) ? Y : X + if (FCI->hasOneUse() && FCmpInst::isUnordered(FCI->getPredicate())) { + FCmpInst::Predicate InvPred = FCI->getInversePredicate(); + Value *NewCond = Builder->CreateFCmp(InvPred, TrueVal, FalseVal, + FCI->getName() + ".inv"); + + return SelectInst::Create(NewCond, FalseVal, TrueVal, + SI.getName() + ".p"); + } + + // NOTE: if we wanted to, this is where to detect MIN/MAX } else if (FCI->getOperand(0) == FalseVal && FCI->getOperand(1) == TrueVal){ // Transform (X == Y) ? Y : X -> X if (FCI->getPredicate() == FCmpInst::FCMP_OEQ) { @@ -944,6 +987,21 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { !CFPf->getValueAPF().isZero())) return ReplaceInstUsesWith(SI, TrueVal); } + + // Canonicalize to use ordered comparisons by swapping the select + // operands. + // + // e.g. + // (X ugt Y) ? X : Y -> (X ole Y) ? X : Y + if (FCI->hasOneUse() && FCmpInst::isUnordered(FCI->getPredicate())) { + FCmpInst::Predicate InvPred = FCI->getInversePredicate(); + Value *NewCond = Builder->CreateFCmp(InvPred, FalseVal, TrueVal, + FCI->getName() + ".inv"); + + return SelectInst::Create(NewCond, FalseVal, TrueVal, + SI.getName() + ".p"); + } + // NOTE: if we wanted to, this is where to detect MIN/MAX } // NOTE: if we wanted to, this is where to detect ABS diff --git a/lib/Transforms/InstCombine/InstCombineShifts.cpp b/lib/Transforms/InstCombine/InstCombineShifts.cpp index 2f91c204dbd9..0a16e2592862 100644 --- a/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -68,7 +68,7 @@ Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) { /// this succeeds, the GetShiftedValue function will be called to produce the /// value. static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool isLeftShift, - InstCombiner &IC) { + InstCombiner &IC, Instruction *CxtI) { // We can always evaluate constants shifted. if (isa<Constant>(V)) return true; @@ -111,8 +111,8 @@ static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool isLeftShift, case Instruction::Or: case Instruction::Xor: // Bitwise operators can all arbitrarily be arbitrarily evaluated shifted. - return CanEvaluateShifted(I->getOperand(0), NumBits, isLeftShift, IC) && - CanEvaluateShifted(I->getOperand(1), NumBits, isLeftShift, IC); + return CanEvaluateShifted(I->getOperand(0), NumBits, isLeftShift, IC, I) && + CanEvaluateShifted(I->getOperand(1), NumBits, isLeftShift, IC, I); case Instruction::Shl: { // We can often fold the shift into shifts-by-a-constant. @@ -131,8 +131,9 @@ static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool isLeftShift, // profitable unless we know the and'd out bits are already zero. if (CI->getZExtValue() > NumBits) { unsigned LowBits = TypeWidth - CI->getZExtValue(); - if (MaskedValueIsZero(I->getOperand(0), - APInt::getLowBitsSet(TypeWidth, NumBits) << LowBits)) + if (IC.MaskedValueIsZero(I->getOperand(0), + APInt::getLowBitsSet(TypeWidth, NumBits) << LowBits, + 0, CxtI)) return true; } @@ -155,8 +156,9 @@ static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool isLeftShift, // profitable unless we know the and'd out bits are already zero. if (CI->getValue().ult(TypeWidth) && CI->getZExtValue() > NumBits) { unsigned LowBits = CI->getZExtValue() - NumBits; - if (MaskedValueIsZero(I->getOperand(0), - APInt::getLowBitsSet(TypeWidth, NumBits) << LowBits)) + if (IC.MaskedValueIsZero(I->getOperand(0), + APInt::getLowBitsSet(TypeWidth, NumBits) << LowBits, + 0, CxtI)) return true; } @@ -164,8 +166,9 @@ static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool isLeftShift, } case Instruction::Select: { SelectInst *SI = cast<SelectInst>(I); - return CanEvaluateShifted(SI->getTrueValue(), NumBits, isLeftShift, IC) && - CanEvaluateShifted(SI->getFalseValue(), NumBits, isLeftShift, IC); + return CanEvaluateShifted(SI->getTrueValue(), NumBits, isLeftShift, + IC, SI) && + CanEvaluateShifted(SI->getFalseValue(), NumBits, isLeftShift, IC, SI); } case Instruction::PHI: { // We can change a phi if we can change all operands. Note that we never @@ -173,7 +176,8 @@ static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool isLeftShift, // instructions with a single use. PHINode *PN = cast<PHINode>(I); for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) - if (!CanEvaluateShifted(PN->getIncomingValue(i), NumBits, isLeftShift,IC)) + if (!CanEvaluateShifted(PN->getIncomingValue(i), NumBits, isLeftShift, + IC, PN)) return false; return true; } @@ -329,7 +333,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, // See if we can propagate this shift into the input, this covers the trivial // cast of lshr(shl(x,c1),c2) as well as other more complex cases. if (I.getOpcode() != Instruction::AShr && - CanEvaluateShifted(Op0, COp1->getZExtValue(), isLeftShift, *this)) { + CanEvaluateShifted(Op0, COp1->getZExtValue(), isLeftShift, *this, &I)) { DEBUG(dbgs() << "ICE: GetShiftedValue propagating shift through expression" " to eliminate shift:\n IN: " << *Op0 << "\n SH: " << I <<"\n"); @@ -488,7 +492,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, } - // If the operand is an bitwise operator with a constant RHS, and the + // If the operand is a bitwise operator with a constant RHS, and the // shift is the only use, we can pull it out of the shift. if (ConstantInt *Op0C = dyn_cast<ConstantInt>(Op0BO->getOperand(1))) { bool isValid = true; // Valid only for And, Or, Xor @@ -689,9 +693,9 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return ReplaceInstUsesWith(I, V); - if (Value *V = SimplifyShlInst(I.getOperand(0), I.getOperand(1), - I.hasNoSignedWrap(), I.hasNoUnsignedWrap(), - DL)) + if (Value *V = + SimplifyShlInst(I.getOperand(0), I.getOperand(1), I.hasNoSignedWrap(), + I.hasNoUnsignedWrap(), DL, TLI, DT, AC)) return ReplaceInstUsesWith(I, V); if (Instruction *V = commonShiftTransforms(I)) @@ -703,14 +707,15 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) { // If the shifted-out value is known-zero, then this is a NUW shift. if (!I.hasNoUnsignedWrap() && MaskedValueIsZero(I.getOperand(0), - APInt::getHighBitsSet(Op1C->getBitWidth(), ShAmt))) { + APInt::getHighBitsSet(Op1C->getBitWidth(), ShAmt), + 0, &I)) { I.setHasNoUnsignedWrap(); return &I; } // If the shifted out value is all signbits, this is a NSW shift. if (!I.hasNoSignedWrap() && - ComputeNumSignBits(I.getOperand(0)) > ShAmt) { + ComputeNumSignBits(I.getOperand(0), 0, &I) > ShAmt) { I.setHasNoSignedWrap(); return &I; } @@ -730,8 +735,8 @@ Instruction *InstCombiner::visitLShr(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return ReplaceInstUsesWith(I, V); - if (Value *V = SimplifyLShrInst(I.getOperand(0), I.getOperand(1), - I.isExact(), DL)) + if (Value *V = SimplifyLShrInst(I.getOperand(0), I.getOperand(1), I.isExact(), + DL, TLI, DT, AC)) return ReplaceInstUsesWith(I, V); if (Instruction *R = commonShiftTransforms(I)) @@ -760,7 +765,8 @@ Instruction *InstCombiner::visitLShr(BinaryOperator &I) { // If the shifted-out value is known-zero, then this is an exact shift. if (!I.isExact() && - MaskedValueIsZero(Op0,APInt::getLowBitsSet(Op1C->getBitWidth(),ShAmt))){ + MaskedValueIsZero(Op0, APInt::getLowBitsSet(Op1C->getBitWidth(), ShAmt), + 0, &I)){ I.setIsExact(); return &I; } @@ -773,8 +779,8 @@ Instruction *InstCombiner::visitAShr(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return ReplaceInstUsesWith(I, V); - if (Value *V = SimplifyAShrInst(I.getOperand(0), I.getOperand(1), - I.isExact(), DL)) + if (Value *V = SimplifyAShrInst(I.getOperand(0), I.getOperand(1), I.isExact(), + DL, TLI, DT, AC)) return ReplaceInstUsesWith(I, V); if (Instruction *R = commonShiftTransforms(I)) @@ -804,7 +810,8 @@ Instruction *InstCombiner::visitAShr(BinaryOperator &I) { // If the shifted-out value is known-zero, then this is an exact shift. if (!I.isExact() && - MaskedValueIsZero(Op0,APInt::getLowBitsSet(Op1C->getBitWidth(),ShAmt))){ + MaskedValueIsZero(Op0,APInt::getLowBitsSet(Op1C->getBitWidth(),ShAmt), + 0, &I)){ I.setIsExact(); return &I; } @@ -812,7 +819,8 @@ Instruction *InstCombiner::visitAShr(BinaryOperator &I) { // See if we can turn a signed shr into an unsigned shr. if (MaskedValueIsZero(Op0, - APInt::getSignBit(I.getType()->getScalarSizeInBits()))) + APInt::getSignBit(I.getType()->getScalarSizeInBits()), + 0, &I)) return BinaryOperator::CreateLShr(Op0, Op1); return nullptr; diff --git a/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp index 1b42d3d504a3..ad6983abf83d 100644 --- a/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ b/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -43,6 +43,20 @@ static bool ShrinkDemandedConstant(Instruction *I, unsigned OpNo, // This instruction is producing bits that are not demanded. Shrink the RHS. Demanded &= OpC->getValue(); I->setOperand(OpNo, ConstantInt::get(OpC->getType(), Demanded)); + + // If either 'nsw' or 'nuw' is set and the constant is negative, + // removing *any* bits from the constant could make overflow occur. + // Remove 'nsw' and 'nuw' from the instruction in this case. + if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(I)) { + assert(OBO->getOpcode() == Instruction::Add); + if (OBO->hasNoSignedWrap() || OBO->hasNoUnsignedWrap()) { + if (OpC->getValue().isNegative()) { + cast<BinaryOperator>(OBO)->setHasNoSignedWrap(false); + cast<BinaryOperator>(OBO)->setHasNoUnsignedWrap(false); + } + } + } + return true; } @@ -57,7 +71,7 @@ bool InstCombiner::SimplifyDemandedInstructionBits(Instruction &Inst) { APInt DemandedMask(APInt::getAllOnesValue(BitWidth)); Value *V = SimplifyDemandedUseBits(&Inst, DemandedMask, - KnownZero, KnownOne, 0); + KnownZero, KnownOne, 0, &Inst); if (!V) return false; if (V == &Inst) return true; ReplaceInstUsesWith(Inst, V); @@ -71,7 +85,8 @@ bool InstCombiner::SimplifyDemandedBits(Use &U, APInt DemandedMask, APInt &KnownZero, APInt &KnownOne, unsigned Depth) { Value *NewVal = SimplifyDemandedUseBits(U.get(), DemandedMask, - KnownZero, KnownOne, Depth); + KnownZero, KnownOne, Depth, + dyn_cast<Instruction>(U.getUser())); if (!NewVal) return false; U = NewVal; return true; @@ -101,7 +116,8 @@ bool InstCombiner::SimplifyDemandedBits(Use &U, APInt DemandedMask, /// in the context where the specified bits are demanded, but not for all users. Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, APInt &KnownZero, APInt &KnownOne, - unsigned Depth) { + unsigned Depth, + Instruction *CxtI) { assert(V != nullptr && "Null pointer of Value???"); assert(Depth <= 6 && "Limit Search Depth"); uint32_t BitWidth = DemandedMask.getBitWidth(); @@ -144,7 +160,7 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, Instruction *I = dyn_cast<Instruction>(V); if (!I) { - computeKnownBits(V, KnownZero, KnownOne, Depth); + computeKnownBits(V, KnownZero, KnownOne, Depth, CxtI); return nullptr; // Only analyze instructions. } @@ -158,8 +174,10 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // this instruction has a simpler value in that context. if (I->getOpcode() == Instruction::And) { // If either the LHS or the RHS are Zero, the result is zero. - computeKnownBits(I->getOperand(1), RHSKnownZero, RHSKnownOne, Depth+1); - computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth+1); + computeKnownBits(I->getOperand(1), RHSKnownZero, RHSKnownOne, Depth+1, + CxtI); + computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth+1, + CxtI); // If all of the demanded bits are known 1 on one side, return the other. // These bits cannot contribute to the result of the 'and' in this @@ -180,8 +198,10 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // only bits from X or Y are demanded. // If either the LHS or the RHS are One, the result is One. - computeKnownBits(I->getOperand(1), RHSKnownZero, RHSKnownOne, Depth+1); - computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth+1); + computeKnownBits(I->getOperand(1), RHSKnownZero, RHSKnownOne, Depth+1, + CxtI); + computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth+1, + CxtI); // If all of the demanded bits are known zero on one side, return the // other. These bits cannot contribute to the result of the 'or' in this @@ -205,8 +225,10 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // We can simplify (X^Y) -> X or Y in the user's context if we know that // only bits from X or Y are demanded. - computeKnownBits(I->getOperand(1), RHSKnownZero, RHSKnownOne, Depth+1); - computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth+1); + computeKnownBits(I->getOperand(1), RHSKnownZero, RHSKnownOne, Depth+1, + CxtI); + computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth+1, + CxtI); // If all of the demanded bits are known zero on one side, return the // other. @@ -217,7 +239,7 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, } // Compute the KnownZero/KnownOne bits to simplify things downstream. - computeKnownBits(I, KnownZero, KnownOne, Depth); + computeKnownBits(I, KnownZero, KnownOne, Depth, CxtI); return nullptr; } @@ -230,7 +252,7 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, switch (I->getOpcode()) { default: - computeKnownBits(I, KnownZero, KnownOne, Depth); + computeKnownBits(I, KnownZero, KnownOne, Depth, CxtI); break; case Instruction::And: // If either the LHS or the RHS are Zero, the result is zero. @@ -242,6 +264,12 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, assert(!(RHSKnownZero & RHSKnownOne) && "Bits known to be one AND zero?"); assert(!(LHSKnownZero & LHSKnownOne) && "Bits known to be one AND zero?"); + // If the client is only demanding bits that we know, return the known + // constant. + if ((DemandedMask & ((RHSKnownZero | LHSKnownZero)| + (RHSKnownOne & LHSKnownOne))) == DemandedMask) + return Constant::getIntegerValue(VTy, RHSKnownOne & LHSKnownOne); + // If all of the demanded bits are known 1 on one side, return the other. // These bits cannot contribute to the result of the 'and'. if ((DemandedMask & ~LHSKnownZero & RHSKnownOne) == @@ -274,6 +302,12 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, assert(!(RHSKnownZero & RHSKnownOne) && "Bits known to be one AND zero?"); assert(!(LHSKnownZero & LHSKnownOne) && "Bits known to be one AND zero?"); + // If the client is only demanding bits that we know, return the known + // constant. + if ((DemandedMask & ((RHSKnownZero & LHSKnownZero)| + (RHSKnownOne | LHSKnownOne))) == DemandedMask) + return Constant::getIntegerValue(VTy, RHSKnownOne | LHSKnownOne); + // If all of the demanded bits are known zero on one side, return the other. // These bits cannot contribute to the result of the 'or'. if ((DemandedMask & ~LHSKnownOne & RHSKnownZero) == @@ -310,6 +344,18 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, assert(!(RHSKnownZero & RHSKnownOne) && "Bits known to be one AND zero?"); assert(!(LHSKnownZero & LHSKnownOne) && "Bits known to be one AND zero?"); + // Output known-0 bits are known if clear or set in both the LHS & RHS. + APInt IKnownZero = (RHSKnownZero & LHSKnownZero) | + (RHSKnownOne & LHSKnownOne); + // Output known-1 are known to be set if set in only one of the LHS, RHS. + APInt IKnownOne = (RHSKnownZero & LHSKnownOne) | + (RHSKnownOne & LHSKnownZero); + + // If the client is only demanding bits that we know, return the known + // constant. + if ((DemandedMask & (IKnownZero|IKnownOne)) == DemandedMask) + return Constant::getIntegerValue(VTy, IKnownOne); + // If all of the demanded bits are known zero on one side, return the other. // These bits cannot contribute to the result of the 'xor'. if ((DemandedMask & RHSKnownZero) == DemandedMask) @@ -581,7 +627,7 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // Otherwise just hand the sub off to computeKnownBits to fill in // the known zeros and ones. - computeKnownBits(V, KnownZero, KnownOne, Depth); + computeKnownBits(V, KnownZero, KnownOne, Depth, CxtI); // Turn this into a xor if LHS is 2^n-1 and the remaining bits are known // zero. @@ -752,7 +798,8 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // remainder is zero. if (DemandedMask.isNegative() && KnownZero.isNonNegative()) { APInt LHSKnownZero(BitWidth, 0), LHSKnownOne(BitWidth, 0); - computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth+1); + computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth+1, + CxtI); // If it's known zero, our sign bit is also zero. if (LHSKnownZero.isNegative()) KnownZero.setBit(KnownZero.getBitWidth() - 1); @@ -814,7 +861,7 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, return nullptr; } } - computeKnownBits(V, KnownZero, KnownOne, Depth); + computeKnownBits(V, KnownZero, KnownOne, Depth, CxtI); break; } diff --git a/lib/Transforms/InstCombine/InstCombineWorklist.h b/lib/Transforms/InstCombine/InstCombineWorklist.h index 1ab7db3a989f..8d857d0f8e00 100644 --- a/lib/Transforms/InstCombine/InstCombineWorklist.h +++ b/lib/Transforms/InstCombine/InstCombineWorklist.h @@ -7,8 +7,8 @@ // //===----------------------------------------------------------------------===// -#ifndef INSTCOMBINE_WORKLIST_H -#define INSTCOMBINE_WORKLIST_H +#ifndef LLVM_LIB_TRANSFORMS_INSTCOMBINE_INSTCOMBINEWORKLIST_H +#define LLVM_LIB_TRANSFORMS_INSTCOMBINE_INSTCOMBINEWORKLIST_H #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallVector.h" diff --git a/lib/Transforms/InstCombine/InstructionCombining.cpp b/lib/Transforms/InstCombine/InstructionCombining.cpp index d3648e2d0505..a0c239a020c8 100644 --- a/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -39,12 +39,16 @@ #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringSwitch.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/CFG.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/CFG.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/Dominators.h" #include "llvm/IR/GetElementPtrTypeIterator.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/PatternMatch.h" @@ -68,11 +72,6 @@ STATISTIC(NumExpand, "Number of expansions"); STATISTIC(NumFactor , "Number of factorizations"); STATISTIC(NumReassoc , "Number of reassociations"); -static cl::opt<bool> UnsafeFPShrink("enable-double-float-shrink", cl::Hidden, - cl::init(false), - cl::desc("Enable unsafe double to float " - "shrinking for math lib calls")); - // Initialization Routines void llvm::initializeInstCombine(PassRegistry &Registry) { initializeInstCombinerPass(Registry); @@ -85,13 +84,18 @@ void LLVMInitializeInstCombine(LLVMPassRegistryRef R) { char InstCombiner::ID = 0; INITIALIZE_PASS_BEGIN(InstCombiner, "instcombine", "Combine redundant instructions", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfo) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_END(InstCombiner, "instcombine", "Combine redundant instructions", false, false) void InstCombiner::getAnalysisUsage(AnalysisUsage &AU) const { AU.setPreservesCFG(); + AU.addRequired<AssumptionCacheTracker>(); AU.addRequired<TargetLibraryInfo>(); + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addPreserved<DominatorTreeWrapperPass>(); } @@ -390,6 +394,25 @@ static bool RightDistributesOverLeft(Instruction::BinaryOps LOp, Instruction::BinaryOps ROp) { if (Instruction::isCommutative(ROp)) return LeftDistributesOverRight(ROp, LOp); + + switch (LOp) { + default: + return false; + // (X >> Z) & (Y >> Z) -> (X&Y) >> Z for all shifts. + // (X >> Z) | (Y >> Z) -> (X|Y) >> Z for all shifts. + // (X >> Z) ^ (Y >> Z) -> (X^Y) >> Z for all shifts. + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: + switch (ROp) { + default: + return false; + case Instruction::Shl: + case Instruction::LShr: + case Instruction::AShr: + return true; + } + } // TODO: It would be nice to handle division, aka "(X + Y)/Z = X/Z + Y/Z", // but this requires knowing that the addition does not overflow and other // such subtleties. @@ -411,26 +434,37 @@ static Value *getIdentityValue(Instruction::BinaryOps OpCode, Value *V) { } /// This function factors binary ops which can be combined using distributive -/// laws. This also factor SHL as MUL e.g. SHL(X, 2) ==> MUL(X, 4). +/// laws. This function tries to transform 'Op' based TopLevelOpcode to enable +/// factorization e.g for ADD(SHL(X , 2), MUL(X, 5)), When this function called +/// with TopLevelOpcode == Instruction::Add and Op = SHL(X, 2), transforms +/// SHL(X, 2) to MUL(X, 4) i.e. returns Instruction::Mul with LHS set to 'X' and +/// RHS to 4. static Instruction::BinaryOps -getBinOpsForFactorization(BinaryOperator *Op, Value *&LHS, Value *&RHS) { +getBinOpsForFactorization(Instruction::BinaryOps TopLevelOpcode, + BinaryOperator *Op, Value *&LHS, Value *&RHS) { if (!Op) return Instruction::BinaryOpsEnd; - if (Op->getOpcode() == Instruction::Shl) { - if (Constant *CST = dyn_cast<Constant>(Op->getOperand(1))) { - // The multiplier is really 1 << CST. - RHS = ConstantExpr::getShl(ConstantInt::get(Op->getType(), 1), CST); - LHS = Op->getOperand(0); - return Instruction::Mul; + LHS = Op->getOperand(0); + RHS = Op->getOperand(1); + + switch (TopLevelOpcode) { + default: + return Op->getOpcode(); + + case Instruction::Add: + case Instruction::Sub: + if (Op->getOpcode() == Instruction::Shl) { + if (Constant *CST = dyn_cast<Constant>(Op->getOperand(1))) { + // The multiplier is really 1 << CST. + RHS = ConstantExpr::getShl(ConstantInt::get(Op->getType(), 1), CST); + return Instruction::Mul; + } } + return Op->getOpcode(); } // TODO: We can add other conversions e.g. shr => div etc. - - LHS = Op->getOperand(0); - RHS = Op->getOperand(1); - return Op->getOpcode(); } /// This tries to simplify binary operations by factorizing out common terms @@ -529,8 +563,9 @@ Value *InstCombiner::SimplifyUsingDistributiveLaws(BinaryOperator &I) { // Factorization. Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr; - Instruction::BinaryOps LHSOpcode = getBinOpsForFactorization(Op0, A, B); - Instruction::BinaryOps RHSOpcode = getBinOpsForFactorization(Op1, C, D); + auto TopLevelOpcode = I.getOpcode(); + auto LHSOpcode = getBinOpsForFactorization(TopLevelOpcode, Op0, A, B); + auto RHSOpcode = getBinOpsForFactorization(TopLevelOpcode, Op1, C, D); // The instruction has the form "(A op' B) op (C op' D)". Try to factorize // a common term. @@ -552,7 +587,6 @@ Value *InstCombiner::SimplifyUsingDistributiveLaws(BinaryOperator &I) { return V; // Expansion. - Instruction::BinaryOps TopLevelOpcode = I.getOpcode(); if (Op0 && RightDistributesOverLeft(Op0->getOpcode(), TopLevelOpcode)) { // The instruction has the form "(A op' B) op C". See if expanding it out // to "(A op C) op' (B op C)" results in simplifications. @@ -765,13 +799,14 @@ Instruction *InstCombiner::FoldOpIntoPhi(Instruction &I) { // If the incoming non-constant value is in I's block, we will remove one // instruction, but insert another equivalent one, leading to infinite // instcombine. - if (NonConstBB == I.getParent()) + if (isPotentiallyReachable(I.getParent(), NonConstBB, DT, + getAnalysisIfAvailable<LoopInfo>())) return nullptr; } // If there is exactly one non-constant value, we can insert a copy of the // operation in that block. However, if this is a critical edge, we would be - // inserting the computation one some other paths (e.g. inside a loop). Only + // inserting the computation on some other paths (e.g. inside a loop). Only // do this if the pred block is unconditionally branching into the phi block. if (NonConstBB != nullptr) { BranchInst *BI = dyn_cast<BranchInst>(NonConstBB->getTerminator()); @@ -1284,7 +1319,7 @@ Value *InstCombiner::SimplifyVectorOp(BinaryOperator &Inst) { Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { SmallVector<Value*, 8> Ops(GEP.op_begin(), GEP.op_end()); - if (Value *V = SimplifyGEPInst(Ops, DL)) + if (Value *V = SimplifyGEPInst(Ops, DL, TLI, DT, AC)) return ReplaceInstUsesWith(GEP, V); Value *PtrOp = GEP.getOperand(0); @@ -1478,19 +1513,50 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { GetElementPtrInst::Create(Src->getOperand(0), Indices, GEP.getName()); } - // Canonicalize (gep i8* X, -(ptrtoint Y)) to (sub (ptrtoint X), (ptrtoint Y)) - // The GEP pattern is emitted by the SCEV expander for certain kinds of - // pointer arithmetic. - if (DL && GEP.getNumIndices() == 1 && - match(GEP.getOperand(1), m_Neg(m_PtrToInt(m_Value())))) { + if (DL && GEP.getNumIndices() == 1) { unsigned AS = GEP.getPointerAddressSpace(); - if (GEP.getType() == Builder->getInt8PtrTy(AS) && - GEP.getOperand(1)->getType()->getScalarSizeInBits() == + if (GEP.getOperand(1)->getType()->getScalarSizeInBits() == DL->getPointerSizeInBits(AS)) { - Operator *Index = cast<Operator>(GEP.getOperand(1)); - Value *PtrToInt = Builder->CreatePtrToInt(PtrOp, Index->getType()); - Value *NewSub = Builder->CreateSub(PtrToInt, Index->getOperand(1)); - return CastInst::Create(Instruction::IntToPtr, NewSub, GEP.getType()); + Type *PtrTy = GEP.getPointerOperandType(); + Type *Ty = PtrTy->getPointerElementType(); + uint64_t TyAllocSize = DL->getTypeAllocSize(Ty); + + bool Matched = false; + uint64_t C; + Value *V = nullptr; + if (TyAllocSize == 1) { + V = GEP.getOperand(1); + Matched = true; + } else if (match(GEP.getOperand(1), + m_AShr(m_Value(V), m_ConstantInt(C)))) { + if (TyAllocSize == 1ULL << C) + Matched = true; + } else if (match(GEP.getOperand(1), + m_SDiv(m_Value(V), m_ConstantInt(C)))) { + if (TyAllocSize == C) + Matched = true; + } + + if (Matched) { + // Canonicalize (gep i8* X, -(ptrtoint Y)) + // to (inttoptr (sub (ptrtoint X), (ptrtoint Y))) + // The GEP pattern is emitted by the SCEV expander for certain kinds of + // pointer arithmetic. + if (match(V, m_Neg(m_PtrToInt(m_Value())))) { + Operator *Index = cast<Operator>(V); + Value *PtrToInt = Builder->CreatePtrToInt(PtrOp, Index->getType()); + Value *NewSub = Builder->CreateSub(PtrToInt, Index->getOperand(1)); + return CastInst::Create(Instruction::IntToPtr, NewSub, GEP.getType()); + } + // Canonicalize (gep i8* X, (ptrtoint Y)-(ptrtoint X)) + // to (bitcast Y) + Value *Y; + if (match(V, m_Sub(m_PtrToInt(m_Value(Y)), + m_PtrToInt(m_Specific(GEP.getOperand(0)))))) { + return CastInst::CreatePointerBitCastOrAddrSpaceCast(Y, + GEP.getType()); + } + } } } @@ -1667,6 +1733,18 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { if (!DL) return nullptr; + // addrspacecast between types is canonicalized as a bitcast, then an + // addrspacecast. To take advantage of the below bitcast + struct GEP, look + // through the addrspacecast. + if (AddrSpaceCastInst *ASC = dyn_cast<AddrSpaceCastInst>(PtrOp)) { + // X = bitcast A addrspace(1)* to B addrspace(1)* + // Y = addrspacecast A addrspace(1)* to B addrspace(2)* + // Z = gep Y, <...constant indices...> + // Into an addrspacecasted GEP of the struct. + if (BitCastInst *BC = dyn_cast<BitCastInst>(ASC->getOperand(0))) + PtrOp = BC; + } + /// See if we can simplify: /// X = bitcast A* to B* /// Y = gep X, <...constant indices...> @@ -1675,11 +1753,10 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { if (BitCastInst *BCI = dyn_cast<BitCastInst>(PtrOp)) { Value *Operand = BCI->getOperand(0); PointerType *OpType = cast<PointerType>(Operand->getType()); - unsigned OffsetBits = DL->getPointerTypeSizeInBits(OpType); + unsigned OffsetBits = DL->getPointerTypeSizeInBits(GEP.getType()); APInt Offset(OffsetBits, 0); if (!isa<BitCastInst>(Operand) && - GEP.accumulateConstantOffset(*DL, Offset) && - StrippedPtrTy->getAddressSpace() == GEP.getPointerAddressSpace()) { + GEP.accumulateConstantOffset(*DL, Offset)) { // If this GEP instruction doesn't move the pointer, just replace the GEP // with a bitcast of the real input to the dest type. @@ -1697,6 +1774,9 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { return &GEP; } } + + if (Operand->getType()->getPointerAddressSpace() != GEP.getAddressSpace()) + return new AddrSpaceCastInst(Operand, GEP.getType()); return new BitCastInst(Operand, GEP.getType()); } @@ -1712,6 +1792,9 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { if (NGEP->getType() == GEP.getType()) return ReplaceInstUsesWith(GEP, NGEP); NGEP->takeName(&GEP); + + if (NGEP->getType()->getPointerAddressSpace() != GEP.getAddressSpace()) + return new AddrSpaceCastInst(NGEP, GEP.getType()); return new BitCastInst(NGEP, GEP.getType()); } } @@ -1922,7 +2005,25 @@ Instruction *InstCombiner::visitFree(CallInst &FI) { return nullptr; } +Instruction *InstCombiner::visitReturnInst(ReturnInst &RI) { + if (RI.getNumOperands() == 0) // ret void + return nullptr; + + Value *ResultOp = RI.getOperand(0); + Type *VTy = ResultOp->getType(); + if (!VTy->isIntegerTy()) + return nullptr; + + // There might be assume intrinsics dominating this return that completely + // determine the value. If so, constant fold it. + unsigned BitWidth = VTy->getPrimitiveSizeInBits(); + APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); + computeKnownBits(ResultOp, KnownZero, KnownOne, 0, &RI); + if ((KnownZero|KnownOne).isAllOnesValue()) + RI.setOperand(0, Constant::getIntegerValue(VTy, KnownOne)); + return nullptr; +} Instruction *InstCombiner::visitBranchInst(BranchInst &BI) { // Change br (not X), label True, label False to: br X, label False, True @@ -1974,6 +2075,40 @@ Instruction *InstCombiner::visitBranchInst(BranchInst &BI) { Instruction *InstCombiner::visitSwitchInst(SwitchInst &SI) { Value *Cond = SI.getCondition(); + unsigned BitWidth = cast<IntegerType>(Cond->getType())->getBitWidth(); + APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); + computeKnownBits(Cond, KnownZero, KnownOne); + unsigned LeadingKnownZeros = KnownZero.countLeadingOnes(); + unsigned LeadingKnownOnes = KnownOne.countLeadingOnes(); + + // Compute the number of leading bits we can ignore. + for (auto &C : SI.cases()) { + LeadingKnownZeros = std::min( + LeadingKnownZeros, C.getCaseValue()->getValue().countLeadingZeros()); + LeadingKnownOnes = std::min( + LeadingKnownOnes, C.getCaseValue()->getValue().countLeadingOnes()); + } + + unsigned NewWidth = BitWidth - std::max(LeadingKnownZeros, LeadingKnownOnes); + + // Truncate the condition operand if the new type is equal to or larger than + // the largest legal integer type. We need to be conservative here since + // x86 generates redundant zero-extenstion instructions if the operand is + // truncated to i8 or i16. + bool TruncCond = false; + if (DL && BitWidth > NewWidth && + NewWidth >= DL->getLargestLegalIntTypeSize()) { + TruncCond = true; + IntegerType *Ty = IntegerType::get(SI.getContext(), NewWidth); + Builder->SetInsertPoint(&SI); + Value *NewCond = Builder->CreateTrunc(SI.getCondition(), Ty, "trunc"); + SI.setCondition(NewCond); + + for (auto &C : SI.cases()) + static_cast<SwitchInst::CaseIt *>(&C)->setValue(ConstantInt::get( + SI.getContext(), C.getCaseValue()->getValue().trunc(NewWidth))); + } + if (Instruction *I = dyn_cast<Instruction>(Cond)) { if (I->getOpcode() == Instruction::Add) if (ConstantInt *AddRHS = dyn_cast<ConstantInt>(I->getOperand(1))) { @@ -1982,8 +2117,12 @@ Instruction *InstCombiner::visitSwitchInst(SwitchInst &SI) { for (SwitchInst::CaseIt i = SI.case_begin(), e = SI.case_end(); i != e; ++i) { ConstantInt* CaseVal = i.getCaseValue(); - Constant* NewCaseVal = ConstantExpr::getSub(cast<Constant>(CaseVal), - AddRHS); + Constant *LHS = CaseVal; + if (TruncCond) + LHS = LeadingKnownZeros + ? ConstantExpr::getZExt(CaseVal, Cond->getType()) + : ConstantExpr::getSExt(CaseVal, Cond->getType()); + Constant* NewCaseVal = ConstantExpr::getSub(LHS, AddRHS); assert(isa<ConstantInt>(NewCaseVal) && "Result of expression should be constant"); i.setValue(cast<ConstantInt>(NewCaseVal)); @@ -1993,7 +2132,8 @@ Instruction *InstCombiner::visitSwitchInst(SwitchInst &SI) { return &SI; } } - return nullptr; + + return TruncCond ? &SI : nullptr; } Instruction *InstCombiner::visitExtractValueInst(ExtractValueInst &EV) { @@ -2212,7 +2352,7 @@ Instruction *InstCombiner::visitLandingPadInst(LandingPadInst &LI) { // If we already saw this clause, there is no point in having a second // copy of it. - if (AlreadyCaught.insert(TypeInfo)) { + if (AlreadyCaught.insert(TypeInfo).second) { // This catch clause was not already seen. NewClauses.push_back(CatchClause); } else { @@ -2294,7 +2434,7 @@ Instruction *InstCombiner::visitLandingPadInst(LandingPadInst &LI) { continue; // There is no point in having multiple copies of the same typeinfo in // a filter, so only add it if we didn't already. - if (SeenInFilter.insert(TypeInfo)) + if (SeenInFilter.insert(TypeInfo).second) NewFilterElts.push_back(cast<Constant>(Elt)); } // A filter containing a catch-all cannot match anything by definition. @@ -2531,7 +2671,7 @@ static bool TryToSinkInstruction(Instruction *I, BasicBlock *DestBlock) { /// whose condition is a known constant, we only visit the reachable successors. /// static bool AddReachableCodeToWorklist(BasicBlock *BB, - SmallPtrSet<BasicBlock*, 64> &Visited, + SmallPtrSetImpl<BasicBlock*> &Visited, InstCombiner &IC, const DataLayout *DL, const TargetLibraryInfo *TLI) { @@ -2546,7 +2686,8 @@ static bool AddReachableCodeToWorklist(BasicBlock *BB, BB = Worklist.pop_back_val(); // We have now visited this block! If we've already been here, ignore it. - if (!Visited.insert(BB)) continue; + if (!Visited.insert(BB).second) + continue; for (BasicBlock::iterator BBI = BB->begin(), E = BB->end(); BBI != E; ) { Instruction *Inst = BBI++; @@ -2807,13 +2948,13 @@ bool InstCombiner::DoOneIteration(Function &F, unsigned Iteration) { } namespace { -class InstCombinerLibCallSimplifier : public LibCallSimplifier { +class InstCombinerLibCallSimplifier final : public LibCallSimplifier { InstCombiner *IC; public: InstCombinerLibCallSimplifier(const DataLayout *DL, const TargetLibraryInfo *TLI, InstCombiner *IC) - : LibCallSimplifier(DL, TLI, UnsafeFPShrink) { + : LibCallSimplifier(DL, TLI) { this->IC = IC; } @@ -2829,18 +2970,20 @@ bool InstCombiner::runOnFunction(Function &F) { if (skipOptnoneFunction(F)) return false; + AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); DataLayoutPass *DLP = getAnalysisIfAvailable<DataLayoutPass>(); DL = DLP ? &DLP->getDataLayout() : nullptr; + DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); TLI = &getAnalysis<TargetLibraryInfo>(); + // Minimizing size? MinimizeSize = F.getAttributes().hasAttribute(AttributeSet::FunctionIndex, Attribute::MinSize); /// Builder - This is an IRBuilder that automatically inserts new /// instructions into the worklist when they are created. - IRBuilder<true, TargetFolder, InstCombineIRInserter> - TheBuilder(F.getContext(), TargetFolder(DL), - InstCombineIRInserter(Worklist)); + IRBuilder<true, TargetFolder, InstCombineIRInserter> TheBuilder( + F.getContext(), TargetFolder(DL), InstCombineIRInserter(Worklist, AC)); Builder = &TheBuilder; InstCombinerLibCallSimplifier TheSimplifier(DL, TLI, this); diff --git a/lib/Transforms/Instrumentation/AddressSanitizer.cpp b/lib/Transforms/Instrumentation/AddressSanitizer.cpp index 124ffe2f8f87..745c85a98e2f 100644 --- a/lib/Transforms/Instrumentation/AddressSanitizer.cpp +++ b/lib/Transforms/Instrumentation/AddressSanitizer.cpp @@ -27,6 +27,7 @@ #include "llvm/IR/CallSite.h" #include "llvm/IR/DIBuilder.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InlineAsm.h" @@ -36,10 +37,13 @@ #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Module.h" #include "llvm/IR/Type.h" +#include "llvm/MC/MCSectionMachO.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/DataTypes.h" #include "llvm/Support/Debug.h" #include "llvm/Support/Endian.h" +#include "llvm/Support/SwapByteOrder.h" +#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/ASanStackFrameLayout.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" @@ -59,9 +63,11 @@ static const uint64_t kIOSShadowOffset32 = 1ULL << 30; static const uint64_t kDefaultShadowOffset64 = 1ULL << 44; static const uint64_t kSmallX86_64ShadowOffset = 0x7FFF8000; // < 2G. static const uint64_t kPPC64_ShadowOffset64 = 1ULL << 41; -static const uint64_t kMIPS32_ShadowOffset32 = 0x0aaa8000; +static const uint64_t kMIPS32_ShadowOffset32 = 0x0aaa0000; +static const uint64_t kMIPS64_ShadowOffset64 = 1ULL << 36; static const uint64_t kFreeBSD_ShadowOffset32 = 1ULL << 30; static const uint64_t kFreeBSD_ShadowOffset64 = 1ULL << 46; +static const uint64_t kWindowsShadowOffset32 = 1ULL << 30; static const size_t kMinStackMallocSize = 1 << 6; // 64B static const size_t kMaxStackMallocSize = 1 << 16; // 64K @@ -70,7 +76,7 @@ static const uintptr_t kRetiredStackFrameMagic = 0x45E0360E; static const char *const kAsanModuleCtorName = "asan.module_ctor"; static const char *const kAsanModuleDtorName = "asan.module_dtor"; -static const int kAsanCtorAndDtorPriority = 1; +static const uint64_t kAsanCtorAndDtorPriority = 1; static const char *const kAsanReportErrorTemplate = "__asan_report_"; static const char *const kAsanReportLoadN = "__asan_report_load_n"; static const char *const kAsanReportStoreN = "__asan_report_store_n"; @@ -79,9 +85,7 @@ static const char *const kAsanUnregisterGlobalsName = "__asan_unregister_globals"; static const char *const kAsanPoisonGlobalsName = "__asan_before_dynamic_init"; static const char *const kAsanUnpoisonGlobalsName = "__asan_after_dynamic_init"; -static const char *const kAsanInitName = "__asan_init_v4"; -static const char *const kAsanCovModuleInitName = "__sanitizer_cov_module_init"; -static const char *const kAsanCovName = "__sanitizer_cov"; +static const char *const kAsanInitName = "__asan_init_v5"; static const char *const kAsanPtrCmp = "__sanitizer_ptr_cmp"; static const char *const kAsanPtrSub = "__sanitizer_ptr_sub"; static const char *const kAsanHandleNoReturnName = "__asan_handle_no_return"; @@ -89,6 +93,7 @@ static const int kMaxAsanStackMallocSizeClass = 10; static const char *const kAsanStackMallocNameTemplate = "__asan_stack_malloc_"; static const char *const kAsanStackFreeNameTemplate = "__asan_stack_free_"; static const char *const kAsanGenPrefix = "__asan_gen_"; +static const char *const kSanCovGenPrefix = "__sancov_gen_"; static const char *const kAsanPoisonStackMemoryName = "__asan_poison_stack_memory"; static const char *const kAsanUnpoisonStackMemoryName = @@ -104,6 +109,12 @@ static const int kAsanStackAfterReturnMagic = 0xf5; // Accesses sizes are powers of two: 1, 2, 4, 8, 16. static const size_t kNumberOfAccessSizes = 5; +static const unsigned kAllocaRzSize = 32; +static const unsigned kAsanAllocaLeftMagic = 0xcacacacaU; +static const unsigned kAsanAllocaRightMagic = 0xcbcbcbcbU; +static const unsigned kAsanAllocaPartialVal1 = 0xcbcbcb00U; +static const unsigned kAsanAllocaPartialVal2 = 0x000000cbU; + // Command-line flags. // This flag may need to be replaced with -f[no-]asan-reads. @@ -133,13 +144,6 @@ static cl::opt<bool> ClUseAfterReturn("asan-use-after-return", // This flag may need to be replaced with -f[no]asan-globals. static cl::opt<bool> ClGlobals("asan-globals", cl::desc("Handle global objects"), cl::Hidden, cl::init(true)); -static cl::opt<int> ClCoverage("asan-coverage", - cl::desc("ASan coverage. 0: none, 1: entry block, 2: all blocks"), - cl::Hidden, cl::init(false)); -static cl::opt<int> ClCoverageBlockThreshold("asan-coverage-block-threshold", - cl::desc("Add coverage instrumentation only to the entry block if there " - "are more than this number of blocks."), - cl::Hidden, cl::init(1500)); static cl::opt<bool> ClInitializers("asan-initialization-order", cl::desc("Handle C++ initializer order"), cl::Hidden, cl::init(true)); static cl::opt<bool> ClInvalidPointerPairs("asan-detect-invalid-pointer-pair", @@ -158,19 +162,8 @@ static cl::opt<std::string> ClMemoryAccessCallbackPrefix( "asan-memory-access-callback-prefix", cl::desc("Prefix for memory access callbacks"), cl::Hidden, cl::init("__asan_")); - -// This is an experimental feature that will allow to choose between -// instrumented and non-instrumented code at link-time. -// If this option is on, just before instrumenting a function we create its -// clone; if the function is not changed by asan the clone is deleted. -// If we end up with a clone, we put the instrumented function into a section -// called "ASAN" and the uninstrumented function into a section called "NOASAN". -// -// This is still a prototype, we need to figure out a way to keep two copies of -// a function so that the linker can easily choose one of them. -static cl::opt<bool> ClKeepUninstrumented("asan-keep-uninstrumented-functions", - cl::desc("Keep uninstrumented copies of functions"), - cl::Hidden, cl::init(false)); +static cl::opt<bool> ClInstrumentAllocas("asan-instrument-allocas", + cl::desc("instrument dynamic allocas"), cl::Hidden, cl::init(false)); // These flags allow to change the shadow mapping. // The shadow mapping looks like @@ -192,6 +185,11 @@ static cl::opt<bool> ClCheckLifetime("asan-check-lifetime", cl::desc("Use llvm.lifetime intrinsics to insert extra checks"), cl::Hidden, cl::init(false)); +static cl::opt<bool> ClDynamicAllocaStack( + "asan-stack-dynamic-alloca", + cl::desc("Use dynamic alloca to represent stack variables"), cl::Hidden, + cl::init(false)); + // Debug flags. static cl::opt<int> ClDebug("asan-debug", cl::desc("debug"), cl::Hidden, cl::init(0)); @@ -206,21 +204,44 @@ static cl::opt<int> ClDebugMax("asan-debug-max", cl::desc("Debug man inst"), STATISTIC(NumInstrumentedReads, "Number of instrumented reads"); STATISTIC(NumInstrumentedWrites, "Number of instrumented writes"); +STATISTIC(NumInstrumentedDynamicAllocas, + "Number of instrumented dynamic allocas"); STATISTIC(NumOptimizedAccessesToGlobalArray, "Number of optimized accesses to global arrays"); STATISTIC(NumOptimizedAccessesToGlobalVar, "Number of optimized accesses to global vars"); namespace { +/// Frontend-provided metadata for source location. +struct LocationMetadata { + StringRef Filename; + int LineNo; + int ColumnNo; + + LocationMetadata() : Filename(), LineNo(0), ColumnNo(0) {} + + bool empty() const { return Filename.empty(); } + + void parse(MDNode *MDN) { + assert(MDN->getNumOperands() == 3); + MDString *MDFilename = cast<MDString>(MDN->getOperand(0)); + Filename = MDFilename->getString(); + LineNo = + mdconst::extract<ConstantInt>(MDN->getOperand(1))->getLimitedValue(); + ColumnNo = + mdconst::extract<ConstantInt>(MDN->getOperand(2))->getLimitedValue(); + } +}; + /// Frontend-provided metadata for global variables. class GlobalsMetadata { public: struct Entry { Entry() - : SourceLoc(nullptr), Name(nullptr), IsDynInit(false), + : SourceLoc(), Name(), IsDynInit(false), IsBlacklisted(false) {} - GlobalVariable *SourceLoc; - GlobalVariable *Name; + LocationMetadata SourceLoc; + StringRef Name; bool IsDynInit; bool IsBlacklisted; }; @@ -236,27 +257,22 @@ class GlobalsMetadata { for (auto MDN : Globals->operands()) { // Metadata node contains the global and the fields of "Entry". assert(MDN->getNumOperands() == 5); - Value *V = MDN->getOperand(0); + auto *GV = mdconst::extract_or_null<GlobalVariable>(MDN->getOperand(0)); // The optimizer may optimize away a global entirely. - if (!V) + if (!GV) continue; - GlobalVariable *GV = cast<GlobalVariable>(V); // We can already have an entry for GV if it was merged with another // global. Entry &E = Entries[GV]; - if (Value *Loc = MDN->getOperand(1)) { - GlobalVariable *GVLoc = cast<GlobalVariable>(Loc); - E.SourceLoc = GVLoc; - addSourceLocationGlobal(GVLoc); - } - if (Value *Name = MDN->getOperand(2)) { - GlobalVariable *GVName = cast<GlobalVariable>(Name); - E.Name = GVName; - InstrumentationGlobals.insert(GVName); - } - ConstantInt *IsDynInit = cast<ConstantInt>(MDN->getOperand(3)); + if (auto *Loc = cast_or_null<MDNode>(MDN->getOperand(1))) + E.SourceLoc.parse(Loc); + if (auto *Name = cast_or_null<MDString>(MDN->getOperand(2))) + E.Name = Name->getString(); + ConstantInt *IsDynInit = + mdconst::extract<ConstantInt>(MDN->getOperand(3)); E.IsDynInit |= IsDynInit->isOne(); - ConstantInt *IsBlacklisted = cast<ConstantInt>(MDN->getOperand(4)); + ConstantInt *IsBlacklisted = + mdconst::extract<ConstantInt>(MDN->getOperand(4)); E.IsBlacklisted |= IsBlacklisted->isOne(); } } @@ -267,31 +283,9 @@ class GlobalsMetadata { return (Pos != Entries.end()) ? Pos->second : Entry(); } - /// Check if the global was generated by the instrumentation - /// (we don't want to instrument it again in this case). - bool isInstrumentationGlobal(GlobalVariable *G) const { - return InstrumentationGlobals.count(G); - } - private: bool inited_; DenseMap<GlobalVariable*, Entry> Entries; - // Globals generated by the frontend instrumentation. - DenseSet<GlobalVariable*> InstrumentationGlobals; - - void addSourceLocationGlobal(GlobalVariable *SourceLocGV) { - // Source location global is a struct with layout: - // { - // filename, - // i32 line_number, - // i32 column_number, - // } - InstrumentationGlobals.insert(SourceLocGV); - ConstantStruct *Contents = - cast<ConstantStruct>(SourceLocGV->getInitializer()); - GlobalVariable *FilenameGV = cast<GlobalVariable>(Contents->getOperand(0)); - InstrumentationGlobals.insert(FilenameGV); - } }; /// This struct defines the shadow mapping using the rule: @@ -302,17 +296,19 @@ struct ShadowMapping { bool OrShadowOffset; }; -static ShadowMapping getShadowMapping(const Module &M, int LongSize) { - llvm::Triple TargetTriple(M.getTargetTriple()); +static ShadowMapping getShadowMapping(Triple &TargetTriple, int LongSize) { bool IsAndroid = TargetTriple.getEnvironment() == llvm::Triple::Android; - bool IsIOS = TargetTriple.getOS() == llvm::Triple::IOS; - bool IsFreeBSD = TargetTriple.getOS() == llvm::Triple::FreeBSD; - bool IsLinux = TargetTriple.getOS() == llvm::Triple::Linux; + bool IsIOS = TargetTriple.isiOS(); + bool IsFreeBSD = TargetTriple.isOSFreeBSD(); + bool IsLinux = TargetTriple.isOSLinux(); bool IsPPC64 = TargetTriple.getArch() == llvm::Triple::ppc64 || TargetTriple.getArch() == llvm::Triple::ppc64le; bool IsX86_64 = TargetTriple.getArch() == llvm::Triple::x86_64; bool IsMIPS32 = TargetTriple.getArch() == llvm::Triple::mips || TargetTriple.getArch() == llvm::Triple::mipsel; + bool IsMIPS64 = TargetTriple.getArch() == llvm::Triple::mips64 || + TargetTriple.getArch() == llvm::Triple::mips64el; + bool IsWindows = TargetTriple.isOSWindows(); ShadowMapping Mapping; @@ -325,6 +321,8 @@ static ShadowMapping getShadowMapping(const Module &M, int LongSize) { Mapping.Offset = kFreeBSD_ShadowOffset32; else if (IsIOS) Mapping.Offset = kIOSShadowOffset32; + else if (IsWindows) + Mapping.Offset = kWindowsShadowOffset32; else Mapping.Offset = kDefaultShadowOffset32; } else { // LongSize == 64 @@ -334,6 +332,8 @@ static ShadowMapping getShadowMapping(const Module &M, int LongSize) { Mapping.Offset = kFreeBSD_ShadowOffset64; else if (IsLinux && IsX86_64) Mapping.Offset = kSmallX86_64ShadowOffset; + else if (IsMIPS64) + Mapping.Offset = kMIPS64_ShadowOffset64; else Mapping.Offset = kDefaultShadowOffset64; } @@ -359,10 +359,15 @@ static size_t RedzoneSizeForScale(int MappingScale) { /// AddressSanitizer: instrument the code in module to find memory bugs. struct AddressSanitizer : public FunctionPass { - AddressSanitizer() : FunctionPass(ID) {} + AddressSanitizer() : FunctionPass(ID) { + initializeAddressSanitizerPass(*PassRegistry::getPassRegistry()); + } const char *getPassName() const override { return "AddressSanitizerFunctionPass"; } + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<DominatorTreeWrapperPass>(); + } void instrumentMop(Instruction *I, bool UseCalls); void instrumentPointerComparisonOrSubtraction(Instruction *I); void instrumentAddress(Instruction *OrigIns, Instruction *InsertBefore, @@ -380,23 +385,24 @@ struct AddressSanitizer : public FunctionPass { bool doInitialization(Module &M) override; static char ID; // Pass identification, replacement for typeid + DominatorTree &getDominatorTree() const { return *DT; } + private: void initializeCallbacks(Module &M); bool LooksLikeCodeInBug11395(Instruction *I); bool GlobalIsLinkerInitialized(GlobalVariable *G); - bool InjectCoverage(Function &F, const ArrayRef<BasicBlock*> AllBlocks); - void InjectCoverageAtBlock(Function &F, BasicBlock &BB); LLVMContext *C; const DataLayout *DL; + Triple TargetTriple; int LongSize; Type *IntptrTy; ShadowMapping Mapping; + DominatorTree *DT; Function *AsanCtorFunction; Function *AsanInitFunction; Function *AsanHandleNoReturnFunc; - Function *AsanCovFunction; Function *AsanPtrCmpFunction, *AsanPtrSubFunction; // This array is indexed by AccessIsWrite and log2(AccessSize). Function *AsanErrorCallback[2][kNumberOfAccessSizes]; @@ -435,12 +441,12 @@ class AddressSanitizerModule : public ModulePass { Type *IntptrTy; LLVMContext *C; const DataLayout *DL; + Triple TargetTriple; ShadowMapping Mapping; Function *AsanPoisonGlobals; Function *AsanUnpoisonGlobals; Function *AsanRegisterGlobals; Function *AsanUnregisterGlobals; - Function *AsanCovModuleInit; }; // Stack poisoning does not play well with exception handling. @@ -478,15 +484,36 @@ struct FunctionStackPoisoner : public InstVisitor<FunctionStackPoisoner> { }; SmallVector<AllocaPoisonCall, 8> AllocaPoisonCallVec; + // Stores left and right redzone shadow addresses for dynamic alloca + // and pointer to alloca instruction itself. + // LeftRzAddr is a shadow address for alloca left redzone. + // RightRzAddr is a shadow address for alloca right redzone. + struct DynamicAllocaCall { + AllocaInst *AI; + Value *LeftRzAddr; + Value *RightRzAddr; + bool Poison; + explicit DynamicAllocaCall(AllocaInst *AI, + Value *LeftRzAddr = nullptr, + Value *RightRzAddr = nullptr) + : AI(AI), LeftRzAddr(LeftRzAddr), RightRzAddr(RightRzAddr), Poison(true) + {} + }; + SmallVector<DynamicAllocaCall, 1> DynamicAllocaVec; + // Maps Value to an AllocaInst from which the Value is originated. typedef DenseMap<Value*, AllocaInst*> AllocaForValueMapTy; AllocaForValueMapTy AllocaForValue; + bool HasNonEmptyInlineAsm; + std::unique_ptr<CallInst> EmptyInlineAsm; + FunctionStackPoisoner(Function &F, AddressSanitizer &ASan) - : F(F), ASan(ASan), DIB(*F.getParent()), C(ASan.C), - IntptrTy(ASan.IntptrTy), IntptrPtrTy(PointerType::get(IntptrTy, 0)), - Mapping(ASan.Mapping), - StackAlignment(1 << Mapping.Scale) {} + : F(F), ASan(ASan), DIB(*F.getParent(), /*AllowUnresolved*/ false), + C(ASan.C), IntptrTy(ASan.IntptrTy), + IntptrPtrTy(PointerType::get(IntptrTy, 0)), Mapping(ASan.Mapping), + StackAlignment(1 << Mapping.Scale), HasNonEmptyInlineAsm(false), + EmptyInlineAsm(CallInst::Create(ASan.EmptyAsm)) {} bool runOnFunction() { if (!ClStack) return false; @@ -494,7 +521,7 @@ struct FunctionStackPoisoner : public InstVisitor<FunctionStackPoisoner> { for (BasicBlock *BB : depth_first(&F.getEntryBlock())) visit(*BB); - if (AllocaVec.empty()) return false; + if (AllocaVec.empty() && DynamicAllocaVec.empty()) return false; initializeCallbacks(*F.getParent()); @@ -506,7 +533,7 @@ struct FunctionStackPoisoner : public InstVisitor<FunctionStackPoisoner> { return true; } - // Finds all static Alloca instructions and puts + // Finds all Alloca instructions and puts // poisoned red zones around all of them. // Then unpoison everything back before the function returns. void poisonStack(); @@ -517,12 +544,64 @@ struct FunctionStackPoisoner : public InstVisitor<FunctionStackPoisoner> { RetVec.push_back(&RI); } + // Unpoison dynamic allocas redzones. + void unpoisonDynamicAlloca(DynamicAllocaCall &AllocaCall) { + if (!AllocaCall.Poison) + return; + for (auto Ret : RetVec) { + IRBuilder<> IRBRet(Ret); + PointerType *Int32PtrTy = PointerType::getUnqual(IRBRet.getInt32Ty()); + Value *Zero = Constant::getNullValue(IRBRet.getInt32Ty()); + Value *PartialRzAddr = IRBRet.CreateSub(AllocaCall.RightRzAddr, + ConstantInt::get(IntptrTy, 4)); + IRBRet.CreateStore(Zero, IRBRet.CreateIntToPtr(AllocaCall.LeftRzAddr, + Int32PtrTy)); + IRBRet.CreateStore(Zero, IRBRet.CreateIntToPtr(PartialRzAddr, + Int32PtrTy)); + IRBRet.CreateStore(Zero, IRBRet.CreateIntToPtr(AllocaCall.RightRzAddr, + Int32PtrTy)); + } + } + + // Right shift for BigEndian and left shift for LittleEndian. + Value *shiftAllocaMagic(Value *Val, IRBuilder<> &IRB, Value *Shift) { + return ASan.DL->isLittleEndian() ? IRB.CreateShl(Val, Shift) + : IRB.CreateLShr(Val, Shift); + } + + // Compute PartialRzMagic for dynamic alloca call. Since we don't know the + // size of requested memory until runtime, we should compute it dynamically. + // If PartialSize is 0, PartialRzMagic would contain kAsanAllocaRightMagic, + // otherwise it would contain the value that we will use to poison the + // partial redzone for alloca call. + Value *computePartialRzMagic(Value *PartialSize, IRBuilder<> &IRB); + + // Deploy and poison redzones around dynamic alloca call. To do this, we + // should replace this call with another one with changed parameters and + // replace all its uses with new address, so + // addr = alloca type, old_size, align + // is replaced by + // new_size = (old_size + additional_size) * sizeof(type) + // tmp = alloca i8, new_size, max(align, 32) + // addr = tmp + 32 (first 32 bytes are for the left redzone). + // Additional_size is added to make new memory allocation contain not only + // requested memory, but also left, partial and right redzones. + // After that, we should poison redzones: + // (1) Left redzone with kAsanAllocaLeftMagic. + // (2) Partial redzone with the value, computed in runtime by + // computePartialRzMagic function. + // (3) Right redzone with kAsanAllocaRightMagic. + void handleDynamicAllocaCall(DynamicAllocaCall &AllocaCall); + /// \brief Collect Alloca instructions we want (and can) handle. void visitAllocaInst(AllocaInst &AI) { if (!isInterestingAlloca(AI)) return; StackAlignment = std::max(StackAlignment, AI.getAlignment()); - AllocaVec.push_back(&AI); + if (isDynamicAlloca(AI)) + DynamicAllocaVec.push_back(DynamicAllocaCall(&AI)); + else + AllocaVec.push_back(&AI); } /// \brief Collect lifetime intrinsic calls to check for use-after-scope @@ -551,13 +630,29 @@ struct FunctionStackPoisoner : public InstVisitor<FunctionStackPoisoner> { AllocaPoisonCallVec.push_back(APC); } + void visitCallInst(CallInst &CI) { + HasNonEmptyInlineAsm |= + CI.isInlineAsm() && !CI.isIdenticalTo(EmptyInlineAsm.get()); + } + // ---------------------- Helpers. void initializeCallbacks(Module &M); + bool doesDominateAllExits(const Instruction *I) const { + for (auto Ret : RetVec) { + if (!ASan.getDominatorTree().dominates(I, Ret)) + return false; + } + return true; + } + + bool isDynamicAlloca(AllocaInst &AI) const { + return AI.isArrayAllocation() || !AI.isStaticAlloca(); + } + // Check if we want (and can) handle this alloca. bool isInterestingAlloca(AllocaInst &AI) const { - return (!AI.isArrayAllocation() && AI.isStaticAlloca() && - AI.getAllocatedType()->isSized() && + return (AI.getAllocatedType()->isSized() && // alloca() may be called with 0 size, ignore it. getAllocaSizeInBytes(&AI) > 0); } @@ -569,18 +664,26 @@ struct FunctionStackPoisoner : public InstVisitor<FunctionStackPoisoner> { } /// Finds alloca where the value comes from. AllocaInst *findAllocaForValue(Value *V); - void poisonRedZones(const ArrayRef<uint8_t> ShadowBytes, IRBuilder<> &IRB, + void poisonRedZones(ArrayRef<uint8_t> ShadowBytes, IRBuilder<> &IRB, Value *ShadowBase, bool DoPoison); void poisonAlloca(Value *V, uint64_t Size, IRBuilder<> &IRB, bool DoPoison); void SetShadowToStackAfterReturnInlined(IRBuilder<> &IRB, Value *ShadowBase, int Size); + Value *createAllocaForLayout(IRBuilder<> &IRB, const ASanStackFrameLayout &L, + bool Dynamic); + PHINode *createPHI(IRBuilder<> &IRB, Value *Cond, Value *ValueIfTrue, + Instruction *ThenTerm, Value *ValueIfFalse); }; } // namespace char AddressSanitizer::ID = 0; -INITIALIZE_PASS(AddressSanitizer, "asan", +INITIALIZE_PASS_BEGIN(AddressSanitizer, "asan", + "AddressSanitizer: detects use-after-free and out-of-bounds bugs.", + false, false) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_END(AddressSanitizer, "asan", "AddressSanitizer: detects use-after-free and out-of-bounds bugs.", false, false) FunctionPass *llvm::createAddressSanitizerFunctionPass() { @@ -616,8 +719,25 @@ static GlobalVariable *createPrivateGlobalForString( return GV; } +/// \brief Create a global describing a source location. +static GlobalVariable *createPrivateGlobalForSourceLoc(Module &M, + LocationMetadata MD) { + Constant *LocData[] = { + createPrivateGlobalForString(M, MD.Filename, true), + ConstantInt::get(Type::getInt32Ty(M.getContext()), MD.LineNo), + ConstantInt::get(Type::getInt32Ty(M.getContext()), MD.ColumnNo), + }; + auto LocStruct = ConstantStruct::getAnon(LocData); + auto GV = new GlobalVariable(M, LocStruct->getType(), true, + GlobalValue::PrivateLinkage, LocStruct, + kAsanGenPrefix); + GV->setUnnamedAddr(true); + return GV; +} + static bool GlobalWasGeneratedByAsan(GlobalVariable *G) { - return G->getName().find(kAsanGenPrefix) == 0; + return G->getName().find(kAsanGenPrefix) == 0 || + G->getName().find(kSanCovGenPrefix) == 0; } Value *AddressSanitizer::memToShadow(Value *Shadow, IRBuilder<> &IRB) { @@ -652,7 +772,7 @@ void AddressSanitizer::instrumentMemIntrinsic(MemIntrinsic *MI) { } // If I is an interesting memory access, return the PointerOperand -// and set IsWrite/Alignment. Otherwise return NULL. +// and set IsWrite/Alignment. Otherwise return nullptr. static Value *isInterestingMemoryAccess(Instruction *I, bool *IsWrite, unsigned *Alignment) { // Skip memory accesses inserted by another instrumentation. @@ -861,8 +981,11 @@ void AddressSanitizer::instrumentAddress(Instruction *OrigIns, TerminatorInst *CrashTerm = nullptr; if (ClAlwaysSlowPath || (TypeSize < 8 * Granularity)) { + // We use branch weights for the slow path check, to indicate that the slow + // path is rarely taken. This seems to be the case for SPEC benchmarks. TerminatorInst *CheckTerm = - SplitBlockAndInsertIfThen(Cmp, InsertBefore, false); + SplitBlockAndInsertIfThen(Cmp, InsertBefore, false, + MDBuilder(*C).createBranchWeights(1, 100000)); assert(dyn_cast<BranchInst>(CheckTerm)->isUnconditional()); BasicBlock *NextBB = CheckTerm->getSuccessor(0); IRB.SetInsertPoint(CheckTerm); @@ -907,10 +1030,12 @@ void AddressSanitizerModule::createInitializerPoisonCalls( ConstantStruct *CS = cast<ConstantStruct>(OP); // Must have a function or null ptr. - // (CS->getOperand(0) is the init priority.) if (Function* F = dyn_cast<Function>(CS->getOperand(1))) { - if (F->getName() != kAsanModuleCtorName) - poisonOneInitializer(*F, ModuleName); + if (F->getName() == kAsanModuleCtorName) continue; + ConstantInt *Priority = dyn_cast<ConstantInt>(CS->getOperand(0)); + // Don't instrument CTORs that will run before asan.module_ctor. + if (Priority->getLimitedValue() <= kAsanCtorAndDtorPriority) continue; + poisonOneInitializer(*F, ModuleName); } } } @@ -920,7 +1045,6 @@ bool AddressSanitizerModule::ShouldInstrumentGlobal(GlobalVariable *G) { DEBUG(dbgs() << "GLOBAL: " << *G << "\n"); if (GlobalsMD.get(G).IsBlacklisted) return false; - if (GlobalsMD.isInstrumentationGlobal(G)) return false; if (!Ty->isSized()) return false; if (!G->hasInitializer()) return false; if (GlobalWasGeneratedByAsan(G)) return false; // Our own global. @@ -941,43 +1065,48 @@ bool AddressSanitizerModule::ShouldInstrumentGlobal(GlobalVariable *G) { // For now, just ignore this Global if the alignment is large. if (G->getAlignment() > MinRedzoneSizeForGlobal()) return false; - // Ignore all the globals with the names starting with "\01L_OBJC_". - // Many of those are put into the .cstring section. The linker compresses - // that section by removing the spare \0s after the string terminator, so - // our redzones get broken. - if ((G->getName().find("\01L_OBJC_") == 0) || - (G->getName().find("\01l_OBJC_") == 0)) { - DEBUG(dbgs() << "Ignoring \\01L_OBJC_* global: " << *G << "\n"); - return false; - } - if (G->hasSection()) { StringRef Section(G->getSection()); - // Ignore the globals from the __OBJC section. The ObjC runtime assumes - // those conform to /usr/lib/objc/runtime.h, so we can't add redzones to - // them. - if (Section.startswith("__OBJC,") || - Section.startswith("__DATA, __objc_")) { - DEBUG(dbgs() << "Ignoring ObjC runtime global: " << *G << "\n"); - return false; - } - // See http://code.google.com/p/address-sanitizer/issues/detail?id=32 - // Constant CFString instances are compiled in the following way: - // -- the string buffer is emitted into - // __TEXT,__cstring,cstring_literals - // -- the constant NSConstantString structure referencing that buffer - // is placed into __DATA,__cfstring - // Therefore there's no point in placing redzones into __DATA,__cfstring. - // Moreover, it causes the linker to crash on OS X 10.7 - if (Section.startswith("__DATA,__cfstring")) { - DEBUG(dbgs() << "Ignoring CFString: " << *G << "\n"); - return false; - } - // The linker merges the contents of cstring_literals and removes the - // trailing zeroes. - if (Section.startswith("__TEXT,__cstring,cstring_literals")) { - DEBUG(dbgs() << "Ignoring a cstring literal: " << *G << "\n"); - return false; + + if (TargetTriple.isOSBinFormatMachO()) { + StringRef ParsedSegment, ParsedSection; + unsigned TAA = 0, StubSize = 0; + bool TAAParsed; + std::string ErrorCode = + MCSectionMachO::ParseSectionSpecifier(Section, ParsedSegment, + ParsedSection, TAA, TAAParsed, + StubSize); + if (!ErrorCode.empty()) { + report_fatal_error("Invalid section specifier '" + ParsedSection + + "': " + ErrorCode + "."); + } + + // Ignore the globals from the __OBJC section. The ObjC runtime assumes + // those conform to /usr/lib/objc/runtime.h, so we can't add redzones to + // them. + if (ParsedSegment == "__OBJC" || + (ParsedSegment == "__DATA" && ParsedSection.startswith("__objc_"))) { + DEBUG(dbgs() << "Ignoring ObjC runtime global: " << *G << "\n"); + return false; + } + // See http://code.google.com/p/address-sanitizer/issues/detail?id=32 + // Constant CFString instances are compiled in the following way: + // -- the string buffer is emitted into + // __TEXT,__cstring,cstring_literals + // -- the constant NSConstantString structure referencing that buffer + // is placed into __DATA,__cfstring + // Therefore there's no point in placing redzones into __DATA,__cfstring. + // Moreover, it causes the linker to crash on OS X 10.7 + if (ParsedSegment == "__DATA" && ParsedSection == "__cfstring") { + DEBUG(dbgs() << "Ignoring CFString: " << *G << "\n"); + return false; + } + // The linker merges the contents of cstring_literals and removes the + // trailing zeroes. + if (ParsedSegment == "__TEXT" && (TAA & MachO::S_CSTRING_LITERALS)) { + DEBUG(dbgs() << "Ignoring a cstring literal: " << *G << "\n"); + return false; + } } // Callbacks put into the CRT initializer/terminator sections @@ -1000,24 +1129,20 @@ void AddressSanitizerModule::initializeCallbacks(Module &M) { IRBuilder<> IRB(*C); // Declare our poisoning and unpoisoning functions. AsanPoisonGlobals = checkInterfaceFunction(M.getOrInsertFunction( - kAsanPoisonGlobalsName, IRB.getVoidTy(), IntptrTy, NULL)); + kAsanPoisonGlobalsName, IRB.getVoidTy(), IntptrTy, nullptr)); AsanPoisonGlobals->setLinkage(Function::ExternalLinkage); AsanUnpoisonGlobals = checkInterfaceFunction(M.getOrInsertFunction( - kAsanUnpoisonGlobalsName, IRB.getVoidTy(), NULL)); + kAsanUnpoisonGlobalsName, IRB.getVoidTy(), nullptr)); AsanUnpoisonGlobals->setLinkage(Function::ExternalLinkage); // Declare functions that register/unregister globals. AsanRegisterGlobals = checkInterfaceFunction(M.getOrInsertFunction( kAsanRegisterGlobalsName, IRB.getVoidTy(), - IntptrTy, IntptrTy, NULL)); + IntptrTy, IntptrTy, nullptr)); AsanRegisterGlobals->setLinkage(Function::ExternalLinkage); AsanUnregisterGlobals = checkInterfaceFunction(M.getOrInsertFunction( kAsanUnregisterGlobalsName, - IRB.getVoidTy(), IntptrTy, IntptrTy, NULL)); + IRB.getVoidTy(), IntptrTy, IntptrTy, nullptr)); AsanUnregisterGlobals->setLinkage(Function::ExternalLinkage); - AsanCovModuleInit = checkInterfaceFunction(M.getOrInsertFunction( - kAsanCovModuleInitName, - IRB.getVoidTy(), IntptrTy, NULL)); - AsanCovModuleInit->setLinkage(Function::ExternalLinkage); } // This function replaces all global variables with new variables that have @@ -1047,7 +1172,7 @@ bool AddressSanitizerModule::InstrumentGlobals(IRBuilder<> &IRB, Module &M) { // We initialize an array of such structures and pass it to a run-time call. StructType *GlobalStructTy = StructType::get(IntptrTy, IntptrTy, IntptrTy, IntptrTy, IntptrTy, - IntptrTy, IntptrTy, NULL); + IntptrTy, IntptrTy, nullptr); SmallVector<Constant *, 16> Initializers(n); bool HasDynamicallyInitializedGlobals = false; @@ -1062,11 +1187,11 @@ bool AddressSanitizerModule::InstrumentGlobals(IRBuilder<> &IRB, Module &M) { GlobalVariable *G = GlobalsToChange[i]; auto MD = GlobalsMD.get(G); - // Create string holding the global name unless it was provided by - // the metadata. - GlobalVariable *Name = - MD.Name ? MD.Name : createPrivateGlobalForString(M, G->getName(), - /*AllowMerging*/ true); + // Create string holding the global name (use global name from metadata + // if it's available, otherwise just write the name of global variable). + GlobalVariable *Name = createPrivateGlobalForString( + M, MD.Name.empty() ? G->getName() : MD.Name, + /*AllowMerging*/ true); PointerType *PtrTy = cast<PointerType>(G->getType()); Type *Ty = PtrTy->getElementType(); @@ -1084,10 +1209,10 @@ bool AddressSanitizerModule::InstrumentGlobals(IRBuilder<> &IRB, Module &M) { assert(((RightRedzoneSize + SizeInBytes) % MinRZ) == 0); Type *RightRedZoneTy = ArrayType::get(IRB.getInt8Ty(), RightRedzoneSize); - StructType *NewTy = StructType::get(Ty, RightRedZoneTy, NULL); + StructType *NewTy = StructType::get(Ty, RightRedZoneTy, nullptr); Constant *NewInitializer = ConstantStruct::get( NewTy, G->getInitializer(), - Constant::getNullValue(RightRedZoneTy), NULL); + Constant::getNullValue(RightRedZoneTy), nullptr); // Create a new global variable with enough space for a redzone. GlobalValue::LinkageTypes Linkage = G->getLinkage(); @@ -1108,16 +1233,21 @@ bool AddressSanitizerModule::InstrumentGlobals(IRBuilder<> &IRB, Module &M) { NewGlobal->takeName(G); G->eraseFromParent(); + Constant *SourceLoc; + if (!MD.SourceLoc.empty()) { + auto SourceLocGlobal = createPrivateGlobalForSourceLoc(M, MD.SourceLoc); + SourceLoc = ConstantExpr::getPointerCast(SourceLocGlobal, IntptrTy); + } else { + SourceLoc = ConstantInt::get(IntptrTy, 0); + } + Initializers[i] = ConstantStruct::get( GlobalStructTy, ConstantExpr::getPointerCast(NewGlobal, IntptrTy), ConstantInt::get(IntptrTy, SizeInBytes), ConstantInt::get(IntptrTy, SizeInBytes + RightRedzoneSize), ConstantExpr::getPointerCast(Name, IntptrTy), ConstantExpr::getPointerCast(ModuleName, IntptrTy), - ConstantInt::get(IntptrTy, MD.IsDynInit), - MD.SourceLoc ? ConstantExpr::getPointerCast(MD.SourceLoc, IntptrTy) - : ConstantInt::get(IntptrTy, 0), - NULL); + ConstantInt::get(IntptrTy, MD.IsDynInit), SourceLoc, nullptr); if (ClInitializers && MD.IsDynInit) HasDynamicallyInitializedGlobals = true; @@ -1161,7 +1291,8 @@ bool AddressSanitizerModule::runOnModule(Module &M) { C = &(M.getContext()); int LongSize = DL->getPointerSizeInBits(); IntptrTy = Type::getIntNTy(*C, LongSize); - Mapping = getShadowMapping(M, LongSize); + TargetTriple = Triple(M.getTargetTriple()); + Mapping = getShadowMapping(TargetTriple, LongSize); initializeCallbacks(M); bool Changed = false; @@ -1170,13 +1301,6 @@ bool AddressSanitizerModule::runOnModule(Module &M) { assert(CtorFunc); IRBuilder<> IRB(CtorFunc->getEntryBlock().getTerminator()); - if (ClCoverage > 0) { - Function *CovFunc = M.getFunction(kAsanCovName); - int nCov = CovFunc ? CovFunc->getNumUses() : 0; - IRB.CreateCall(AsanCovModuleInit, ConstantInt::get(IntptrTy, nCov)); - Changed = true; - } - if (ClGlobals) Changed |= InstrumentGlobals(IRB, M); @@ -1195,43 +1319,42 @@ void AddressSanitizer::initializeCallbacks(Module &M) { AsanErrorCallback[AccessIsWrite][AccessSizeIndex] = checkInterfaceFunction( M.getOrInsertFunction(kAsanReportErrorTemplate + Suffix, - IRB.getVoidTy(), IntptrTy, NULL)); + IRB.getVoidTy(), IntptrTy, nullptr)); AsanMemoryAccessCallback[AccessIsWrite][AccessSizeIndex] = checkInterfaceFunction( M.getOrInsertFunction(ClMemoryAccessCallbackPrefix + Suffix, - IRB.getVoidTy(), IntptrTy, NULL)); + IRB.getVoidTy(), IntptrTy, nullptr)); } } AsanErrorCallbackSized[0] = checkInterfaceFunction(M.getOrInsertFunction( - kAsanReportLoadN, IRB.getVoidTy(), IntptrTy, IntptrTy, NULL)); + kAsanReportLoadN, IRB.getVoidTy(), IntptrTy, IntptrTy, nullptr)); AsanErrorCallbackSized[1] = checkInterfaceFunction(M.getOrInsertFunction( - kAsanReportStoreN, IRB.getVoidTy(), IntptrTy, IntptrTy, NULL)); + kAsanReportStoreN, IRB.getVoidTy(), IntptrTy, IntptrTy, nullptr)); AsanMemoryAccessCallbackSized[0] = checkInterfaceFunction( M.getOrInsertFunction(ClMemoryAccessCallbackPrefix + "loadN", - IRB.getVoidTy(), IntptrTy, IntptrTy, NULL)); + IRB.getVoidTy(), IntptrTy, IntptrTy, nullptr)); AsanMemoryAccessCallbackSized[1] = checkInterfaceFunction( M.getOrInsertFunction(ClMemoryAccessCallbackPrefix + "storeN", - IRB.getVoidTy(), IntptrTy, IntptrTy, NULL)); + IRB.getVoidTy(), IntptrTy, IntptrTy, nullptr)); AsanMemmove = checkInterfaceFunction(M.getOrInsertFunction( ClMemoryAccessCallbackPrefix + "memmove", IRB.getInt8PtrTy(), - IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IntptrTy, NULL)); + IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IntptrTy, nullptr)); AsanMemcpy = checkInterfaceFunction(M.getOrInsertFunction( ClMemoryAccessCallbackPrefix + "memcpy", IRB.getInt8PtrTy(), - IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IntptrTy, NULL)); + IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IntptrTy, nullptr)); AsanMemset = checkInterfaceFunction(M.getOrInsertFunction( ClMemoryAccessCallbackPrefix + "memset", IRB.getInt8PtrTy(), - IRB.getInt8PtrTy(), IRB.getInt32Ty(), IntptrTy, NULL)); + IRB.getInt8PtrTy(), IRB.getInt32Ty(), IntptrTy, nullptr)); AsanHandleNoReturnFunc = checkInterfaceFunction( - M.getOrInsertFunction(kAsanHandleNoReturnName, IRB.getVoidTy(), NULL)); - AsanCovFunction = checkInterfaceFunction(M.getOrInsertFunction( - kAsanCovName, IRB.getVoidTy(), NULL)); + M.getOrInsertFunction(kAsanHandleNoReturnName, IRB.getVoidTy(), nullptr)); + AsanPtrCmpFunction = checkInterfaceFunction(M.getOrInsertFunction( - kAsanPtrCmp, IRB.getVoidTy(), IntptrTy, IntptrTy, NULL)); + kAsanPtrCmp, IRB.getVoidTy(), IntptrTy, IntptrTy, nullptr)); AsanPtrSubFunction = checkInterfaceFunction(M.getOrInsertFunction( - kAsanPtrSub, IRB.getVoidTy(), IntptrTy, IntptrTy, NULL)); + kAsanPtrSub, IRB.getVoidTy(), IntptrTy, IntptrTy, nullptr)); // We insert an empty inline asm after __asan_report* to avoid callback merge. EmptyAsm = InlineAsm::get(FunctionType::get(IRB.getVoidTy(), false), StringRef(""), StringRef(""), @@ -1251,6 +1374,7 @@ bool AddressSanitizer::doInitialization(Module &M) { C = &(M.getContext()); LongSize = DL->getPointerSizeInBits(); IntptrTy = Type::getIntNTy(*C, LongSize); + TargetTriple = Triple(M.getTargetTriple()); AsanCtorFunction = Function::Create( FunctionType::get(Type::getVoidTy(*C), false), @@ -1259,11 +1383,11 @@ bool AddressSanitizer::doInitialization(Module &M) { // call __asan_init in the module ctor. IRBuilder<> IRB(ReturnInst::Create(*C, AsanCtorBB)); AsanInitFunction = checkInterfaceFunction( - M.getOrInsertFunction(kAsanInitName, IRB.getVoidTy(), NULL)); + M.getOrInsertFunction(kAsanInitName, IRB.getVoidTy(), nullptr)); AsanInitFunction->setLinkage(Function::ExternalLinkage); IRB.CreateCall(AsanInitFunction); - Mapping = getShadowMapping(M, LongSize); + Mapping = getShadowMapping(TargetTriple, LongSize); appendToGlobalCtors(M, AsanCtorFunction, kAsanCtorAndDtorPriority); return true; @@ -1285,80 +1409,14 @@ bool AddressSanitizer::maybeInsertAsanInitAtFunctionEntry(Function &F) { return false; } -void AddressSanitizer::InjectCoverageAtBlock(Function &F, BasicBlock &BB) { - BasicBlock::iterator IP = BB.getFirstInsertionPt(), BE = BB.end(); - // Skip static allocas at the top of the entry block so they don't become - // dynamic when we split the block. If we used our optimized stack layout, - // then there will only be one alloca and it will come first. - for (; IP != BE; ++IP) { - AllocaInst *AI = dyn_cast<AllocaInst>(IP); - if (!AI || !AI->isStaticAlloca()) - break; - } - - DebugLoc EntryLoc = IP->getDebugLoc().getFnDebugLoc(*C); - IRBuilder<> IRB(IP); - IRB.SetCurrentDebugLocation(EntryLoc); - Type *Int8Ty = IRB.getInt8Ty(); - GlobalVariable *Guard = new GlobalVariable( - *F.getParent(), Int8Ty, false, GlobalValue::PrivateLinkage, - Constant::getNullValue(Int8Ty), "__asan_gen_cov_" + F.getName()); - LoadInst *Load = IRB.CreateLoad(Guard); - Load->setAtomic(Monotonic); - Load->setAlignment(1); - Value *Cmp = IRB.CreateICmpEQ(Constant::getNullValue(Int8Ty), Load); - Instruction *Ins = SplitBlockAndInsertIfThen( - Cmp, IP, false, MDBuilder(*C).createBranchWeights(1, 100000)); - IRB.SetInsertPoint(Ins); - IRB.SetCurrentDebugLocation(EntryLoc); - // We pass &F to __sanitizer_cov. We could avoid this and rely on - // GET_CALLER_PC, but having the PC of the first instruction is just nice. - IRB.CreateCall(AsanCovFunction); - StoreInst *Store = IRB.CreateStore(ConstantInt::get(Int8Ty, 1), Guard); - Store->setAtomic(Monotonic); - Store->setAlignment(1); -} - -// Poor man's coverage that works with ASan. -// We create a Guard boolean variable with the same linkage -// as the function and inject this code into the entry block (-asan-coverage=1) -// or all blocks (-asan-coverage=2): -// if (*Guard) { -// __sanitizer_cov(&F); -// *Guard = 1; -// } -// The accesses to Guard are atomic. The rest of the logic is -// in __sanitizer_cov (it's fine to call it more than once). -// -// This coverage implementation provides very limited data: -// it only tells if a given function (block) was ever executed. -// No counters, no per-edge data. -// But for many use cases this is what we need and the added slowdown -// is negligible. This simple implementation will probably be obsoleted -// by the upcoming Clang-based coverage implementation. -// By having it here and now we hope to -// a) get the functionality to users earlier and -// b) collect usage statistics to help improve Clang coverage design. -bool AddressSanitizer::InjectCoverage(Function &F, - const ArrayRef<BasicBlock *> AllBlocks) { - if (!ClCoverage) return false; - - if (ClCoverage == 1 || - (unsigned)ClCoverageBlockThreshold < AllBlocks.size()) { - InjectCoverageAtBlock(F, F.getEntryBlock()); - } else { - for (auto BB : AllBlocks) - InjectCoverageAtBlock(F, *BB); - } - return true; -} - bool AddressSanitizer::runOnFunction(Function &F) { if (&F == AsanCtorFunction) return false; if (F.getLinkage() == GlobalValue::AvailableExternallyLinkage) return false; DEBUG(dbgs() << "ASAN instrumenting:\n" << F << "\n"); initializeCallbacks(*F.getParent()); + DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + // If needed, insert __asan_init before checking for SanitizeAddress attr. maybeInsertAsanInitAtFunctionEntry(F); @@ -1389,7 +1447,7 @@ bool AddressSanitizer::runOnFunction(Function &F) { if (Value *Addr = isInterestingMemoryAccess(&Inst, &IsWrite, &Alignment)) { if (ClOpt && ClOptSameTemp) { - if (!TempsToInstrument.insert(Addr)) + if (!TempsToInstrument.insert(Addr).second) continue; // We've seen this temp in the current BB. } } else if (ClInvalidPointerPairs && @@ -1417,17 +1475,6 @@ bool AddressSanitizer::runOnFunction(Function &F) { } } - Function *UninstrumentedDuplicate = nullptr; - bool LikelyToInstrument = - !NoReturnCalls.empty() || !ToInstrument.empty() || (NumAllocas > 0); - if (ClKeepUninstrumented && LikelyToInstrument) { - ValueToValueMapTy VMap; - UninstrumentedDuplicate = CloneFunction(&F, VMap, false); - UninstrumentedDuplicate->removeFnAttr(Attribute::SanitizeAddress); - UninstrumentedDuplicate->setName("NOASAN_" + F.getName()); - F.getParent()->getFunctionList().push_back(UninstrumentedDuplicate); - } - bool UseCalls = false; if (ClInstrumentationWithCallsThreshold >= 0 && ToInstrument.size() > (unsigned)ClInstrumentationWithCallsThreshold) @@ -1463,25 +1510,8 @@ bool AddressSanitizer::runOnFunction(Function &F) { bool res = NumInstrumented > 0 || ChangedStack || !NoReturnCalls.empty(); - if (InjectCoverage(F, AllBlocks)) - res = true; - DEBUG(dbgs() << "ASAN done instrumenting: " << res << " " << F << "\n"); - if (ClKeepUninstrumented) { - if (!res) { - // No instrumentation is done, no need for the duplicate. - if (UninstrumentedDuplicate) - UninstrumentedDuplicate->eraseFromParent(); - } else { - // The function was instrumented. We must have the duplicate. - assert(UninstrumentedDuplicate); - UninstrumentedDuplicate->setSection("NOASAN"); - assert(!F.hasSection()); - F.setSection("ASAN"); - } - } - return res; } @@ -1501,21 +1531,22 @@ void FunctionStackPoisoner::initializeCallbacks(Module &M) { IRBuilder<> IRB(*C); for (int i = 0; i <= kMaxAsanStackMallocSizeClass; i++) { std::string Suffix = itostr(i); - AsanStackMallocFunc[i] = checkInterfaceFunction( - M.getOrInsertFunction(kAsanStackMallocNameTemplate + Suffix, IntptrTy, - IntptrTy, IntptrTy, NULL)); - AsanStackFreeFunc[i] = checkInterfaceFunction(M.getOrInsertFunction( - kAsanStackFreeNameTemplate + Suffix, IRB.getVoidTy(), IntptrTy, - IntptrTy, IntptrTy, NULL)); + AsanStackMallocFunc[i] = checkInterfaceFunction(M.getOrInsertFunction( + kAsanStackMallocNameTemplate + Suffix, IntptrTy, IntptrTy, nullptr)); + AsanStackFreeFunc[i] = checkInterfaceFunction( + M.getOrInsertFunction(kAsanStackFreeNameTemplate + Suffix, + IRB.getVoidTy(), IntptrTy, IntptrTy, nullptr)); } - AsanPoisonStackMemoryFunc = checkInterfaceFunction(M.getOrInsertFunction( - kAsanPoisonStackMemoryName, IRB.getVoidTy(), IntptrTy, IntptrTy, NULL)); - AsanUnpoisonStackMemoryFunc = checkInterfaceFunction(M.getOrInsertFunction( - kAsanUnpoisonStackMemoryName, IRB.getVoidTy(), IntptrTy, IntptrTy, NULL)); + AsanPoisonStackMemoryFunc = checkInterfaceFunction( + M.getOrInsertFunction(kAsanPoisonStackMemoryName, IRB.getVoidTy(), + IntptrTy, IntptrTy, nullptr)); + AsanUnpoisonStackMemoryFunc = checkInterfaceFunction( + M.getOrInsertFunction(kAsanUnpoisonStackMemoryName, IRB.getVoidTy(), + IntptrTy, IntptrTy, nullptr)); } void -FunctionStackPoisoner::poisonRedZones(const ArrayRef<uint8_t> ShadowBytes, +FunctionStackPoisoner::poisonRedZones(ArrayRef<uint8_t> ShadowBytes, IRBuilder<> &IRB, Value *ShadowBase, bool DoPoison) { size_t n = ShadowBytes.size(); @@ -1576,11 +1607,49 @@ static DebugLoc getFunctionEntryDebugLocation(Function &F) { return DebugLoc(); } +PHINode *FunctionStackPoisoner::createPHI(IRBuilder<> &IRB, Value *Cond, + Value *ValueIfTrue, + Instruction *ThenTerm, + Value *ValueIfFalse) { + PHINode *PHI = IRB.CreatePHI(IntptrTy, 2); + BasicBlock *CondBlock = cast<Instruction>(Cond)->getParent(); + PHI->addIncoming(ValueIfFalse, CondBlock); + BasicBlock *ThenBlock = ThenTerm->getParent(); + PHI->addIncoming(ValueIfTrue, ThenBlock); + return PHI; +} + +Value *FunctionStackPoisoner::createAllocaForLayout( + IRBuilder<> &IRB, const ASanStackFrameLayout &L, bool Dynamic) { + AllocaInst *Alloca; + if (Dynamic) { + Alloca = IRB.CreateAlloca(IRB.getInt8Ty(), + ConstantInt::get(IRB.getInt64Ty(), L.FrameSize), + "MyAlloca"); + } else { + Alloca = IRB.CreateAlloca(ArrayType::get(IRB.getInt8Ty(), L.FrameSize), + nullptr, "MyAlloca"); + assert(Alloca->isStaticAlloca()); + } + assert((ClRealignStack & (ClRealignStack - 1)) == 0); + size_t FrameAlignment = std::max(L.FrameAlignment, (size_t)ClRealignStack); + Alloca->setAlignment(FrameAlignment); + return IRB.CreatePointerCast(Alloca, IntptrTy); +} + void FunctionStackPoisoner::poisonStack() { + assert(AllocaVec.size() > 0 || DynamicAllocaVec.size() > 0); + + if (ClInstrumentAllocas) + // Handle dynamic allocas. + for (auto &AllocaCall : DynamicAllocaVec) + handleDynamicAllocaCall(AllocaCall); + + if (AllocaVec.size() == 0) return; + int StackMallocIdx = -1; DebugLoc EntryDebugLocation = getFunctionEntryDebugLocation(F); - assert(AllocaVec.size() > 0); Instruction *InsBefore = AllocaVec[0]; IRBuilder<> IRB(InsBefore); IRB.SetCurrentDebugLocation(EntryDebugLocation); @@ -1602,42 +1671,56 @@ void FunctionStackPoisoner::poisonStack() { uint64_t LocalStackSize = L.FrameSize; bool DoStackMalloc = ClUseAfterReturn && LocalStackSize <= kMaxStackMallocSize; + // Don't do dynamic alloca in presence of inline asm: too often it + // makes assumptions on which registers are available. + bool DoDynamicAlloca = ClDynamicAllocaStack && !HasNonEmptyInlineAsm; - Type *ByteArrayTy = ArrayType::get(IRB.getInt8Ty(), LocalStackSize); - AllocaInst *MyAlloca = - new AllocaInst(ByteArrayTy, "MyAlloca", InsBefore); - MyAlloca->setDebugLoc(EntryDebugLocation); - assert((ClRealignStack & (ClRealignStack - 1)) == 0); - size_t FrameAlignment = std::max(L.FrameAlignment, (size_t)ClRealignStack); - MyAlloca->setAlignment(FrameAlignment); - assert(MyAlloca->isStaticAlloca()); - Value *OrigStackBase = IRB.CreatePointerCast(MyAlloca, IntptrTy); - Value *LocalStackBase = OrigStackBase; + Value *StaticAlloca = + DoDynamicAlloca ? nullptr : createAllocaForLayout(IRB, L, false); + + Value *FakeStack; + Value *LocalStackBase; if (DoStackMalloc) { - // LocalStackBase = OrigStackBase - // if (__asan_option_detect_stack_use_after_return) - // LocalStackBase = __asan_stack_malloc_N(LocalStackBase, OrigStackBase); - StackMallocIdx = StackMallocSizeClass(LocalStackSize); - assert(StackMallocIdx <= kMaxAsanStackMallocSizeClass); + // void *FakeStack = __asan_option_detect_stack_use_after_return + // ? __asan_stack_malloc_N(LocalStackSize) + // : nullptr; + // void *LocalStackBase = (FakeStack) ? FakeStack : alloca(LocalStackSize); Constant *OptionDetectUAR = F.getParent()->getOrInsertGlobal( kAsanOptionDetectUAR, IRB.getInt32Ty()); - Value *Cmp = IRB.CreateICmpNE(IRB.CreateLoad(OptionDetectUAR), - Constant::getNullValue(IRB.getInt32Ty())); - Instruction *Term = SplitBlockAndInsertIfThen(Cmp, InsBefore, false); - BasicBlock *CmpBlock = cast<Instruction>(Cmp)->getParent(); + Value *UARIsEnabled = + IRB.CreateICmpNE(IRB.CreateLoad(OptionDetectUAR), + Constant::getNullValue(IRB.getInt32Ty())); + Instruction *Term = + SplitBlockAndInsertIfThen(UARIsEnabled, InsBefore, false); IRBuilder<> IRBIf(Term); IRBIf.SetCurrentDebugLocation(EntryDebugLocation); - LocalStackBase = IRBIf.CreateCall2( - AsanStackMallocFunc[StackMallocIdx], - ConstantInt::get(IntptrTy, LocalStackSize), OrigStackBase); - BasicBlock *SetBlock = cast<Instruction>(LocalStackBase)->getParent(); + StackMallocIdx = StackMallocSizeClass(LocalStackSize); + assert(StackMallocIdx <= kMaxAsanStackMallocSizeClass); + Value *FakeStackValue = + IRBIf.CreateCall(AsanStackMallocFunc[StackMallocIdx], + ConstantInt::get(IntptrTy, LocalStackSize)); + IRB.SetInsertPoint(InsBefore); + IRB.SetCurrentDebugLocation(EntryDebugLocation); + FakeStack = createPHI(IRB, UARIsEnabled, FakeStackValue, Term, + ConstantInt::get(IntptrTy, 0)); + + Value *NoFakeStack = + IRB.CreateICmpEQ(FakeStack, Constant::getNullValue(IntptrTy)); + Term = SplitBlockAndInsertIfThen(NoFakeStack, InsBefore, false); + IRBIf.SetInsertPoint(Term); + IRBIf.SetCurrentDebugLocation(EntryDebugLocation); + Value *AllocaValue = + DoDynamicAlloca ? createAllocaForLayout(IRBIf, L, true) : StaticAlloca; IRB.SetInsertPoint(InsBefore); IRB.SetCurrentDebugLocation(EntryDebugLocation); - PHINode *Phi = IRB.CreatePHI(IntptrTy, 2); - Phi->addIncoming(OrigStackBase, CmpBlock); - Phi->addIncoming(LocalStackBase, SetBlock); - LocalStackBase = Phi; + LocalStackBase = createPHI(IRB, NoFakeStack, AllocaValue, Term, FakeStack); + } else { + // void *FakeStack = nullptr; + // void *LocalStackBase = alloca(LocalStackSize); + FakeStack = ConstantInt::get(IntptrTy, 0); + LocalStackBase = + DoDynamicAlloca ? createAllocaForLayout(IRB, L, true) : StaticAlloca; } // Insert poison calls for lifetime intrinsics for alloca. @@ -1694,17 +1777,18 @@ void FunctionStackPoisoner::poisonStack() { BasePlus0); if (DoStackMalloc) { assert(StackMallocIdx >= 0); - // if LocalStackBase != OrigStackBase: + // if FakeStack != 0 // LocalStackBase == FakeStack // // In use-after-return mode, poison the whole stack frame. // if StackMallocIdx <= 4 // // For small sizes inline the whole thing: // memset(ShadowBase, kAsanStackAfterReturnMagic, ShadowSize); - // **SavedFlagPtr(LocalStackBase) = 0 + // **SavedFlagPtr(FakeStack) = 0 // else - // __asan_stack_free_N(LocalStackBase, OrigStackBase) + // __asan_stack_free_N(FakeStack, LocalStackSize) // else // <This is not a fake stack; unpoison the redzones> - Value *Cmp = IRBRet.CreateICmpNE(LocalStackBase, OrigStackBase); + Value *Cmp = + IRBRet.CreateICmpNE(FakeStack, Constant::getNullValue(IntptrTy)); TerminatorInst *ThenTerm, *ElseTerm; SplitBlockAndInsertIfThenElse(Cmp, Ret, &ThenTerm, &ElseTerm); @@ -1714,7 +1798,7 @@ void FunctionStackPoisoner::poisonStack() { SetShadowToStackAfterReturnInlined(IRBPoison, ShadowBase, ClassSize >> Mapping.Scale); Value *SavedFlagPtrPtr = IRBPoison.CreateAdd( - LocalStackBase, + FakeStack, ConstantInt::get(IntptrTy, ClassSize - ASan.LongSize / 8)); Value *SavedFlagPtr = IRBPoison.CreateLoad( IRBPoison.CreateIntToPtr(SavedFlagPtrPtr, IntptrPtrTy)); @@ -1723,9 +1807,8 @@ void FunctionStackPoisoner::poisonStack() { IRBPoison.CreateIntToPtr(SavedFlagPtr, IRBPoison.getInt8PtrTy())); } else { // For larger frames call __asan_stack_free_*. - IRBPoison.CreateCall3(AsanStackFreeFunc[StackMallocIdx], LocalStackBase, - ConstantInt::get(IntptrTy, LocalStackSize), - OrigStackBase); + IRBPoison.CreateCall2(AsanStackFreeFunc[StackMallocIdx], FakeStack, + ConstantInt::get(IntptrTy, LocalStackSize)); } IRBuilder<> IRBElse(ElseTerm); @@ -1733,13 +1816,17 @@ void FunctionStackPoisoner::poisonStack() { } else if (HavePoisonedAllocas) { // If we poisoned some allocas in llvm.lifetime analysis, // unpoison whole stack frame now. - assert(LocalStackBase == OrigStackBase); poisonAlloca(LocalStackBase, LocalStackSize, IRBRet, false); } else { poisonRedZones(L.ShadowBytes, IRBRet, ShadowBase, false); } } + if (ClInstrumentAllocas) + // Unpoison dynamic allocas. + for (auto &AllocaCall : DynamicAllocaVec) + unpoisonDynamicAlloca(AllocaCall); + // We are done. Remove the old unused alloca instructions. for (auto AI : AllocaVec) AI->eraseFromParent(); @@ -1795,3 +1882,140 @@ AllocaInst *FunctionStackPoisoner::findAllocaForValue(Value *V) { AllocaForValue[V] = Res; return Res; } + +// Compute PartialRzMagic for dynamic alloca call. PartialRzMagic is +// constructed from two separate 32-bit numbers: PartialRzMagic = Val1 | Val2. +// (1) Val1 is resposible for forming base value for PartialRzMagic, containing +// only 00 for fully addressable and 0xcb for fully poisoned bytes for each +// 8-byte chunk of user memory respectively. +// (2) Val2 forms the value for marking first poisoned byte in shadow memory +// with appropriate value (0x01 - 0x07 or 0xcb if Padding % 8 == 0). + +// Shift = Padding & ~7; // the number of bits we need to shift to access first +// chunk in shadow memory, containing nonzero bytes. +// Example: +// Padding = 21 Padding = 16 +// Shadow: |00|00|05|cb| Shadow: |00|00|cb|cb| +// ^ ^ +// | | +// Shift = 21 & ~7 = 16 Shift = 16 & ~7 = 16 +// +// Val1 = 0xcbcbcbcb << Shift; +// PartialBits = Padding ? Padding & 7 : 0xcb; +// Val2 = PartialBits << Shift; +// Result = Val1 | Val2; +Value *FunctionStackPoisoner::computePartialRzMagic(Value *PartialSize, + IRBuilder<> &IRB) { + PartialSize = IRB.CreateIntCast(PartialSize, IRB.getInt32Ty(), false); + Value *Shift = IRB.CreateAnd(PartialSize, IRB.getInt32(~7)); + unsigned Val1Int = kAsanAllocaPartialVal1; + unsigned Val2Int = kAsanAllocaPartialVal2; + if (!ASan.DL->isLittleEndian()) { + Val1Int = sys::getSwappedBytes(Val1Int); + Val2Int = sys::getSwappedBytes(Val2Int); + } + Value *Val1 = shiftAllocaMagic(IRB.getInt32(Val1Int), IRB, Shift); + Value *PartialBits = IRB.CreateAnd(PartialSize, IRB.getInt32(7)); + // For BigEndian get 0x000000YZ -> 0xYZ000000. + if (ASan.DL->isBigEndian()) + PartialBits = IRB.CreateShl(PartialBits, IRB.getInt32(24)); + Value *Val2 = IRB.getInt32(Val2Int); + Value *Cond = + IRB.CreateICmpNE(PartialBits, Constant::getNullValue(IRB.getInt32Ty())); + Val2 = IRB.CreateSelect(Cond, shiftAllocaMagic(PartialBits, IRB, Shift), + shiftAllocaMagic(Val2, IRB, Shift)); + return IRB.CreateOr(Val1, Val2); +} + +void FunctionStackPoisoner::handleDynamicAllocaCall( + DynamicAllocaCall &AllocaCall) { + AllocaInst *AI = AllocaCall.AI; + if (!doesDominateAllExits(AI)) { + // We do not yet handle complex allocas + AllocaCall.Poison = false; + return; + } + + IRBuilder<> IRB(AI); + + PointerType *Int32PtrTy = PointerType::getUnqual(IRB.getInt32Ty()); + const unsigned Align = std::max(kAllocaRzSize, AI->getAlignment()); + const uint64_t AllocaRedzoneMask = kAllocaRzSize - 1; + + Value *Zero = Constant::getNullValue(IntptrTy); + Value *AllocaRzSize = ConstantInt::get(IntptrTy, kAllocaRzSize); + Value *AllocaRzMask = ConstantInt::get(IntptrTy, AllocaRedzoneMask); + Value *NotAllocaRzMask = ConstantInt::get(IntptrTy, ~AllocaRedzoneMask); + + // Since we need to extend alloca with additional memory to locate + // redzones, and OldSize is number of allocated blocks with + // ElementSize size, get allocated memory size in bytes by + // OldSize * ElementSize. + unsigned ElementSize = ASan.DL->getTypeAllocSize(AI->getAllocatedType()); + Value *OldSize = IRB.CreateMul(AI->getArraySize(), + ConstantInt::get(IntptrTy, ElementSize)); + + // PartialSize = OldSize % 32 + Value *PartialSize = IRB.CreateAnd(OldSize, AllocaRzMask); + + // Misalign = kAllocaRzSize - PartialSize; + Value *Misalign = IRB.CreateSub(AllocaRzSize, PartialSize); + + // PartialPadding = Misalign != kAllocaRzSize ? Misalign : 0; + Value *Cond = IRB.CreateICmpNE(Misalign, AllocaRzSize); + Value *PartialPadding = IRB.CreateSelect(Cond, Misalign, Zero); + + // AdditionalChunkSize = Align + PartialPadding + kAllocaRzSize + // Align is added to locate left redzone, PartialPadding for possible + // partial redzone and kAllocaRzSize for right redzone respectively. + Value *AdditionalChunkSize = IRB.CreateAdd( + ConstantInt::get(IntptrTy, Align + kAllocaRzSize), PartialPadding); + + Value *NewSize = IRB.CreateAdd(OldSize, AdditionalChunkSize); + + // Insert new alloca with new NewSize and Align params. + AllocaInst *NewAlloca = IRB.CreateAlloca(IRB.getInt8Ty(), NewSize); + NewAlloca->setAlignment(Align); + + // NewAddress = Address + Align + Value *NewAddress = IRB.CreateAdd(IRB.CreatePtrToInt(NewAlloca, IntptrTy), + ConstantInt::get(IntptrTy, Align)); + + Value *NewAddressPtr = IRB.CreateIntToPtr(NewAddress, AI->getType()); + + // LeftRzAddress = NewAddress - kAllocaRzSize + Value *LeftRzAddress = IRB.CreateSub(NewAddress, AllocaRzSize); + + // Poisoning left redzone. + AllocaCall.LeftRzAddr = ASan.memToShadow(LeftRzAddress, IRB); + IRB.CreateStore(ConstantInt::get(IRB.getInt32Ty(), kAsanAllocaLeftMagic), + IRB.CreateIntToPtr(AllocaCall.LeftRzAddr, Int32PtrTy)); + + // PartialRzAligned = PartialRzAddr & ~AllocaRzMask + Value *PartialRzAddr = IRB.CreateAdd(NewAddress, OldSize); + Value *PartialRzAligned = IRB.CreateAnd(PartialRzAddr, NotAllocaRzMask); + + // Poisoning partial redzone. + Value *PartialRzMagic = computePartialRzMagic(PartialSize, IRB); + Value *PartialRzShadowAddr = ASan.memToShadow(PartialRzAligned, IRB); + IRB.CreateStore(PartialRzMagic, + IRB.CreateIntToPtr(PartialRzShadowAddr, Int32PtrTy)); + + // RightRzAddress + // = (PartialRzAddr + AllocaRzMask) & ~AllocaRzMask + Value *RightRzAddress = IRB.CreateAnd( + IRB.CreateAdd(PartialRzAddr, AllocaRzMask), NotAllocaRzMask); + + // Poisoning right redzone. + AllocaCall.RightRzAddr = ASan.memToShadow(RightRzAddress, IRB); + IRB.CreateStore(ConstantInt::get(IRB.getInt32Ty(), kAsanAllocaRightMagic), + IRB.CreateIntToPtr(AllocaCall.RightRzAddr, Int32PtrTy)); + + // Replace all uses of AddessReturnedByAlloca with NewAddress. + AI->replaceAllUsesWith(NewAddressPtr); + + // We are done. Erase old alloca and store left, partial and right redzones + // shadow addresses for future unpoisoning. + AI->eraseFromParent(); + NumInstrumentedDynamicAllocas++; +} diff --git a/lib/Transforms/Instrumentation/CMakeLists.txt b/lib/Transforms/Instrumentation/CMakeLists.txt index 35635934b81c..92e1091aa3b1 100644 --- a/lib/Transforms/Instrumentation/CMakeLists.txt +++ b/lib/Transforms/Instrumentation/CMakeLists.txt @@ -2,10 +2,11 @@ add_llvm_library(LLVMInstrumentation AddressSanitizer.cpp BoundsChecking.cpp DataFlowSanitizer.cpp - DebugIR.cpp GCOVProfiling.cpp MemorySanitizer.cpp Instrumentation.cpp + InstrProfiling.cpp + SanitizerCoverage.cpp ThreadSanitizer.cpp ) diff --git a/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp b/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp index 35057cdd47e9..8f24476f03c1 100644 --- a/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp +++ b/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp @@ -49,8 +49,10 @@ #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/Triple.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Dominators.h" +#include "llvm/IR/DebugInfo.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InlineAsm.h" #include "llvm/IR/InstVisitor.h" @@ -139,11 +141,11 @@ class DFSanABIList { std::unique_ptr<SpecialCaseList> SCL; public: - DFSanABIList(SpecialCaseList *SCL) : SCL(SCL) {} + DFSanABIList(std::unique_ptr<SpecialCaseList> SCL) : SCL(std::move(SCL)) {} /// Returns whether either this function or its source file are listed in the /// given category. - bool isIn(const Function &F, const StringRef Category) const { + bool isIn(const Function &F, StringRef Category) const { return isIn(*F.getParent(), Category) || SCL->inSection("fun", F.getName(), Category); } @@ -152,7 +154,7 @@ class DFSanABIList { /// /// If GA aliases a function, the alias's name is matched as a function name /// would be. Similarly, aliases of globals are matched like globals. - bool isIn(const GlobalAlias &GA, const StringRef Category) const { + bool isIn(const GlobalAlias &GA, StringRef Category) const { if (isIn(*GA.getParent(), Category)) return true; @@ -164,7 +166,7 @@ class DFSanABIList { } /// Returns whether this module is listed in the given category. - bool isIn(const Module &M, const StringRef Category) const { + bool isIn(const Module &M, StringRef Category) const { return SCL->inSection("src", M.getModuleIdentifier(), Category); } }; @@ -233,15 +235,19 @@ class DataFlowSanitizer : public ModulePass { FunctionType *DFSanUnimplementedFnTy; FunctionType *DFSanSetLabelFnTy; FunctionType *DFSanNonzeroLabelFnTy; + FunctionType *DFSanVarargWrapperFnTy; Constant *DFSanUnionFn; + Constant *DFSanCheckedUnionFn; Constant *DFSanUnionLoadFn; Constant *DFSanUnimplementedFn; Constant *DFSanSetLabelFn; Constant *DFSanNonzeroLabelFn; + Constant *DFSanVarargWrapperFn; MDNode *ColdCallWeights; DFSanABIList ABIList; DenseMap<Value *, Function *> UnwrappedFnMap; AttributeSet ReadOnlyNoneAttrs; + DenseMap<const Function *, DISubprogram> FunctionDIs; Value *getShadowAddress(Value *Addr, Instruction *Pos); bool isInstrumented(const Function *F); @@ -279,7 +285,8 @@ struct DFSanFunction { DenseMap<AllocaInst *, AllocaInst *> AllocaShadowMap; std::vector<std::pair<PHINode *, PHINode *> > PHIFixups; DenseSet<Instruction *> SkipInsts; - DenseSet<Value *> NonZeroChecks; + std::vector<Value *> NonZeroChecks; + bool AvoidNewBlocks; struct CachedCombinedShadow { BasicBlock *Block; @@ -294,6 +301,9 @@ struct DFSanFunction { IsNativeABI(IsNativeABI), ArgTLSPtr(nullptr), RetvalTLSPtr(nullptr), LabelReturnAlloca(nullptr) { DT.recalculate(*F); + // FIXME: Need to track down the register allocator issue which causes poor + // performance in pathological cases with large numbers of basic blocks. + AvoidNewBlocks = F->size() > 1000; } Value *getArgTLSPtr(); Value *getArgTLS(unsigned Index, Instruction *Pos); @@ -382,7 +392,6 @@ FunctionType *DataFlowSanitizer::getTrampolineFunctionType(FunctionType *T) { } FunctionType *DataFlowSanitizer::getCustomFunctionType(FunctionType *T) { - assert(!T->isVarArg()); llvm::SmallVector<Type *, 4> ArgTypes; for (FunctionType::param_iterator i = T->param_begin(), e = T->param_end(); i != e; ++i) { @@ -397,13 +406,20 @@ FunctionType *DataFlowSanitizer::getCustomFunctionType(FunctionType *T) { } for (unsigned i = 0, e = T->getNumParams(); i != e; ++i) ArgTypes.push_back(ShadowTy); + if (T->isVarArg()) + ArgTypes.push_back(ShadowPtrTy); Type *RetType = T->getReturnType(); if (!RetType->isVoidTy()) ArgTypes.push_back(ShadowPtrTy); - return FunctionType::get(T->getReturnType(), ArgTypes, false); + return FunctionType::get(T->getReturnType(), ArgTypes, T->isVarArg()); } bool DataFlowSanitizer::doInitialization(Module &M) { + llvm::Triple TargetTriple(M.getTargetTriple()); + bool IsX86_64 = TargetTriple.getArch() == llvm::Triple::x86_64; + bool IsMIPS64 = TargetTriple.getArch() == llvm::Triple::mips64 || + TargetTriple.getArch() == llvm::Triple::mips64el; + DataLayoutPass *DLP = getAnalysisIfAvailable<DataLayoutPass>(); if (!DLP) report_fatal_error("data layout missing"); @@ -415,8 +431,13 @@ bool DataFlowSanitizer::doInitialization(Module &M) { ShadowPtrTy = PointerType::getUnqual(ShadowTy); IntptrTy = DL->getIntPtrType(*Ctx); ZeroShadow = ConstantInt::getSigned(ShadowTy, 0); - ShadowPtrMask = ConstantInt::getSigned(IntptrTy, ~0x700000000000LL); ShadowPtrMul = ConstantInt::getSigned(IntptrTy, ShadowWidth / 8); + if (IsX86_64) + ShadowPtrMask = ConstantInt::getSigned(IntptrTy, ~0x700000000000LL); + else if (IsMIPS64) + ShadowPtrMask = ConstantInt::getSigned(IntptrTy, ~0xF000000000LL); + else + report_fatal_error("unsupported triple"); Type *DFSanUnionArgs[2] = { ShadowTy, ShadowTy }; DFSanUnionFnTy = @@ -430,7 +451,9 @@ bool DataFlowSanitizer::doInitialization(Module &M) { DFSanSetLabelFnTy = FunctionType::get(Type::getVoidTy(*Ctx), DFSanSetLabelArgs, /*isVarArg=*/false); DFSanNonzeroLabelFnTy = FunctionType::get( - Type::getVoidTy(*Ctx), ArrayRef<Type *>(), /*isVarArg=*/false); + Type::getVoidTy(*Ctx), None, /*isVarArg=*/false); + DFSanVarargWrapperFnTy = FunctionType::get( + Type::getVoidTy(*Ctx), Type::getInt8PtrTy(*Ctx), /*isVarArg=*/false); if (GetArgTLSPtr) { Type *ArgTLSTy = ArrayType::get(ShadowTy, 64); @@ -510,15 +533,26 @@ DataFlowSanitizer::buildWrapperFunction(Function *F, StringRef NewFName, AttributeSet::ReturnIndex)); BasicBlock *BB = BasicBlock::Create(*Ctx, "entry", NewF); - std::vector<Value *> Args; - unsigned n = FT->getNumParams(); - for (Function::arg_iterator ai = NewF->arg_begin(); n != 0; ++ai, --n) - Args.push_back(&*ai); - CallInst *CI = CallInst::Create(F, Args, "", BB); - if (FT->getReturnType()->isVoidTy()) - ReturnInst::Create(*Ctx, BB); - else - ReturnInst::Create(*Ctx, CI, BB); + if (F->isVarArg()) { + NewF->removeAttributes( + AttributeSet::FunctionIndex, + AttributeSet().addAttribute(*Ctx, AttributeSet::FunctionIndex, + "split-stack")); + CallInst::Create(DFSanVarargWrapperFn, + IRBuilder<>(BB).CreateGlobalStringPtr(F->getName()), "", + BB); + new UnreachableInst(*Ctx, BB); + } else { + std::vector<Value *> Args; + unsigned n = FT->getNumParams(); + for (Function::arg_iterator ai = NewF->arg_begin(); n != 0; ++ai, --n) + Args.push_back(&*ai); + CallInst *CI = CallInst::Create(F, Args, "", BB); + if (FT->getReturnType()->isVoidTy()) + ReturnInst::Create(*Ctx, BB); + else + ReturnInst::Create(*Ctx, CI, BB); + } return NewF; } @@ -563,6 +597,8 @@ bool DataFlowSanitizer::runOnModule(Module &M) { if (ABIList.isIn(M, "skip")) return false; + FunctionDIs = makeSubprogramMap(M); + if (!GetArgTLSPtr) { Type *ArgTLSTy = ArrayType::get(ShadowTy, 64); ArgTLS = Mod->getOrInsertGlobal("__dfsan_arg_tls", ArgTLSTy); @@ -577,6 +613,15 @@ bool DataFlowSanitizer::runOnModule(Module &M) { DFSanUnionFn = Mod->getOrInsertFunction("__dfsan_union", DFSanUnionFnTy); if (Function *F = dyn_cast<Function>(DFSanUnionFn)) { + F->addAttribute(AttributeSet::FunctionIndex, Attribute::NoUnwind); + F->addAttribute(AttributeSet::FunctionIndex, Attribute::ReadNone); + F->addAttribute(AttributeSet::ReturnIndex, Attribute::ZExt); + F->addAttribute(1, Attribute::ZExt); + F->addAttribute(2, Attribute::ZExt); + } + DFSanCheckedUnionFn = Mod->getOrInsertFunction("dfsan_union", DFSanUnionFnTy); + if (Function *F = dyn_cast<Function>(DFSanCheckedUnionFn)) { + F->addAttribute(AttributeSet::FunctionIndex, Attribute::NoUnwind); F->addAttribute(AttributeSet::FunctionIndex, Attribute::ReadNone); F->addAttribute(AttributeSet::ReturnIndex, Attribute::ZExt); F->addAttribute(1, Attribute::ZExt); @@ -585,6 +630,7 @@ bool DataFlowSanitizer::runOnModule(Module &M) { DFSanUnionLoadFn = Mod->getOrInsertFunction("__dfsan_union_load", DFSanUnionLoadFnTy); if (Function *F = dyn_cast<Function>(DFSanUnionLoadFn)) { + F->addAttribute(AttributeSet::FunctionIndex, Attribute::NoUnwind); F->addAttribute(AttributeSet::FunctionIndex, Attribute::ReadOnly); F->addAttribute(AttributeSet::ReturnIndex, Attribute::ZExt); } @@ -597,16 +643,20 @@ bool DataFlowSanitizer::runOnModule(Module &M) { } DFSanNonzeroLabelFn = Mod->getOrInsertFunction("__dfsan_nonzero_label", DFSanNonzeroLabelFnTy); + DFSanVarargWrapperFn = Mod->getOrInsertFunction("__dfsan_vararg_wrapper", + DFSanVarargWrapperFnTy); std::vector<Function *> FnsToInstrument; llvm::SmallPtrSet<Function *, 2> FnsWithNativeABI; for (Module::iterator i = M.begin(), e = M.end(); i != e; ++i) { if (!i->isIntrinsic() && i != DFSanUnionFn && + i != DFSanCheckedUnionFn && i != DFSanUnionLoadFn && i != DFSanUnimplementedFn && i != DFSanSetLabelFn && - i != DFSanNonzeroLabelFn) + i != DFSanNonzeroLabelFn && + i != DFSanVarargWrapperFn) FnsToInstrument.push_back(&*i); } @@ -688,11 +738,6 @@ bool DataFlowSanitizer::runOnModule(Module &M) { } else { addGlobalNamePrefix(&F); } - // Hopefully, nobody will try to indirectly call a vararg - // function... yet. - } else if (FT->isVarArg()) { - UnwrappedFnMap[&F] = &F; - *i = nullptr; } else if (!IsZeroArgsVoidRet || getWrapperKind(&F) == WK_Custom) { // Build a wrapper function for F. The wrapper simply calls F, and is // added to FnsToInstrument so that any instrumentation according to its @@ -709,6 +754,12 @@ bool DataFlowSanitizer::runOnModule(Module &M) { Value *WrappedFnCst = ConstantExpr::getBitCast(NewF, PointerType::getUnqual(FT)); F.replaceAllUsesWith(WrappedFnCst); + + // Patch the pointer to LLVM function in debug info descriptor. + auto DI = FunctionDIs.find(&F); + if (DI != FunctionDIs.end()) + DI->second.replaceFunction(&F); + UnwrappedFnMap[WrappedFnCst] = &F; *i = NewF; @@ -728,6 +779,11 @@ bool DataFlowSanitizer::runOnModule(Module &M) { i = FnsToInstrument.begin() + N; e = FnsToInstrument.begin() + Count; } + // Hopefully, nobody will try to indirectly call a vararg + // function... yet. + } else if (FT->isVarArg()) { + UnwrappedFnMap[&F] = &F; + *i = nullptr; } } @@ -786,18 +842,16 @@ bool DataFlowSanitizer::runOnModule(Module &M) { // yet). To make our life easier, do this work in a pass after the main // instrumentation. if (ClDebugNonzeroLabels) { - for (DenseSet<Value *>::iterator i = DFSF.NonZeroChecks.begin(), - e = DFSF.NonZeroChecks.end(); - i != e; ++i) { + for (Value *V : DFSF.NonZeroChecks) { Instruction *Pos; - if (Instruction *I = dyn_cast<Instruction>(*i)) + if (Instruction *I = dyn_cast<Instruction>(V)) Pos = I->getNextNode(); else Pos = DFSF.F->getEntryBlock().begin(); while (isa<PHINode>(Pos) || isa<AllocaInst>(Pos)) Pos = Pos->getNextNode(); IRBuilder<> IRB(Pos); - Value *Ne = IRB.CreateICmpNE(*i, DFSF.DFS.ZeroShadow); + Value *Ne = IRB.CreateICmpNE(V, DFSF.DFS.ZeroShadow); BranchInst *BI = cast<BranchInst>(SplitBlockAndInsertIfThen( Ne, Pos, /*Unreachable=*/false, ColdCallWeights)); IRBuilder<> ThenIRB(BI); @@ -862,7 +916,7 @@ Value *DFSanFunction::getShadow(Value *V) { break; } } - NonZeroChecks.insert(Shadow); + NonZeroChecks.push_back(Shadow); } else { Shadow = DFS.ZeroShadow; } @@ -922,23 +976,33 @@ Value *DFSanFunction::combineShadows(Value *V1, Value *V2, Instruction *Pos) { return CCS.Shadow; IRBuilder<> IRB(Pos); - BasicBlock *Head = Pos->getParent(); - Value *Ne = IRB.CreateICmpNE(V1, V2); - BranchInst *BI = cast<BranchInst>(SplitBlockAndInsertIfThen( - Ne, Pos, /*Unreachable=*/false, DFS.ColdCallWeights, &DT)); - IRBuilder<> ThenIRB(BI); - CallInst *Call = ThenIRB.CreateCall2(DFS.DFSanUnionFn, V1, V2); - Call->addAttribute(AttributeSet::ReturnIndex, Attribute::ZExt); - Call->addAttribute(1, Attribute::ZExt); - Call->addAttribute(2, Attribute::ZExt); - - BasicBlock *Tail = BI->getSuccessor(0); - PHINode *Phi = PHINode::Create(DFS.ShadowTy, 2, "", Tail->begin()); - Phi->addIncoming(Call, Call->getParent()); - Phi->addIncoming(V1, Head); - - CCS.Block = Tail; - CCS.Shadow = Phi; + if (AvoidNewBlocks) { + CallInst *Call = IRB.CreateCall2(DFS.DFSanCheckedUnionFn, V1, V2); + Call->addAttribute(AttributeSet::ReturnIndex, Attribute::ZExt); + Call->addAttribute(1, Attribute::ZExt); + Call->addAttribute(2, Attribute::ZExt); + + CCS.Block = Pos->getParent(); + CCS.Shadow = Call; + } else { + BasicBlock *Head = Pos->getParent(); + Value *Ne = IRB.CreateICmpNE(V1, V2); + BranchInst *BI = cast<BranchInst>(SplitBlockAndInsertIfThen( + Ne, Pos, /*Unreachable=*/false, DFS.ColdCallWeights, &DT)); + IRBuilder<> ThenIRB(BI); + CallInst *Call = ThenIRB.CreateCall2(DFS.DFSanUnionFn, V1, V2); + Call->addAttribute(AttributeSet::ReturnIndex, Attribute::ZExt); + Call->addAttribute(1, Attribute::ZExt); + Call->addAttribute(2, Attribute::ZExt); + + BasicBlock *Tail = BI->getSuccessor(0); + PHINode *Phi = PHINode::Create(DFS.ShadowTy, 2, "", Tail->begin()); + Phi->addIncoming(Call, Call->getParent()); + Phi->addIncoming(V1, Head); + + CCS.Block = Tail; + CCS.Shadow = Phi; + } std::set<Value *> UnionElems; if (V1Elems != ShadowElements.end()) { @@ -951,9 +1015,9 @@ Value *DFSanFunction::combineShadows(Value *V1, Value *V2, Instruction *Pos) { } else { UnionElems.insert(V2); } - ShadowElements[Phi] = std::move(UnionElems); + ShadowElements[CCS.Shadow] = std::move(UnionElems); - return Phi; + return CCS.Shadow; } // A convenience function which folds the shadows of each of the operands @@ -1022,7 +1086,7 @@ Value *DFSanFunction::loadShadow(Value *Addr, uint64_t Size, uint64_t Align, IRB.CreateAlignedLoad(ShadowAddr1, ShadowAlign), Pos); } } - if (Size % (64 / DFS.ShadowWidth) == 0) { + if (!AvoidNewBlocks && Size % (64 / DFS.ShadowWidth) == 0) { // Fast path for the common case where each byte has identical shadow: load // shadow 64 bits at a time, fall out to a __dfsan_union_load call if any // shadow is non-equal. @@ -1092,6 +1156,11 @@ Value *DFSanFunction::loadShadow(Value *Addr, uint64_t Size, uint64_t Align, void DFSanVisitor::visitLoadInst(LoadInst &LI) { uint64_t Size = DFSF.DFS.DL->getTypeStoreSize(LI.getType()); + if (Size == 0) { + DFSF.setShadow(&LI, DFSF.DFS.ZeroShadow); + return; + } + uint64_t Align; if (ClPreserveAlignment) { Align = LI.getAlignment(); @@ -1107,7 +1176,7 @@ void DFSanVisitor::visitLoadInst(LoadInst &LI) { Shadow = DFSF.combineShadows(Shadow, PtrShadow, &LI); } if (Shadow != DFSF.DFS.ZeroShadow) - DFSF.NonZeroChecks.insert(Shadow); + DFSF.NonZeroChecks.push_back(Shadow); DFSF.setShadow(&LI, Shadow); } @@ -1166,6 +1235,9 @@ void DFSanFunction::storeShadow(Value *Addr, uint64_t Size, uint64_t Align, void DFSanVisitor::visitStoreInst(StoreInst &SI) { uint64_t Size = DFSF.DFS.DL->getTypeStoreSize(SI.getValueOperand()->getType()); + if (Size == 0) + return; + uint64_t Align; if (ClPreserveAlignment) { Align = SI.getAlignment(); @@ -1320,6 +1392,15 @@ void DFSanVisitor::visitCallSite(CallSite CS) { return; } + // Calls to this function are synthesized in wrappers, and we shouldn't + // instrument them. + if (F == DFSF.DFS.DFSanVarargWrapperFn) + return; + + assert(!(cast<FunctionType>( + CS.getCalledValue()->getType()->getPointerElementType())->isVarArg() && + dyn_cast<InvokeInst>(CS.getInstruction()))); + IRBuilder<> IRB(CS.getInstruction()); DenseMap<Value *, Function *>::iterator i = @@ -1391,6 +1472,20 @@ void DFSanVisitor::visitCallSite(CallSite CS) { for (unsigned n = FT->getNumParams(); n != 0; ++i, --n) Args.push_back(DFSF.getShadow(*i)); + if (FT->isVarArg()) { + auto LabelVAAlloca = + new AllocaInst(ArrayType::get(DFSF.DFS.ShadowTy, + CS.arg_size() - FT->getNumParams()), + "labelva", DFSF.F->getEntryBlock().begin()); + + for (unsigned n = 0; i != CS.arg_end(); ++i, ++n) { + auto LabelVAPtr = IRB.CreateStructGEP(LabelVAAlloca, n); + IRB.CreateStore(DFSF.getShadow(*i), LabelVAPtr); + } + + Args.push_back(IRB.CreateStructGEP(LabelVAAlloca, 0)); + } + if (!FT->getReturnType()->isVoidTy()) { if (!DFSF.LabelReturnAlloca) { DFSF.LabelReturnAlloca = @@ -1400,6 +1495,9 @@ void DFSanVisitor::visitCallSite(CallSite CS) { Args.push_back(DFSF.LabelReturnAlloca); } + for (i = CS.arg_begin() + FT->getNumParams(); i != CS.arg_end(); ++i) + Args.push_back(*i); + CallInst *CustomCI = IRB.CreateCall(CustomF, Args); CustomCI->setCallingConv(CI->getCallingConv()); CustomCI->setAttributes(CI->getAttributes()); @@ -1446,7 +1544,7 @@ void DFSanVisitor::visitCallSite(CallSite CS) { LoadInst *LI = NextIRB.CreateLoad(DFSF.getRetvalTLS()); DFSF.SkipInsts.insert(LI); DFSF.setShadow(CS.getInstruction(), LI); - DFSF.NonZeroChecks.insert(LI); + DFSF.NonZeroChecks.push_back(LI); } } @@ -1500,7 +1598,7 @@ void DFSanVisitor::visitCallSite(CallSite CS) { ExtractValueInst::Create(NewCS.getInstruction(), 1, "", Next); DFSF.SkipInsts.insert(ExShadow); DFSF.setShadow(ExVal, ExShadow); - DFSF.NonZeroChecks.insert(ExShadow); + DFSF.NonZeroChecks.push_back(ExShadow); CS.getInstruction()->replaceAllUsesWith(ExVal); } diff --git a/lib/Transforms/Instrumentation/DebugIR.cpp b/lib/Transforms/Instrumentation/DebugIR.cpp deleted file mode 100644 index f2f1738808be..000000000000 --- a/lib/Transforms/Instrumentation/DebugIR.cpp +++ /dev/null @@ -1,617 +0,0 @@ -//===--- DebugIR.cpp - Transform debug metadata to allow debugging IR -----===// -// -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. -// -//===----------------------------------------------------------------------===// -// -// A Module transform pass that emits a succinct version of the IR and replaces -// the source file metadata to allow debuggers to step through the IR. -// -// FIXME: instead of replacing debug metadata, this pass should allow for -// additional metadata to be used to point capable debuggers to the IR file -// without destroying the mapping to the original source file. -// -//===----------------------------------------------------------------------===// - -#include "llvm/IR/ValueMap.h" -#include "DebugIR.h" -#include "llvm/IR/AssemblyAnnotationWriter.h" -#include "llvm/IR/DIBuilder.h" -#include "llvm/IR/DataLayout.h" -#include "llvm/IR/DebugInfo.h" -#include "llvm/IR/InstVisitor.h" -#include "llvm/IR/Instruction.h" -#include "llvm/IR/LLVMContext.h" -#include "llvm/IR/Module.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/FileSystem.h" -#include "llvm/Support/FormattedStream.h" -#include "llvm/Support/Path.h" -#include "llvm/Support/ToolOutputFile.h" -#include "llvm/Transforms/Instrumentation.h" -#include "llvm/Transforms/Utils/Cloning.h" -#include <string> - -#define STR_HELPER(x) #x -#define STR(x) STR_HELPER(x) - -using namespace llvm; - -#define DEBUG_TYPE "debug-ir" - -namespace { - -/// Builds a map of Value* to line numbers on which the Value appears in a -/// textual representation of the IR by plugging into the AssemblyWriter by -/// masquerading as an AssemblyAnnotationWriter. -class ValueToLineMap : public AssemblyAnnotationWriter { - ValueMap<const Value *, unsigned int> Lines; - typedef ValueMap<const Value *, unsigned int>::const_iterator LineIter; - - void addEntry(const Value *V, formatted_raw_ostream &Out) { - Out.flush(); - Lines.insert(std::make_pair(V, Out.getLine() + 1)); - } - -public: - - /// Prints Module to a null buffer in order to build the map of Value pointers - /// to line numbers. - ValueToLineMap(const Module *M) { - raw_null_ostream ThrowAway; - M->print(ThrowAway, this); - } - - // This function is called after an Instruction, GlobalValue, or GlobalAlias - // is printed. - void printInfoComment(const Value &V, formatted_raw_ostream &Out) override { - addEntry(&V, Out); - } - - void emitFunctionAnnot(const Function *F, - formatted_raw_ostream &Out) override { - addEntry(F, Out); - } - - /// If V appears on a line in the textual IR representation, sets Line to the - /// line number and returns true, otherwise returns false. - bool getLine(const Value *V, unsigned int &Line) const { - LineIter i = Lines.find(V); - if (i != Lines.end()) { - Line = i->second; - return true; - } - return false; - } -}; - -/// Removes debug intrisncs like llvm.dbg.declare and llvm.dbg.value. -class DebugIntrinsicsRemover : public InstVisitor<DebugIntrinsicsRemover> { - void remove(Instruction &I) { I.eraseFromParent(); } - -public: - static void process(Module &M) { - DebugIntrinsicsRemover Remover; - Remover.visit(&M); - } - void visitDbgDeclareInst(DbgDeclareInst &I) { remove(I); } - void visitDbgValueInst(DbgValueInst &I) { remove(I); } - void visitDbgInfoIntrinsic(DbgInfoIntrinsic &I) { remove(I); } -}; - -/// Removes debug metadata (!dbg) nodes from all instructions, and optionally -/// metadata named "llvm.dbg.cu" if RemoveNamedInfo is true. -class DebugMetadataRemover : public InstVisitor<DebugMetadataRemover> { - bool RemoveNamedInfo; - -public: - static void process(Module &M, bool RemoveNamedInfo = true) { - DebugMetadataRemover Remover(RemoveNamedInfo); - Remover.run(&M); - } - - DebugMetadataRemover(bool RemoveNamedInfo) - : RemoveNamedInfo(RemoveNamedInfo) {} - - void visitInstruction(Instruction &I) { - if (I.getMetadata(LLVMContext::MD_dbg)) - I.setMetadata(LLVMContext::MD_dbg, nullptr); - } - - void run(Module *M) { - // Remove debug metadata attached to instructions - visit(M); - - if (RemoveNamedInfo) { - // Remove CU named metadata (and all children nodes) - NamedMDNode *Node = M->getNamedMetadata("llvm.dbg.cu"); - if (Node) - M->eraseNamedMetadata(Node); - } - } -}; - -/// Updates debug metadata in a Module: -/// - changes Filename/Directory to values provided on construction -/// - adds/updates line number (DebugLoc) entries associated with each -/// instruction to reflect the instruction's location in an LLVM IR file -class DIUpdater : public InstVisitor<DIUpdater> { - /// Builder of debug information - DIBuilder Builder; - - /// Helper for type attributes/sizes/etc - DataLayout Layout; - - /// Map of Value* to line numbers - const ValueToLineMap LineTable; - - /// Map of Value* (in original Module) to Value* (in optional cloned Module) - const ValueToValueMapTy *VMap; - - /// Directory of debug metadata - DebugInfoFinder Finder; - - /// Source filename and directory - StringRef Filename; - StringRef Directory; - - // CU nodes needed when creating DI subprograms - MDNode *FileNode; - MDNode *LexicalBlockFileNode; - const MDNode *CUNode; - - ValueMap<const Function *, MDNode *> SubprogramDescriptors; - DenseMap<const Type *, MDNode *> TypeDescriptors; - -public: - DIUpdater(Module &M, StringRef Filename = StringRef(), - StringRef Directory = StringRef(), const Module *DisplayM = nullptr, - const ValueToValueMapTy *VMap = nullptr) - : Builder(M), Layout(&M), LineTable(DisplayM ? DisplayM : &M), VMap(VMap), - Finder(), Filename(Filename), Directory(Directory), FileNode(nullptr), - LexicalBlockFileNode(nullptr), CUNode(nullptr) { - Finder.processModule(M); - visit(&M); - } - - ~DIUpdater() { Builder.finalize(); } - - void visitModule(Module &M) { - if (Finder.compile_unit_count() > 1) - report_fatal_error("DebugIR pass supports only a signle compile unit per " - "Module."); - createCompileUnit(Finder.compile_unit_count() == 1 ? - (MDNode*)*Finder.compile_units().begin() : nullptr); - } - - void visitFunction(Function &F) { - if (F.isDeclaration() || findDISubprogram(&F)) - return; - - StringRef MangledName = F.getName(); - DICompositeType Sig = createFunctionSignature(&F); - - // find line of function declaration - unsigned Line = 0; - if (!findLine(&F, Line)) { - DEBUG(dbgs() << "WARNING: No line for Function " << F.getName().str() - << "\n"); - return; - } - - Instruction *FirstInst = F.begin()->begin(); - unsigned ScopeLine = 0; - if (!findLine(FirstInst, ScopeLine)) { - DEBUG(dbgs() << "WARNING: No line for 1st Instruction in Function " - << F.getName().str() << "\n"); - return; - } - - bool Local = F.hasInternalLinkage(); - bool IsDefinition = !F.isDeclaration(); - bool IsOptimized = false; - - int FuncFlags = llvm::DIDescriptor::FlagPrototyped; - assert(CUNode && FileNode); - DISubprogram Sub = Builder.createFunction( - DICompileUnit(CUNode), F.getName(), MangledName, DIFile(FileNode), Line, - Sig, Local, IsDefinition, ScopeLine, FuncFlags, IsOptimized, &F); - assert(Sub.isSubprogram()); - DEBUG(dbgs() << "create subprogram mdnode " << *Sub << ": " - << "\n"); - - SubprogramDescriptors.insert(std::make_pair(&F, Sub)); - } - - void visitInstruction(Instruction &I) { - DebugLoc Loc(I.getDebugLoc()); - - /// If a ValueToValueMap is provided, use it to get the real instruction as - /// the line table was generated on a clone of the module on which we are - /// operating. - Value *RealInst = nullptr; - if (VMap) - RealInst = VMap->lookup(&I); - - if (!RealInst) - RealInst = &I; - - unsigned Col = 0; // FIXME: support columns - unsigned Line; - if (!LineTable.getLine(RealInst, Line)) { - // Instruction has no line, it may have been removed (in the module that - // will be passed to the debugger) so there is nothing to do here. - DEBUG(dbgs() << "WARNING: no LineTable entry for instruction " << RealInst - << "\n"); - DEBUG(RealInst->dump()); - return; - } - - DebugLoc NewLoc; - if (!Loc.isUnknown()) - // I had a previous debug location: re-use the DebugLoc - NewLoc = DebugLoc::get(Line, Col, Loc.getScope(RealInst->getContext()), - Loc.getInlinedAt(RealInst->getContext())); - else if (MDNode *scope = findScope(&I)) - NewLoc = DebugLoc::get(Line, Col, scope, nullptr); - else { - DEBUG(dbgs() << "WARNING: no valid scope for instruction " << &I - << ". no DebugLoc will be present." - << "\n"); - return; - } - - addDebugLocation(I, NewLoc); - } - -private: - - void createCompileUnit(MDNode *CUToReplace) { - std::string Flags; - bool IsOptimized = false; - StringRef Producer; - unsigned RuntimeVersion(0); - StringRef SplitName; - - if (CUToReplace) { - // save fields from existing CU to re-use in the new CU - DICompileUnit ExistingCU(CUToReplace); - Producer = ExistingCU.getProducer(); - IsOptimized = ExistingCU.isOptimized(); - Flags = ExistingCU.getFlags(); - RuntimeVersion = ExistingCU.getRunTimeVersion(); - SplitName = ExistingCU.getSplitDebugFilename(); - } else { - Producer = - "LLVM Version " STR(LLVM_VERSION_MAJOR) "." STR(LLVM_VERSION_MINOR); - } - - CUNode = - Builder.createCompileUnit(dwarf::DW_LANG_C99, Filename, Directory, - Producer, IsOptimized, Flags, RuntimeVersion); - - if (CUToReplace) - CUToReplace->replaceAllUsesWith(const_cast<MDNode *>(CUNode)); - - DICompileUnit CU(CUNode); - FileNode = Builder.createFile(Filename, Directory); - LexicalBlockFileNode = Builder.createLexicalBlockFile(CU, DIFile(FileNode)); - } - - /// Returns the MDNode* that represents the DI scope to associate with I - MDNode *findScope(const Instruction *I) { - const Function *F = I->getParent()->getParent(); - if (MDNode *ret = findDISubprogram(F)) - return ret; - - DEBUG(dbgs() << "WARNING: Using fallback lexical block file scope " - << LexicalBlockFileNode << " as scope for instruction " << I - << "\n"); - return LexicalBlockFileNode; - } - - /// Returns the MDNode* that is the descriptor for F - MDNode *findDISubprogram(const Function *F) { - typedef ValueMap<const Function *, MDNode *>::const_iterator FuncNodeIter; - FuncNodeIter i = SubprogramDescriptors.find(F); - if (i != SubprogramDescriptors.end()) - return i->second; - - DEBUG(dbgs() << "searching for DI scope node for Function " << F - << " in a list of " << Finder.subprogram_count() - << " subprogram nodes" - << "\n"); - - for (DISubprogram S : Finder.subprograms()) { - if (S.getFunction() == F) { - DEBUG(dbgs() << "Found DISubprogram " << S << " for function " - << S.getFunction() << "\n"); - return S; - } - } - DEBUG(dbgs() << "unable to find DISubprogram node for function " - << F->getName().str() << "\n"); - return nullptr; - } - - /// Sets Line to the line number on which V appears and returns true. If a - /// line location for V is not found, returns false. - bool findLine(const Value *V, unsigned &Line) { - if (LineTable.getLine(V, Line)) - return true; - - if (VMap) { - Value *mapped = VMap->lookup(V); - if (mapped && LineTable.getLine(mapped, Line)) - return true; - } - return false; - } - - std::string getTypeName(Type *T) { - std::string TypeName; - raw_string_ostream TypeStream(TypeName); - if (T) - T->print(TypeStream); - else - TypeStream << "Printing <null> Type"; - TypeStream.flush(); - return TypeName; - } - - /// Returns the MDNode that represents type T if it is already created, or 0 - /// if it is not. - MDNode *getType(const Type *T) { - typedef DenseMap<const Type *, MDNode *>::const_iterator TypeNodeIter; - TypeNodeIter i = TypeDescriptors.find(T); - if (i != TypeDescriptors.end()) - return i->second; - return nullptr; - } - - /// Returns a DebugInfo type from an LLVM type T. - DIDerivedType getOrCreateType(Type *T) { - MDNode *N = getType(T); - if (N) - return DIDerivedType(N); - else if (T->isVoidTy()) - return DIDerivedType(nullptr); - else if (T->isStructTy()) { - N = Builder.createStructType( - DIScope(LexicalBlockFileNode), T->getStructName(), DIFile(FileNode), - 0, Layout.getTypeSizeInBits(T), Layout.getABITypeAlignment(T), 0, - DIType(nullptr), DIArray(nullptr)); // filled in later - - // N is added to the map (early) so that element search below can find it, - // so as to avoid infinite recursion for structs that contain pointers to - // their own type. - TypeDescriptors[T] = N; - DICompositeType StructDescriptor(N); - - SmallVector<Value *, 4> Elements; - for (unsigned i = 0; i < T->getStructNumElements(); ++i) - Elements.push_back(getOrCreateType(T->getStructElementType(i))); - - // set struct elements - StructDescriptor.setTypeArray(Builder.getOrCreateArray(Elements)); - } else if (T->isPointerTy()) { - Type *PointeeTy = T->getPointerElementType(); - if (!(N = getType(PointeeTy))) - N = Builder.createPointerType( - getOrCreateType(PointeeTy), Layout.getPointerTypeSizeInBits(T), - Layout.getPrefTypeAlignment(T), getTypeName(T)); - } else if (T->isArrayTy()) { - SmallVector<Value *, 1> Subrange; - Subrange.push_back( - Builder.getOrCreateSubrange(0, T->getArrayNumElements() - 1)); - - N = Builder.createArrayType(Layout.getTypeSizeInBits(T), - Layout.getPrefTypeAlignment(T), - getOrCreateType(T->getArrayElementType()), - Builder.getOrCreateArray(Subrange)); - } else { - int encoding = llvm::dwarf::DW_ATE_signed; - if (T->isIntegerTy()) - encoding = llvm::dwarf::DW_ATE_unsigned; - else if (T->isFloatingPointTy()) - encoding = llvm::dwarf::DW_ATE_float; - - N = Builder.createBasicType(getTypeName(T), T->getPrimitiveSizeInBits(), - 0, encoding); - } - TypeDescriptors[T] = N; - return DIDerivedType(N); - } - - /// Returns a DebugInfo type that represents a function signature for Func. - DICompositeType createFunctionSignature(const Function *Func) { - SmallVector<Value *, 4> Params; - DIDerivedType ReturnType(getOrCreateType(Func->getReturnType())); - Params.push_back(ReturnType); - - const Function::ArgumentListType &Args(Func->getArgumentList()); - for (Function::ArgumentListType::const_iterator i = Args.begin(), - e = Args.end(); - i != e; ++i) { - Type *T(i->getType()); - Params.push_back(getOrCreateType(T)); - } - - DIArray ParamArray = Builder.getOrCreateArray(Params); - return Builder.createSubroutineType(DIFile(FileNode), ParamArray); - } - - /// Associates Instruction I with debug location Loc. - void addDebugLocation(Instruction &I, DebugLoc Loc) { - MDNode *MD = Loc.getAsMDNode(I.getContext()); - I.setMetadata(LLVMContext::MD_dbg, MD); - } -}; - -/// Sets Filename/Directory from the Module identifier and returns true, or -/// false if source information is not present. -bool getSourceInfoFromModule(const Module &M, std::string &Directory, - std::string &Filename) { - std::string PathStr(M.getModuleIdentifier()); - if (PathStr.length() == 0 || PathStr == "<stdin>") - return false; - - Filename = sys::path::filename(PathStr); - SmallVector<char, 16> Path(PathStr.begin(), PathStr.end()); - sys::path::remove_filename(Path); - Directory = StringRef(Path.data(), Path.size()); - return true; -} - -// Sets Filename/Directory from debug information in M and returns true, or -// false if no debug information available, or cannot be parsed. -bool getSourceInfoFromDI(const Module &M, std::string &Directory, - std::string &Filename) { - NamedMDNode *CUNode = M.getNamedMetadata("llvm.dbg.cu"); - if (!CUNode || CUNode->getNumOperands() == 0) - return false; - - DICompileUnit CU(CUNode->getOperand(0)); - if (!CU.Verify()) - return false; - - Filename = CU.getFilename(); - Directory = CU.getDirectory(); - return true; -} - -} // anonymous namespace - -namespace llvm { - -bool DebugIR::getSourceInfo(const Module &M) { - ParsedPath = getSourceInfoFromDI(M, Directory, Filename) || - getSourceInfoFromModule(M, Directory, Filename); - return ParsedPath; -} - -bool DebugIR::updateExtension(StringRef NewExtension) { - size_t dot = Filename.find_last_of("."); - if (dot == std::string::npos) - return false; - - Filename.erase(dot); - Filename += NewExtension.str(); - return true; -} - -void DebugIR::generateFilename(std::unique_ptr<int> &fd) { - SmallVector<char, 16> PathVec; - fd.reset(new int); - sys::fs::createTemporaryFile("debug-ir", "ll", *fd, PathVec); - StringRef Path(PathVec.data(), PathVec.size()); - Filename = sys::path::filename(Path); - sys::path::remove_filename(PathVec); - Directory = StringRef(PathVec.data(), PathVec.size()); - - GeneratedPath = true; -} - -std::string DebugIR::getPath() { - SmallVector<char, 16> Path; - sys::path::append(Path, Directory, Filename); - Path.resize(Filename.size() + Directory.size() + 2); - Path[Filename.size() + Directory.size() + 1] = '\0'; - return std::string(Path.data()); -} - -void DebugIR::writeDebugBitcode(const Module *M, int *fd) { - std::unique_ptr<raw_fd_ostream> Out; - std::string error; - - if (!fd) { - std::string Path = getPath(); - Out.reset(new raw_fd_ostream(Path.c_str(), error, sys::fs::F_Text)); - DEBUG(dbgs() << "WRITING debug bitcode from Module " << M << " to file " - << Path << "\n"); - } else { - DEBUG(dbgs() << "WRITING debug bitcode from Module " << M << " to fd " - << *fd << "\n"); - Out.reset(new raw_fd_ostream(*fd, true)); - } - - M->print(*Out, nullptr); - Out->close(); -} - -void DebugIR::createDebugInfo(Module &M, std::unique_ptr<Module> &DisplayM) { - if (M.getFunctionList().size() == 0) - // no functions -- no debug info needed - return; - - std::unique_ptr<ValueToValueMapTy> VMap; - - if (WriteSourceToDisk && (HideDebugIntrinsics || HideDebugMetadata)) { - VMap.reset(new ValueToValueMapTy); - DisplayM.reset(CloneModule(&M, *VMap)); - - if (HideDebugIntrinsics) - DebugIntrinsicsRemover::process(*DisplayM); - - if (HideDebugMetadata) - DebugMetadataRemover::process(*DisplayM); - } - - DIUpdater R(M, Filename, Directory, DisplayM.get(), VMap.get()); -} - -bool DebugIR::isMissingPath() { return Filename.empty() || Directory.empty(); } - -bool DebugIR::runOnModule(Module &M) { - std::unique_ptr<int> fd; - - if (isMissingPath() && !getSourceInfo(M)) { - if (!WriteSourceToDisk) - report_fatal_error("DebugIR unable to determine file name in input. " - "Ensure Module contains an identifier, a valid " - "DICompileUnit, or construct DebugIR with " - "non-empty Filename/Directory parameters."); - else - generateFilename(fd); - } - - if (!GeneratedPath && WriteSourceToDisk) - updateExtension(".debug-ll"); - - // Clear line numbers. Keep debug info (if any) if we were able to read the - // file name from the DICompileUnit descriptor. - DebugMetadataRemover::process(M, !ParsedPath); - - std::unique_ptr<Module> DisplayM; - createDebugInfo(M, DisplayM); - if (WriteSourceToDisk) { - Module *OutputM = DisplayM.get() ? DisplayM.get() : &M; - writeDebugBitcode(OutputM, fd.get()); - } - - DEBUG(M.dump()); - return true; -} - -bool DebugIR::runOnModule(Module &M, std::string &Path) { - bool result = runOnModule(M); - Path = getPath(); - return result; -} - -} // llvm namespace - -char DebugIR::ID = 0; -INITIALIZE_PASS(DebugIR, "debug-ir", "Enable debugging IR", false, false) - -ModulePass *llvm::createDebugIRPass(bool HideDebugIntrinsics, - bool HideDebugMetadata, StringRef Directory, - StringRef Filename) { - return new DebugIR(HideDebugIntrinsics, HideDebugMetadata, Directory, - Filename); -} - -ModulePass *llvm::createDebugIRPass() { return new DebugIR(); } diff --git a/lib/Transforms/Instrumentation/DebugIR.h b/lib/Transforms/Instrumentation/DebugIR.h deleted file mode 100644 index 02831eda2d9f..000000000000 --- a/lib/Transforms/Instrumentation/DebugIR.h +++ /dev/null @@ -1,98 +0,0 @@ -//===- llvm/Transforms/Instrumentation/DebugIR.h - Interface ----*- C++ -*-===// -// -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. -// -//===----------------------------------------------------------------------===// -// -// This file defines the interface of the DebugIR pass. For most users, -// including Instrumentation.h and calling createDebugIRPass() is sufficient and -// there is no need to include this file. -// -//===----------------------------------------------------------------------===// - -#ifndef LLVM_TRANSFORMS_INSTRUMENTATION_DEBUGIR_H -#define LLVM_TRANSFORMS_INSTRUMENTATION_DEBUGIR_H - -#include "llvm/Pass.h" - -namespace llvm { - -class DebugIR : public llvm::ModulePass { - /// If true, write a source file to disk. - bool WriteSourceToDisk; - - /// Hide certain (non-essential) debug information (only relevant if - /// createSource is true. - bool HideDebugIntrinsics; - bool HideDebugMetadata; - - /// The location of the source file. - std::string Directory; - std::string Filename; - - /// True if a temporary file name was generated. - bool GeneratedPath; - - /// True if the file name was read from the Module. - bool ParsedPath; - -public: - static char ID; - - const char *getPassName() const override { return "DebugIR"; } - - /// Generate a file on disk to be displayed in a debugger. If Filename and - /// Directory are empty, a temporary path will be generated. - DebugIR(bool HideDebugIntrinsics, bool HideDebugMetadata, - llvm::StringRef Directory, llvm::StringRef Filename) - : ModulePass(ID), WriteSourceToDisk(true), - HideDebugIntrinsics(HideDebugIntrinsics), - HideDebugMetadata(HideDebugMetadata), Directory(Directory), - Filename(Filename), GeneratedPath(false), ParsedPath(false) {} - - /// Modify input in-place; do not generate additional files, and do not hide - /// any debug intrinsics/metadata that might be present. - DebugIR() - : ModulePass(ID), WriteSourceToDisk(false), HideDebugIntrinsics(false), - HideDebugMetadata(false), GeneratedPath(false), ParsedPath(false) {} - - /// Run pass on M and set Path to the source file path in the output module. - bool runOnModule(llvm::Module &M, std::string &Path); - bool runOnModule(llvm::Module &M) override; - -private: - - /// Returns the concatenated Directory + Filename, without error checking - std::string getPath(); - - /// Attempts to read source information from debug information in M, and if - /// that fails, from M's identifier. Returns true on success, false otherwise. - bool getSourceInfo(const llvm::Module &M); - - /// Replace the extension of Filename with NewExtension, and return true if - /// successful. Return false if extension could not be found or Filename is - /// empty. - bool updateExtension(llvm::StringRef NewExtension); - - /// Generate a temporary filename and open an fd - void generateFilename(std::unique_ptr<int> &fd); - - /// Creates DWARF CU/Subroutine metadata - void createDebugInfo(llvm::Module &M, - std::unique_ptr<llvm::Module> &DisplayM); - - /// Returns true if either Directory or Filename is missing, false otherwise. - bool isMissingPath(); - - /// Write M to disk, optionally passing in an fd to an open file which is - /// closed by this function after writing. If no fd is specified, a new file - /// is opened, written, and closed. - void writeDebugBitcode(const llvm::Module *M, int *fd = nullptr); -}; - -} // llvm namespace - -#endif // LLVM_TRANSFORMS_INSTRUMENTATION_DEBUGIR_H diff --git a/lib/Transforms/Instrumentation/GCOVProfiling.cpp b/lib/Transforms/Instrumentation/GCOVProfiling.cpp index cfeb62eb1f9f..cb965fb9a225 100644 --- a/lib/Transforms/Instrumentation/GCOVProfiling.cpp +++ b/lib/Transforms/Instrumentation/GCOVProfiling.cpp @@ -285,6 +285,14 @@ namespace { DeleteContainerSeconds(LinesByFile); } + GCOVBlock(const GCOVBlock &RHS) : GCOVRecord(RHS), Number(RHS.Number) { + // Only allow copy before edges and lines have been added. After that, + // there are inter-block pointers (eg: edges) that won't take kindly to + // blocks being copied or moved around. + assert(LinesByFile.empty()); + assert(OutEdges.empty()); + } + private: friend class GCOVFunction; @@ -303,18 +311,22 @@ namespace { // object users can construct, the blocks and lines will be rooted here. class GCOVFunction : public GCOVRecord { public: - GCOVFunction(DISubprogram SP, raw_ostream *os, uint32_t Ident, - bool UseCfgChecksum) : - SP(SP), Ident(Ident), UseCfgChecksum(UseCfgChecksum), CfgChecksum(0) { + GCOVFunction(DISubprogram SP, raw_ostream *os, uint32_t Ident, + bool UseCfgChecksum) + : SP(SP), Ident(Ident), UseCfgChecksum(UseCfgChecksum), CfgChecksum(0), + ReturnBlock(1, os) { this->os = os; Function *F = SP.getFunction(); DEBUG(dbgs() << "Function: " << getFunctionName(SP) << "\n"); + uint32_t i = 0; - for (Function::iterator BB = F->begin(), E = F->end(); BB != E; ++BB) { - Blocks[BB] = new GCOVBlock(i++, os); + for (auto &BB : *F) { + // Skip index 1 (0, 2, 3, 4, ...) because that's assigned to the + // ReturnBlock. + bool first = i == 0; + Blocks.insert(std::make_pair(&BB, GCOVBlock(i++ + !first, os))); } - ReturnBlock = new GCOVBlock(i++, os); std::string FunctionNameAndLine; raw_string_ostream FNLOS(FunctionNameAndLine); @@ -323,17 +335,12 @@ namespace { FuncChecksum = hash_value(FunctionNameAndLine); } - ~GCOVFunction() { - DeleteContainerSeconds(Blocks); - delete ReturnBlock; - } - GCOVBlock &getBlock(BasicBlock *BB) { - return *Blocks[BB]; + return Blocks.find(BB)->second; } GCOVBlock &getReturnBlock() { - return *ReturnBlock; + return ReturnBlock; } std::string getEdgeDestinations() { @@ -341,7 +348,7 @@ namespace { raw_string_ostream EDOS(EdgeDestinations); Function *F = Blocks.begin()->first->getParent(); for (Function::iterator I = F->begin(), E = F->end(); I != E; ++I) { - GCOVBlock &Block = *Blocks[I]; + GCOVBlock &Block = getBlock(I); for (int i = 0, e = Block.OutEdges.size(); i != e; ++i) EDOS << Block.OutEdges[i]->Number; } @@ -383,7 +390,7 @@ namespace { if (Blocks.empty()) return; Function *F = Blocks.begin()->first->getParent(); for (Function::iterator I = F->begin(), E = F->end(); I != E; ++I) { - GCOVBlock &Block = *Blocks[I]; + GCOVBlock &Block = getBlock(I); if (Block.OutEdges.empty()) continue; writeBytes(EdgeTag, 4); @@ -399,7 +406,7 @@ namespace { // Emit lines for each block. for (Function::iterator I = F->begin(), E = F->end(); I != E; ++I) { - Blocks[I]->writeOut(); + getBlock(I).writeOut(); } } @@ -409,8 +416,8 @@ namespace { uint32_t FuncChecksum; bool UseCfgChecksum; uint32_t CfgChecksum; - DenseMap<BasicBlock *, GCOVBlock *> Blocks; - GCOVBlock *ReturnBlock; + DenseMap<BasicBlock *, GCOVBlock> Blocks; + GCOVBlock ReturnBlock; }; } @@ -480,12 +487,12 @@ void GCOVProfiler::emitProfileNotes() { // LTO, we'll generate the same .gcno files. DICompileUnit CU(CU_Nodes->getOperand(i)); - std::string ErrorInfo; - raw_fd_ostream out(mangleName(CU, "gcno").c_str(), ErrorInfo, - sys::fs::F_None); + std::error_code EC; + raw_fd_ostream out(mangleName(CU, "gcno"), EC, sys::fs::F_None); std::string EdgeDestinations; DIArray SPs = CU.getSubprograms(); + unsigned FunctionIdent = 0; for (unsigned i = 0, e = SPs.getNumElements(); i != e; ++i) { DISubprogram SP(SPs.getElement(i)); assert((!SP || SP.isSubprogram()) && @@ -505,8 +512,8 @@ void GCOVProfiler::emitProfileNotes() { ++It; EntryBlock.splitBasicBlock(It); - Funcs.push_back( - make_unique<GCOVFunction>(SP, &out, i, Options.UseCfgChecksum)); + Funcs.push_back(make_unique<GCOVFunction>(SP, &out, FunctionIdent++, + Options.UseCfgChecksum)); GCOVFunction &Func = *Funcs.back(); for (Function::iterator BB = F->begin(), E = F->end(); BB != E; ++BB) { @@ -738,11 +745,11 @@ GlobalVariable *GCOVProfiler::buildEdgeLookupTable( Edge += Successors; } - ArrayRef<Constant*> V(&EdgeTable[0], TableSize); GlobalVariable *EdgeTableGV = new GlobalVariable( *M, EdgeTableTy, true, GlobalValue::InternalLinkage, - ConstantArray::get(EdgeTableTy, V), + ConstantArray::get(EdgeTableTy, + makeArrayRef(&EdgeTable[0],TableSize)), "__llvm_gcda_edge_table"); EdgeTableGV->setUnnamedAddr(true); return EdgeTableGV; diff --git a/lib/Transforms/Instrumentation/InstrProfiling.cpp b/lib/Transforms/Instrumentation/InstrProfiling.cpp new file mode 100644 index 000000000000..5f73b89e8551 --- /dev/null +++ b/lib/Transforms/Instrumentation/InstrProfiling.cpp @@ -0,0 +1,309 @@ +//===-- InstrProfiling.cpp - Frontend instrumentation based profiling -----===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass lowers instrprof_increment intrinsics emitted by a frontend for +// profiling. It also builds the data structures and initialization code needed +// for updating execution counts and emitting the profile at runtime. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Instrumentation.h" + +#include "llvm/ADT/Triple.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Module.h" +#include "llvm/Transforms/Utils/ModuleUtils.h" + +using namespace llvm; + +#define DEBUG_TYPE "instrprof" + +namespace { + +class InstrProfiling : public ModulePass { +public: + static char ID; + + InstrProfiling() : ModulePass(ID) {} + + InstrProfiling(const InstrProfOptions &Options) + : ModulePass(ID), Options(Options) {} + + const char *getPassName() const override { + return "Frontend instrumentation-based coverage lowering"; + } + + bool runOnModule(Module &M) override; + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesCFG(); + } + +private: + InstrProfOptions Options; + Module *M; + DenseMap<GlobalVariable *, GlobalVariable *> RegionCounters; + std::vector<Value *> UsedVars; + + bool isMachO() const { + return Triple(M->getTargetTriple()).isOSBinFormatMachO(); + } + + /// Get the section name for the counter variables. + StringRef getCountersSection() const { + return isMachO() ? "__DATA,__llvm_prf_cnts" : "__llvm_prf_cnts"; + } + + /// Get the section name for the name variables. + StringRef getNameSection() const { + return isMachO() ? "__DATA,__llvm_prf_names" : "__llvm_prf_names"; + } + + /// Get the section name for the profile data variables. + StringRef getDataSection() const { + return isMachO() ? "__DATA,__llvm_prf_data" : "__llvm_prf_data"; + } + + /// Replace instrprof_increment with an increment of the appropriate value. + void lowerIncrement(InstrProfIncrementInst *Inc); + + /// Get the region counters for an increment, creating them if necessary. + /// + /// If the counter array doesn't yet exist, the profile data variables + /// referring to them will also be created. + GlobalVariable *getOrCreateRegionCounters(InstrProfIncrementInst *Inc); + + /// Emit runtime registration functions for each profile data variable. + void emitRegistration(); + + /// Emit the necessary plumbing to pull in the runtime initialization. + void emitRuntimeHook(); + + /// Add uses of our data variables and runtime hook. + void emitUses(); + + /// Create a static initializer for our data, on platforms that need it. + void emitInitialization(); +}; + +} // anonymous namespace + +char InstrProfiling::ID = 0; +INITIALIZE_PASS(InstrProfiling, "instrprof", + "Frontend instrumentation-based coverage lowering.", false, + false) + +ModulePass *llvm::createInstrProfilingPass(const InstrProfOptions &Options) { + return new InstrProfiling(Options); +} + +bool InstrProfiling::runOnModule(Module &M) { + bool MadeChange = false; + + this->M = &M; + RegionCounters.clear(); + UsedVars.clear(); + + for (Function &F : M) + for (BasicBlock &BB : F) + for (auto I = BB.begin(), E = BB.end(); I != E;) + if (auto *Inc = dyn_cast<InstrProfIncrementInst>(I++)) { + lowerIncrement(Inc); + MadeChange = true; + } + if (!MadeChange) + return false; + + emitRegistration(); + emitRuntimeHook(); + emitUses(); + emitInitialization(); + return true; +} + +void InstrProfiling::lowerIncrement(InstrProfIncrementInst *Inc) { + GlobalVariable *Counters = getOrCreateRegionCounters(Inc); + + IRBuilder<> Builder(Inc->getParent(), *Inc); + uint64_t Index = Inc->getIndex()->getZExtValue(); + llvm::Value *Addr = Builder.CreateConstInBoundsGEP2_64(Counters, 0, Index); + llvm::Value *Count = Builder.CreateLoad(Addr, "pgocount"); + Count = Builder.CreateAdd(Count, Builder.getInt64(1)); + Inc->replaceAllUsesWith(Builder.CreateStore(Count, Addr)); + Inc->eraseFromParent(); +} + +/// Get the name of a profiling variable for a particular function. +static std::string getVarName(InstrProfIncrementInst *Inc, StringRef VarName) { + auto *Arr = cast<ConstantDataArray>(Inc->getName()->getInitializer()); + StringRef Name = Arr->isCString() ? Arr->getAsCString() : Arr->getAsString(); + return ("__llvm_profile_" + VarName + "_" + Name).str(); +} + +GlobalVariable * +InstrProfiling::getOrCreateRegionCounters(InstrProfIncrementInst *Inc) { + GlobalVariable *Name = Inc->getName(); + auto It = RegionCounters.find(Name); + if (It != RegionCounters.end()) + return It->second; + + // Move the name variable to the right section. + Name->setSection(getNameSection()); + Name->setAlignment(1); + + uint64_t NumCounters = Inc->getNumCounters()->getZExtValue(); + LLVMContext &Ctx = M->getContext(); + ArrayType *CounterTy = ArrayType::get(Type::getInt64Ty(Ctx), NumCounters); + + // Create the counters variable. + auto *Counters = new GlobalVariable(*M, CounterTy, false, Name->getLinkage(), + Constant::getNullValue(CounterTy), + getVarName(Inc, "counters")); + Counters->setVisibility(Name->getVisibility()); + Counters->setSection(getCountersSection()); + Counters->setAlignment(8); + + RegionCounters[Inc->getName()] = Counters; + + // Create data variable. + auto *NameArrayTy = Name->getType()->getPointerElementType(); + auto *Int32Ty = Type::getInt32Ty(Ctx); + auto *Int64Ty = Type::getInt64Ty(Ctx); + auto *Int8PtrTy = Type::getInt8PtrTy(Ctx); + auto *Int64PtrTy = Type::getInt64PtrTy(Ctx); + + Type *DataTypes[] = {Int32Ty, Int32Ty, Int64Ty, Int8PtrTy, Int64PtrTy}; + auto *DataTy = StructType::get(Ctx, makeArrayRef(DataTypes)); + Constant *DataVals[] = { + ConstantInt::get(Int32Ty, NameArrayTy->getArrayNumElements()), + ConstantInt::get(Int32Ty, NumCounters), + ConstantInt::get(Int64Ty, Inc->getHash()->getZExtValue()), + ConstantExpr::getBitCast(Name, Int8PtrTy), + ConstantExpr::getBitCast(Counters, Int64PtrTy)}; + auto *Data = new GlobalVariable(*M, DataTy, true, Name->getLinkage(), + ConstantStruct::get(DataTy, DataVals), + getVarName(Inc, "data")); + Data->setVisibility(Name->getVisibility()); + Data->setSection(getDataSection()); + Data->setAlignment(8); + + // Mark the data variable as used so that it isn't stripped out. + UsedVars.push_back(Data); + + return Counters; +} + +void InstrProfiling::emitRegistration() { + // Don't do this for Darwin. compiler-rt uses linker magic. + if (Triple(M->getTargetTriple()).isOSDarwin()) + return; + + // Construct the function. + auto *VoidTy = Type::getVoidTy(M->getContext()); + auto *VoidPtrTy = Type::getInt8PtrTy(M->getContext()); + auto *RegisterFTy = FunctionType::get(VoidTy, false); + auto *RegisterF = Function::Create(RegisterFTy, GlobalValue::InternalLinkage, + "__llvm_profile_register_functions", M); + RegisterF->setUnnamedAddr(true); + if (Options.NoRedZone) + RegisterF->addFnAttr(Attribute::NoRedZone); + + auto *RuntimeRegisterTy = llvm::FunctionType::get(VoidTy, VoidPtrTy, false); + auto *RuntimeRegisterF = + Function::Create(RuntimeRegisterTy, GlobalVariable::ExternalLinkage, + "__llvm_profile_register_function", M); + + IRBuilder<> IRB(BasicBlock::Create(M->getContext(), "", RegisterF)); + for (Value *Data : UsedVars) + IRB.CreateCall(RuntimeRegisterF, IRB.CreateBitCast(Data, VoidPtrTy)); + IRB.CreateRetVoid(); +} + +void InstrProfiling::emitRuntimeHook() { + const char *const RuntimeVarName = "__llvm_profile_runtime"; + const char *const RuntimeUserName = "__llvm_profile_runtime_user"; + + // If the module's provided its own runtime, we don't need to do anything. + if (M->getGlobalVariable(RuntimeVarName)) + return; + + // Declare an external variable that will pull in the runtime initialization. + auto *Int32Ty = Type::getInt32Ty(M->getContext()); + auto *Var = + new GlobalVariable(*M, Int32Ty, false, GlobalValue::ExternalLinkage, + nullptr, RuntimeVarName); + + // Make a function that uses it. + auto *User = + Function::Create(FunctionType::get(Int32Ty, false), + GlobalValue::LinkOnceODRLinkage, RuntimeUserName, M); + User->addFnAttr(Attribute::NoInline); + if (Options.NoRedZone) + User->addFnAttr(Attribute::NoRedZone); + + IRBuilder<> IRB(BasicBlock::Create(M->getContext(), "", User)); + auto *Load = IRB.CreateLoad(Var); + IRB.CreateRet(Load); + + // Mark the user variable as used so that it isn't stripped out. + UsedVars.push_back(User); +} + +void InstrProfiling::emitUses() { + if (UsedVars.empty()) + return; + + GlobalVariable *LLVMUsed = M->getGlobalVariable("llvm.used"); + std::vector<Constant*> MergedVars; + if (LLVMUsed) { + // Collect the existing members of llvm.used. + ConstantArray *Inits = cast<ConstantArray>(LLVMUsed->getInitializer()); + for (unsigned I = 0, E = Inits->getNumOperands(); I != E; ++I) + MergedVars.push_back(Inits->getOperand(I)); + LLVMUsed->eraseFromParent(); + } + + Type *i8PTy = Type::getInt8PtrTy(M->getContext()); + // Add uses for our data. + for (auto *Value : UsedVars) + MergedVars.push_back( + ConstantExpr::getBitCast(cast<llvm::Constant>(Value), i8PTy)); + + // Recreate llvm.used. + ArrayType *ATy = ArrayType::get(i8PTy, MergedVars.size()); + LLVMUsed = new llvm::GlobalVariable( + *M, ATy, false, llvm::GlobalValue::AppendingLinkage, + llvm::ConstantArray::get(ATy, MergedVars), "llvm.used"); + + LLVMUsed->setSection("llvm.metadata"); +} + +void InstrProfiling::emitInitialization() { + Constant *RegisterF = M->getFunction("__llvm_profile_register_functions"); + if (!RegisterF) + return; + + // Create the initialization function. + auto *VoidTy = Type::getVoidTy(M->getContext()); + auto *F = + Function::Create(FunctionType::get(VoidTy, false), + GlobalValue::InternalLinkage, "__llvm_profile_init", M); + F->setUnnamedAddr(true); + F->addFnAttr(Attribute::NoInline); + if (Options.NoRedZone) + F->addFnAttr(Attribute::NoRedZone); + + // Add the basic block and the necessary calls. + IRBuilder<> IRB(BasicBlock::Create(M->getContext(), "", F)); + IRB.CreateCall(RegisterF); + IRB.CreateRetVoid(); + + appendToGlobalCtors(*M, F, 0); +} diff --git a/lib/Transforms/Instrumentation/Instrumentation.cpp b/lib/Transforms/Instrumentation/Instrumentation.cpp index ac1dd43c3ae4..a91fc0ec2a48 100644 --- a/lib/Transforms/Instrumentation/Instrumentation.cpp +++ b/lib/Transforms/Instrumentation/Instrumentation.cpp @@ -25,8 +25,10 @@ void llvm::initializeInstrumentation(PassRegistry &Registry) { initializeAddressSanitizerModulePass(Registry); initializeBoundsCheckingPass(Registry); initializeGCOVProfilerPass(Registry); + initializeInstrProfilingPass(Registry); initializeMemorySanitizerPass(Registry); initializeThreadSanitizerPass(Registry); + initializeSanitizerCoverageModulePass(Registry); initializeDataFlowSanitizerPass(Registry); } diff --git a/lib/Transforms/Instrumentation/LLVMBuild.txt b/lib/Transforms/Instrumentation/LLVMBuild.txt index 99e95dfa375a..59249e6ab443 100644 --- a/lib/Transforms/Instrumentation/LLVMBuild.txt +++ b/lib/Transforms/Instrumentation/LLVMBuild.txt @@ -19,4 +19,4 @@ type = Library name = Instrumentation parent = Transforms -required_libraries = Analysis Core Support Target TransformUtils +required_libraries = Analysis Core MC Support Target TransformUtils diff --git a/lib/Transforms/Instrumentation/MemorySanitizer.cpp b/lib/Transforms/Instrumentation/MemorySanitizer.cpp index 57e308c20dba..9f00d3d6c824 100644 --- a/lib/Transforms/Instrumentation/MemorySanitizer.cpp +++ b/lib/Transforms/Instrumentation/MemorySanitizer.cpp @@ -120,13 +120,13 @@ using namespace llvm; #define DEBUG_TYPE "msan" -static const uint64_t kShadowMask32 = 1ULL << 31; -static const uint64_t kShadowMask64 = 1ULL << 46; -static const uint64_t kOriginOffset32 = 1ULL << 30; -static const uint64_t kOriginOffset64 = 1ULL << 45; static const unsigned kMinOriginAlignment = 4; static const unsigned kShadowTLSAlignment = 8; +// These constants must be kept in sync with the ones in msan.h. +static const unsigned kParamTLSSize = 800; +static const unsigned kRetvalTLSSize = 800; + // Accesses sizes are powers of two: 1, 2, 4, 8. static const size_t kNumberOfAccessSizes = 4; @@ -183,20 +183,73 @@ static cl::opt<int> ClInstrumentationWithCallThreshold( "inline checks (-1 means never use callbacks)."), cl::Hidden, cl::init(3500)); -// Experimental. Wraps all indirect calls in the instrumented code with -// a call to the given function. This is needed to assist the dynamic -// helper tool (MSanDR) to regain control on transition between instrumented and -// non-instrumented code. -static cl::opt<std::string> ClWrapIndirectCalls("msan-wrap-indirect-calls", - cl::desc("Wrap indirect calls with a given function"), - cl::Hidden); - -static cl::opt<bool> ClWrapIndirectCallsFast("msan-wrap-indirect-calls-fast", - cl::desc("Do not wrap indirect calls with target in the same module"), - cl::Hidden, cl::init(true)); +// This is an experiment to enable handling of cases where shadow is a non-zero +// compile-time constant. For some unexplainable reason they were silently +// ignored in the instrumentation. +static cl::opt<bool> ClCheckConstantShadow("msan-check-constant-shadow", + cl::desc("Insert checks for constant shadow values"), + cl::Hidden, cl::init(false)); namespace { +// Memory map parameters used in application-to-shadow address calculation. +// Offset = (Addr & ~AndMask) ^ XorMask +// Shadow = ShadowBase + Offset +// Origin = OriginBase + Offset +struct MemoryMapParams { + uint64_t AndMask; + uint64_t XorMask; + uint64_t ShadowBase; + uint64_t OriginBase; +}; + +struct PlatformMemoryMapParams { + const MemoryMapParams *bits32; + const MemoryMapParams *bits64; +}; + +// i386 Linux +static const MemoryMapParams LinuxMemoryMapParams32 = { + 0x000080000000, // AndMask + 0, // XorMask (not used) + 0, // ShadowBase (not used) + 0x000040000000, // OriginBase +}; + +// x86_64 Linux +static const MemoryMapParams LinuxMemoryMapParams64 = { + 0x400000000000, // AndMask + 0, // XorMask (not used) + 0, // ShadowBase (not used) + 0x200000000000, // OriginBase +}; + +// i386 FreeBSD +static const MemoryMapParams FreeBSDMemoryMapParams32 = { + 0x000180000000, // AndMask + 0x000040000000, // XorMask + 0x000020000000, // ShadowBase + 0x000700000000, // OriginBase +}; + +// x86_64 FreeBSD +static const MemoryMapParams FreeBSDMemoryMapParams64 = { + 0xc00000000000, // AndMask + 0x200000000000, // XorMask + 0x100000000000, // ShadowBase + 0x380000000000, // OriginBase +}; + +static const PlatformMemoryMapParams LinuxMemoryMapParams = { + &LinuxMemoryMapParams32, + &LinuxMemoryMapParams64, +}; + +static const PlatformMemoryMapParams FreeBSDMemoryMapParams = { + &FreeBSDMemoryMapParams32, + &FreeBSDMemoryMapParams64, +}; + /// \brief An instrumentation pass implementing detection of uninitialized /// reads. /// @@ -208,8 +261,7 @@ class MemorySanitizer : public FunctionPass { : FunctionPass(ID), TrackOrigins(std::max(TrackOrigins, (int)ClTrackOrigins)), DL(nullptr), - WarningFn(nullptr), - WrapIndirectCalls(!ClWrapIndirectCalls.empty()) {} + WarningFn(nullptr) {} const char *getPassName() const override { return "MemorySanitizer"; } bool runOnFunction(Function &F) override; bool doInitialization(Module &M) override; @@ -243,9 +295,6 @@ class MemorySanitizer : public FunctionPass { /// function. GlobalVariable *OriginTLS; - GlobalVariable *MsandrModuleStart; - GlobalVariable *MsandrModuleEnd; - /// \brief The run-time callback to print a warning. Value *WarningFn; // These arrays are indexed by log2(AccessSize). @@ -263,25 +312,15 @@ class MemorySanitizer : public FunctionPass { /// \brief MSan runtime replacements for memmove, memcpy and memset. Value *MemmoveFn, *MemcpyFn, *MemsetFn; - /// \brief Address mask used in application-to-shadow address calculation. - /// ShadowAddr is computed as ApplicationAddr & ~ShadowMask. - uint64_t ShadowMask; - /// \brief Offset of the origin shadow from the "normal" shadow. - /// OriginAddr is computed as (ShadowAddr + OriginOffset) & ~3ULL - uint64_t OriginOffset; - /// \brief Branch weights for error reporting. + /// \brief Memory map parameters used in application-to-shadow calculation. + const MemoryMapParams *MapParams; + MDNode *ColdCallWeights; /// \brief Branch weights for origin store. MDNode *OriginStoreWeights; /// \brief An empty volatile inline asm that prevents callback merge. InlineAsm *EmptyAsm; - bool WrapIndirectCalls; - /// \brief Run-time wrapper for indirect calls. - Value *IndirectCallWrapperFn; - // Argument and return type of IndirectCallWrapperFn: void (*f)(void). - Type *AnyFunctionPtrTy; - friend struct MemorySanitizerVisitor; friend struct VarArgAMD64Helper; }; @@ -321,7 +360,7 @@ void MemorySanitizer::initializeCallbacks(Module &M) { // which is not yet implemented. StringRef WarningFnName = ClKeepGoing ? "__msan_warning" : "__msan_warning_noreturn"; - WarningFn = M.getOrInsertFunction(WarningFnName, IRB.getVoidTy(), NULL); + WarningFn = M.getOrInsertFunction(WarningFnName, IRB.getVoidTy(), nullptr); for (size_t AccessSizeIndex = 0; AccessSizeIndex < kNumberOfAccessSizes; AccessSizeIndex++) { @@ -329,34 +368,35 @@ void MemorySanitizer::initializeCallbacks(Module &M) { std::string FunctionName = "__msan_maybe_warning_" + itostr(AccessSize); MaybeWarningFn[AccessSizeIndex] = M.getOrInsertFunction( FunctionName, IRB.getVoidTy(), IRB.getIntNTy(AccessSize * 8), - IRB.getInt32Ty(), NULL); + IRB.getInt32Ty(), nullptr); FunctionName = "__msan_maybe_store_origin_" + itostr(AccessSize); MaybeStoreOriginFn[AccessSizeIndex] = M.getOrInsertFunction( FunctionName, IRB.getVoidTy(), IRB.getIntNTy(AccessSize * 8), - IRB.getInt8PtrTy(), IRB.getInt32Ty(), NULL); + IRB.getInt8PtrTy(), IRB.getInt32Ty(), nullptr); } MsanSetAllocaOrigin4Fn = M.getOrInsertFunction( "__msan_set_alloca_origin4", IRB.getVoidTy(), IRB.getInt8PtrTy(), IntptrTy, - IRB.getInt8PtrTy(), IntptrTy, NULL); - MsanPoisonStackFn = M.getOrInsertFunction( - "__msan_poison_stack", IRB.getVoidTy(), IRB.getInt8PtrTy(), IntptrTy, NULL); + IRB.getInt8PtrTy(), IntptrTy, nullptr); + MsanPoisonStackFn = + M.getOrInsertFunction("__msan_poison_stack", IRB.getVoidTy(), + IRB.getInt8PtrTy(), IntptrTy, nullptr); MsanChainOriginFn = M.getOrInsertFunction( - "__msan_chain_origin", IRB.getInt32Ty(), IRB.getInt32Ty(), NULL); + "__msan_chain_origin", IRB.getInt32Ty(), IRB.getInt32Ty(), nullptr); MemmoveFn = M.getOrInsertFunction( "__msan_memmove", IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), - IRB.getInt8PtrTy(), IntptrTy, NULL); + IRB.getInt8PtrTy(), IntptrTy, nullptr); MemcpyFn = M.getOrInsertFunction( "__msan_memcpy", IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), - IntptrTy, NULL); + IntptrTy, nullptr); MemsetFn = M.getOrInsertFunction( "__msan_memset", IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IRB.getInt32Ty(), - IntptrTy, NULL); + IntptrTy, nullptr); // Create globals. RetvalTLS = new GlobalVariable( - M, ArrayType::get(IRB.getInt64Ty(), 8), false, + M, ArrayType::get(IRB.getInt64Ty(), kRetvalTLSSize / 8), false, GlobalVariable::ExternalLinkage, nullptr, "__msan_retval_tls", nullptr, GlobalVariable::InitialExecTLSModel); RetvalOriginTLS = new GlobalVariable( @@ -364,16 +404,16 @@ void MemorySanitizer::initializeCallbacks(Module &M) { "__msan_retval_origin_tls", nullptr, GlobalVariable::InitialExecTLSModel); ParamTLS = new GlobalVariable( - M, ArrayType::get(IRB.getInt64Ty(), 1000), false, + M, ArrayType::get(IRB.getInt64Ty(), kParamTLSSize / 8), false, GlobalVariable::ExternalLinkage, nullptr, "__msan_param_tls", nullptr, GlobalVariable::InitialExecTLSModel); ParamOriginTLS = new GlobalVariable( - M, ArrayType::get(OriginTy, 1000), false, GlobalVariable::ExternalLinkage, - nullptr, "__msan_param_origin_tls", nullptr, - GlobalVariable::InitialExecTLSModel); + M, ArrayType::get(OriginTy, kParamTLSSize / 4), false, + GlobalVariable::ExternalLinkage, nullptr, "__msan_param_origin_tls", + nullptr, GlobalVariable::InitialExecTLSModel); VAArgTLS = new GlobalVariable( - M, ArrayType::get(IRB.getInt64Ty(), 1000), false, + M, ArrayType::get(IRB.getInt64Ty(), kParamTLSSize / 8), false, GlobalVariable::ExternalLinkage, nullptr, "__msan_va_arg_tls", nullptr, GlobalVariable::InitialExecTLSModel); VAArgOverflowSizeTLS = new GlobalVariable( @@ -388,24 +428,6 @@ void MemorySanitizer::initializeCallbacks(Module &M) { EmptyAsm = InlineAsm::get(FunctionType::get(IRB.getVoidTy(), false), StringRef(""), StringRef(""), /*hasSideEffects=*/true); - - if (WrapIndirectCalls) { - AnyFunctionPtrTy = - PointerType::getUnqual(FunctionType::get(IRB.getVoidTy(), false)); - IndirectCallWrapperFn = M.getOrInsertFunction( - ClWrapIndirectCalls, AnyFunctionPtrTy, AnyFunctionPtrTy, NULL); - } - - if (WrapIndirectCalls && ClWrapIndirectCallsFast) { - MsandrModuleStart = new GlobalVariable( - M, IRB.getInt32Ty(), false, GlobalValue::ExternalLinkage, - nullptr, "__executable_start"); - MsandrModuleStart->setVisibility(GlobalVariable::HiddenVisibility); - MsandrModuleEnd = new GlobalVariable( - M, IRB.getInt32Ty(), false, GlobalValue::ExternalLinkage, - nullptr, "_end"); - MsandrModuleEnd->setVisibility(GlobalVariable::HiddenVisibility); - } } /// \brief Module-level initialization. @@ -417,16 +439,21 @@ bool MemorySanitizer::doInitialization(Module &M) { report_fatal_error("data layout missing"); DL = &DLP->getDataLayout(); + Triple TargetTriple(M.getTargetTriple()); + const PlatformMemoryMapParams *PlatformMapParams; + if (TargetTriple.getOS() == Triple::FreeBSD) + PlatformMapParams = &FreeBSDMemoryMapParams; + else + PlatformMapParams = &LinuxMemoryMapParams; + C = &(M.getContext()); unsigned PtrSize = DL->getPointerSizeInBits(/* AddressSpace */0); switch (PtrSize) { case 64: - ShadowMask = kShadowMask64; - OriginOffset = kOriginOffset64; + MapParams = PlatformMapParams->bits64; break; case 32: - ShadowMask = kShadowMask32; - OriginOffset = kOriginOffset32; + MapParams = PlatformMapParams->bits32; break; default: report_fatal_error("unsupported pointer size"); @@ -442,7 +469,7 @@ bool MemorySanitizer::doInitialization(Module &M) { // Insert a call to __msan_init/__msan_track_origins into the module's CTORs. appendToGlobalCtors(M, cast<Function>(M.getOrInsertFunction( - "__msan_init", IRB.getVoidTy(), NULL)), 0); + "__msan_init", IRB.getVoidTy(), nullptr)), 0); if (TrackOrigins) new GlobalVariable(M, IRB.getInt32Ty(), true, GlobalValue::WeakODRLinkage, @@ -525,7 +552,6 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { }; SmallVector<ShadowOriginAndInsertPoint, 16> InstrumentationList; SmallVector<Instruction*, 16> StoreList; - SmallVector<CallSite, 16> IndirectCallList; MemorySanitizerVisitor(Function &F, MemorySanitizer &MS) : F(F), MS(MS), VAHelper(CreateVarArgHelper(F, MS, *this)) { @@ -551,15 +577,18 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { void storeOrigin(IRBuilder<> &IRB, Value *Addr, Value *Shadow, Value *Origin, unsigned Alignment, bool AsCall) { + unsigned OriginAlignment = std::max(kMinOriginAlignment, Alignment); if (isa<StructType>(Shadow->getType())) { - IRB.CreateAlignedStore(updateOrigin(Origin, IRB), getOriginPtr(Addr, IRB), - Alignment); + IRB.CreateAlignedStore(updateOrigin(Origin, IRB), + getOriginPtr(Addr, IRB, Alignment), + OriginAlignment); } else { Value *ConvertedShadow = convertToShadowTyNoVec(Shadow, IRB); // TODO(eugenis): handle non-zero constant shadow by inserting an // unconditional check (can not simply fail compilation as this could // be in the dead code). - if (isa<Constant>(ConvertedShadow)) return; + if (!ClCheckConstantShadow) + if (isa<Constant>(ConvertedShadow)) return; unsigned TypeSizeInBits = MS.DL->getTypeSizeInBits(ConvertedShadow->getType()); unsigned SizeIndex = TypeSizeToSizeIndex(TypeSizeInBits); @@ -577,7 +606,8 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { Cmp, IRB.GetInsertPoint(), false, MS.OriginStoreWeights); IRBuilder<> IRBNew(CheckTerm); IRBNew.CreateAlignedStore(updateOrigin(Origin, IRBNew), - getOriginPtr(Addr, IRBNew), Alignment); + getOriginPtr(Addr, IRBNew, Alignment), + OriginAlignment); } } } @@ -601,11 +631,9 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { if (SI.isAtomic()) SI.setOrdering(addReleaseOrdering(SI.getOrdering())); - if (MS.TrackOrigins) { - unsigned Alignment = std::max(kMinOriginAlignment, SI.getAlignment()); - storeOrigin(IRB, Addr, Shadow, getOrigin(Val), Alignment, + if (MS.TrackOrigins) + storeOrigin(IRB, Addr, Shadow, getOrigin(Val), SI.getAlignment(), InstrumentWithCalls); - } } } @@ -615,8 +643,9 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { DEBUG(dbgs() << " SHAD0 : " << *Shadow << "\n"); Value *ConvertedShadow = convertToShadowTyNoVec(Shadow, IRB); DEBUG(dbgs() << " SHAD1 : " << *ConvertedShadow << "\n"); - // See the comment in materializeStores(). - if (isa<Constant>(ConvertedShadow)) return; + // See the comment in storeOrigin(). + if (!ClCheckConstantShadow) + if (isa<Constant>(ConvertedShadow)) return; unsigned TypeSizeInBits = MS.DL->getTypeSizeInBits(ConvertedShadow->getType()); unsigned SizeIndex = TypeSizeToSizeIndex(TypeSizeInBits); @@ -655,47 +684,6 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { DEBUG(dbgs() << "DONE:\n" << F); } - void materializeIndirectCalls() { - for (auto &CS : IndirectCallList) { - Instruction *I = CS.getInstruction(); - BasicBlock *B = I->getParent(); - IRBuilder<> IRB(I); - Value *Fn0 = CS.getCalledValue(); - Value *Fn = IRB.CreateBitCast(Fn0, MS.AnyFunctionPtrTy); - - if (ClWrapIndirectCallsFast) { - // Check that call target is inside this module limits. - Value *Start = - IRB.CreateBitCast(MS.MsandrModuleStart, MS.AnyFunctionPtrTy); - Value *End = IRB.CreateBitCast(MS.MsandrModuleEnd, MS.AnyFunctionPtrTy); - - Value *NotInThisModule = IRB.CreateOr(IRB.CreateICmpULT(Fn, Start), - IRB.CreateICmpUGE(Fn, End)); - - PHINode *NewFnPhi = - IRB.CreatePHI(Fn0->getType(), 2, "msandr.indirect_target"); - - Instruction *CheckTerm = SplitBlockAndInsertIfThen( - NotInThisModule, NewFnPhi, - /* Unreachable */ false, MS.ColdCallWeights); - - IRB.SetInsertPoint(CheckTerm); - // Slow path: call wrapper function to possibly transform the call - // target. - Value *NewFn = IRB.CreateBitCast( - IRB.CreateCall(MS.IndirectCallWrapperFn, Fn), Fn0->getType()); - - NewFnPhi->addIncoming(Fn0, B); - NewFnPhi->addIncoming(NewFn, dyn_cast<Instruction>(NewFn)->getParent()); - CS.setCalledFunction(NewFnPhi); - } else { - Value *NewFn = IRB.CreateBitCast( - IRB.CreateCall(MS.IndirectCallWrapperFn, Fn), Fn0->getType()); - CS.setCalledFunction(NewFn); - } - } - } - /// \brief Add MemorySanitizer instrumentation to a function. bool runOnFunction() { MS.initializeCallbacks(*F.getParent()); @@ -738,9 +726,6 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { // Insert shadow value checks. materializeChecks(InstrumentWithCalls); - // Wrap indirect calls. - materializeIndirectCalls(); - return true; } @@ -763,6 +748,10 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { return VectorType::get(IntegerType::get(*MS.C, EltSize), VT->getNumElements()); } + if (ArrayType *AT = dyn_cast<ArrayType>(OrigTy)) { + return ArrayType::get(getShadowTy(AT->getElementType()), + AT->getNumElements()); + } if (StructType *ST = dyn_cast<StructType>(OrigTy)) { SmallVector<Type*, 4> Elements; for (unsigned i = 0, n = ST->getNumElements(); i < n; i++) @@ -790,32 +779,57 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { return IRB.CreateBitCast(V, NoVecTy); } + /// \brief Compute the integer shadow offset that corresponds to a given + /// application address. + /// + /// Offset = (Addr & ~AndMask) ^ XorMask + Value *getShadowPtrOffset(Value *Addr, IRBuilder<> &IRB) { + uint64_t AndMask = MS.MapParams->AndMask; + assert(AndMask != 0 && "AndMask shall be specified"); + Value *OffsetLong = + IRB.CreateAnd(IRB.CreatePointerCast(Addr, MS.IntptrTy), + ConstantInt::get(MS.IntptrTy, ~AndMask)); + + uint64_t XorMask = MS.MapParams->XorMask; + if (XorMask != 0) + OffsetLong = IRB.CreateXor(OffsetLong, + ConstantInt::get(MS.IntptrTy, XorMask)); + return OffsetLong; + } + /// \brief Compute the shadow address that corresponds to a given application /// address. /// - /// Shadow = Addr & ~ShadowMask. + /// Shadow = ShadowBase + Offset Value *getShadowPtr(Value *Addr, Type *ShadowTy, IRBuilder<> &IRB) { - Value *ShadowLong = - IRB.CreateAnd(IRB.CreatePointerCast(Addr, MS.IntptrTy), - ConstantInt::get(MS.IntptrTy, ~MS.ShadowMask)); + Value *ShadowLong = getShadowPtrOffset(Addr, IRB); + uint64_t ShadowBase = MS.MapParams->ShadowBase; + if (ShadowBase != 0) + ShadowLong = + IRB.CreateAdd(ShadowLong, + ConstantInt::get(MS.IntptrTy, ShadowBase)); return IRB.CreateIntToPtr(ShadowLong, PointerType::get(ShadowTy, 0)); } /// \brief Compute the origin address that corresponds to a given application /// address. /// - /// OriginAddr = (ShadowAddr + OriginOffset) & ~3ULL - Value *getOriginPtr(Value *Addr, IRBuilder<> &IRB) { - Value *ShadowLong = - IRB.CreateAnd(IRB.CreatePointerCast(Addr, MS.IntptrTy), - ConstantInt::get(MS.IntptrTy, ~MS.ShadowMask)); - Value *Add = - IRB.CreateAdd(ShadowLong, - ConstantInt::get(MS.IntptrTy, MS.OriginOffset)); - Value *SecondAnd = - IRB.CreateAnd(Add, ConstantInt::get(MS.IntptrTy, ~3ULL)); - return IRB.CreateIntToPtr(SecondAnd, PointerType::get(IRB.getInt32Ty(), 0)); + /// OriginAddr = (OriginBase + Offset) & ~3ULL + Value *getOriginPtr(Value *Addr, IRBuilder<> &IRB, unsigned Alignment) { + Value *OriginLong = getShadowPtrOffset(Addr, IRB); + uint64_t OriginBase = MS.MapParams->OriginBase; + if (OriginBase != 0) + OriginLong = + IRB.CreateAdd(OriginLong, + ConstantInt::get(MS.IntptrTy, OriginBase)); + if (Alignment < kMinOriginAlignment) { + uint64_t Mask = kMinOriginAlignment - 1; + OriginLong = IRB.CreateAnd(OriginLong, + ConstantInt::get(MS.IntptrTy, ~Mask)); + } + return IRB.CreateIntToPtr(OriginLong, + PointerType::get(IRB.getInt32Ty(), 0)); } /// \brief Compute the shadow address for a given function argument. @@ -882,11 +896,18 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { assert(ShadowTy); if (isa<IntegerType>(ShadowTy) || isa<VectorType>(ShadowTy)) return Constant::getAllOnesValue(ShadowTy); - StructType *ST = cast<StructType>(ShadowTy); - SmallVector<Constant *, 4> Vals; - for (unsigned i = 0, n = ST->getNumElements(); i < n; i++) - Vals.push_back(getPoisonedShadow(ST->getElementType(i))); - return ConstantStruct::get(ST, Vals); + if (ArrayType *AT = dyn_cast<ArrayType>(ShadowTy)) { + SmallVector<Constant *, 4> Vals(AT->getNumElements(), + getPoisonedShadow(AT->getElementType())); + return ConstantArray::get(AT, Vals); + } + if (StructType *ST = dyn_cast<StructType>(ShadowTy)) { + SmallVector<Constant *, 4> Vals; + for (unsigned i = 0, n = ST->getNumElements(); i < n; i++) + Vals.push_back(getPoisonedShadow(ST->getElementType(i))); + return ConstantStruct::get(ST, Vals); + } + llvm_unreachable("Unexpected shadow type"); } /// \brief Create a dirty shadow for a given value. @@ -941,6 +962,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { ? MS.DL->getTypeAllocSize(FArg.getType()->getPointerElementType()) : MS.DL->getTypeAllocSize(FArg.getType()); if (A == &FArg) { + bool Overflow = ArgOffset + Size > kParamTLSSize; Value *Base = getShadowPtrForArgument(&FArg, EntryIRB, ArgOffset); if (FArg.hasByValAttr()) { // ByVal pointer itself has clean shadow. We copy the actual @@ -951,25 +973,40 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { Type *EltType = A->getType()->getPointerElementType(); ArgAlign = MS.DL->getABITypeAlignment(EltType); } - unsigned CopyAlign = std::min(ArgAlign, kShadowTLSAlignment); - Value *Cpy = EntryIRB.CreateMemCpy( - getShadowPtr(V, EntryIRB.getInt8Ty(), EntryIRB), Base, Size, - CopyAlign); - DEBUG(dbgs() << " ByValCpy: " << *Cpy << "\n"); - (void)Cpy; + if (Overflow) { + // ParamTLS overflow. + EntryIRB.CreateMemSet( + getShadowPtr(V, EntryIRB.getInt8Ty(), EntryIRB), + Constant::getNullValue(EntryIRB.getInt8Ty()), Size, ArgAlign); + } else { + unsigned CopyAlign = std::min(ArgAlign, kShadowTLSAlignment); + Value *Cpy = EntryIRB.CreateMemCpy( + getShadowPtr(V, EntryIRB.getInt8Ty(), EntryIRB), Base, Size, + CopyAlign); + DEBUG(dbgs() << " ByValCpy: " << *Cpy << "\n"); + (void)Cpy; + } *ShadowPtr = getCleanShadow(V); } else { - *ShadowPtr = EntryIRB.CreateAlignedLoad(Base, kShadowTLSAlignment); + if (Overflow) { + // ParamTLS overflow. + *ShadowPtr = getCleanShadow(V); + } else { + *ShadowPtr = + EntryIRB.CreateAlignedLoad(Base, kShadowTLSAlignment); + } } DEBUG(dbgs() << " ARG: " << FArg << " ==> " << **ShadowPtr << "\n"); - if (MS.TrackOrigins) { + if (MS.TrackOrigins && !Overflow) { Value *OriginPtr = getOriginPtrForArgument(&FArg, EntryIRB, ArgOffset); setOrigin(A, EntryIRB.CreateLoad(OriginPtr)); + } else { + setOrigin(A, getCleanOrigin()); } } - ArgOffset += DataLayout::RoundUpAlignment(Size, kShadowTLSAlignment); + ArgOffset += RoundUpToAlignment(Size, kShadowTLSAlignment); } assert(*ShadowPtr && "Could not find shadow for an argument"); return *ShadowPtr; @@ -986,15 +1023,13 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { /// \brief Get the origin for a value. Value *getOrigin(Value *V) { if (!MS.TrackOrigins) return nullptr; - if (isa<Instruction>(V) || isa<Argument>(V)) { - Value *Origin = OriginMap[V]; - if (!Origin) { - DEBUG(dbgs() << "NO ORIGIN: " << *V << "\n"); - Origin = getCleanOrigin(); - } - return Origin; - } - return getCleanOrigin(); + if (!PropagateShadow) return getCleanOrigin(); + if (isa<Constant>(V)) return getCleanOrigin(); + assert((isa<Instruction>(V) || isa<Argument>(V)) && + "Unexpected value type in getOrigin()"); + Value *Origin = OriginMap[V]; + assert(Origin && "Missing origin"); + return Origin; } /// \brief Get the origin for i-th argument of the instruction I. @@ -1024,9 +1059,16 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { /// UMR warning in runtime if the value is not fully defined. void insertShadowCheck(Value *Val, Instruction *OrigIns) { assert(Val); - Instruction *Shadow = dyn_cast_or_null<Instruction>(getShadow(Val)); - if (!Shadow) return; - Instruction *Origin = dyn_cast_or_null<Instruction>(getOrigin(Val)); + Value *Shadow, *Origin; + if (ClCheckConstantShadow) { + Shadow = getShadow(Val); + if (!Shadow) return; + Origin = getOrigin(Val); + } else { + Shadow = dyn_cast_or_null<Instruction>(getShadow(Val)); + if (!Shadow) return; + Origin = dyn_cast_or_null<Instruction>(getOrigin(Val)); + } insertShadowCheck(Shadow, Origin, OrigIns); } @@ -1075,7 +1117,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { IRBuilder<> IRB(I.getNextNode()); Type *ShadowTy = getShadowTy(&I); Value *Addr = I.getPointerOperand(); - if (PropagateShadow) { + if (PropagateShadow && !I.getMetadata("nosanitize")) { Value *ShadowPtr = getShadowPtr(Addr, ShadowTy, IRB); setShadow(&I, IRB.CreateAlignedLoad(ShadowPtr, I.getAlignment(), "_msld")); @@ -1091,9 +1133,10 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { if (MS.TrackOrigins) { if (PropagateShadow) { - unsigned Alignment = std::max(kMinOriginAlignment, I.getAlignment()); - setOrigin(&I, - IRB.CreateAlignedLoad(getOriginPtr(Addr, IRB), Alignment)); + unsigned Alignment = I.getAlignment(); + unsigned OriginAlignment = std::max(kMinOriginAlignment, Alignment); + setOrigin(&I, IRB.CreateAlignedLoad(getOriginPtr(Addr, IRB, Alignment), + OriginAlignment)); } else { setOrigin(&I, getCleanOrigin()); } @@ -1127,6 +1170,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { IRB.CreateStore(getCleanShadow(&I), ShadowPtr); setShadow(&I, getCleanShadow(&I)); + setOrigin(&I, getCleanOrigin()); } void visitAtomicRMWInst(AtomicRMWInst &I) { @@ -1744,7 +1788,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { // FIXME: use ClStoreCleanOrigin // FIXME: factor out common code from materializeStores if (MS.TrackOrigins) - IRB.CreateStore(getOrigin(&I, 1), getOriginPtr(Addr, IRB)); + IRB.CreateStore(getOrigin(&I, 1), getOriginPtr(Addr, IRB, 1)); return true; } @@ -1771,7 +1815,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { if (MS.TrackOrigins) { if (PropagateShadow) - setOrigin(&I, IRB.CreateLoad(getOriginPtr(Addr, IRB))); + setOrigin(&I, IRB.CreateLoad(getOriginPtr(Addr, IRB, 1))); else setOrigin(&I, getCleanOrigin()); } @@ -1859,7 +1903,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { Value *Op = I.getArgOperand(0); Type *OpType = Op->getType(); Function *BswapFunc = Intrinsic::getDeclaration( - F.getParent(), Intrinsic::bswap, ArrayRef<Type*>(&OpType, 1)); + F.getParent(), Intrinsic::bswap, makeArrayRef(&OpType, 1)); setShadow(&I, IRB.CreateCall(BswapFunc, getShadow(Op))); setOrigin(&I, getOrigin(Op)); } @@ -1935,6 +1979,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { setOrigin(&I, getOrigin(CopyOp)); } else { setShadow(&I, getCleanShadow(&I)); + setOrigin(&I, getCleanOrigin()); } } @@ -2291,9 +2336,6 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { } IRBuilder<> IRB(&I); - if (MS.WrapIndirectCalls && !CS.getCalledFunction()) - IndirectCallList.push_back(CS); - unsigned ArgOffset = 0; DEBUG(dbgs() << " CallSite: " << I << "\n"); for (CallSite::arg_iterator ArgIt = CS.arg_begin(), End = CS.arg_end(); @@ -2318,12 +2360,15 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { assert(A->getType()->isPointerTy() && "ByVal argument is not a pointer!"); Size = MS.DL->getTypeAllocSize(A->getType()->getPointerElementType()); - unsigned Alignment = CS.getParamAlignment(i + 1); + if (ArgOffset + Size > kParamTLSSize) break; + unsigned ParamAlignment = CS.getParamAlignment(i + 1); + unsigned Alignment = std::min(ParamAlignment, kShadowTLSAlignment); Store = IRB.CreateMemCpy(ArgShadowBase, getShadowPtr(A, Type::getInt8Ty(*MS.C), IRB), Size, Alignment); } else { Size = MS.DL->getTypeAllocSize(A->getType()); + if (ArgOffset + Size > kParamTLSSize) break; Store = IRB.CreateAlignedStore(ArgShadow, ArgShadowBase, kShadowTLSAlignment); Constant *Cst = dyn_cast<Constant>(ArgShadow); @@ -2335,7 +2380,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { (void)Store; assert(Size != 0 && Store != nullptr); DEBUG(dbgs() << " Param:" << *Store << "\n"); - ArgOffset += DataLayout::RoundUpAlignment(Size, 8); + ArgOffset += RoundUpToAlignment(Size, 8); } DEBUG(dbgs() << " done with call args\n"); @@ -2399,6 +2444,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { IRBuilder<> IRB(&I); if (!PropagateShadow) { setShadow(&I, getCleanShadow(&I)); + setOrigin(&I, getCleanOrigin()); return; } @@ -2412,6 +2458,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { void visitAllocaInst(AllocaInst &I) { setShadow(&I, getCleanShadow(&I)); + setOrigin(&I, getCleanOrigin()); IRBuilder<> IRB(I.getNextNode()); uint64_t Size = MS.DL->getTypeAllocSize(I.getAllocatedType()); if (PoisonStack && ClPoisonStackWithCall) { @@ -2425,7 +2472,6 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { } if (PoisonStack && MS.TrackOrigins) { - setOrigin(&I, getCleanOrigin()); SmallString<2048> StackDescriptionStorage; raw_svector_ostream StackDescription(StackDescriptionStorage); // We create a string with a description of the stack allocation and @@ -2491,9 +2537,10 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { } // a = select b, c, d // Oa = Sb ? Ob : (b ? Oc : Od) - setOrigin(&I, IRB.CreateSelect( - Sb, getOrigin(I.getCondition()), - IRB.CreateSelect(B, getOrigin(C), getOrigin(D)))); + setOrigin( + &I, IRB.CreateSelect(Sb, getOrigin(I.getCondition()), + IRB.CreateSelect(B, getOrigin(I.getTrueValue()), + getOrigin(I.getFalseValue())))); } } @@ -2616,7 +2663,7 @@ struct VarArgAMD64Helper : public VarArgHelper { Type *RealTy = A->getType()->getPointerElementType(); uint64_t ArgSize = MS.DL->getTypeAllocSize(RealTy); Value *Base = getShadowPtrForVAArgument(RealTy, IRB, OverflowOffset); - OverflowOffset += DataLayout::RoundUpAlignment(ArgSize, 8); + OverflowOffset += RoundUpToAlignment(ArgSize, 8); IRB.CreateMemCpy(Base, MSV.getShadowPtr(A, IRB.getInt8Ty(), IRB), ArgSize, kShadowTLSAlignment); } else { @@ -2638,7 +2685,7 @@ struct VarArgAMD64Helper : public VarArgHelper { case AK_Memory: uint64_t ArgSize = MS.DL->getTypeAllocSize(A->getType()); Base = getShadowPtrForVAArgument(A->getType(), IRB, OverflowOffset); - OverflowOffset += DataLayout::RoundUpAlignment(ArgSize, 8); + OverflowOffset += RoundUpToAlignment(ArgSize, 8); } IRB.CreateAlignedStore(MSV.getShadow(A), Base, kShadowTLSAlignment); } diff --git a/lib/Transforms/Instrumentation/SanitizerCoverage.cpp b/lib/Transforms/Instrumentation/SanitizerCoverage.cpp new file mode 100644 index 000000000000..c048a99f8880 --- /dev/null +++ b/lib/Transforms/Instrumentation/SanitizerCoverage.cpp @@ -0,0 +1,314 @@ +//===-- SanitizerCoverage.cpp - coverage instrumentation for sanitizers ---===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// Coverage instrumentation that works with AddressSanitizer +// and potentially with other Sanitizers. +// +// We create a Guard variable with the same linkage +// as the function and inject this code into the entry block (CoverageLevel=1) +// or all blocks (CoverageLevel>=2): +// if (Guard < 0) { +// __sanitizer_cov(&Guard); +// } +// The accesses to Guard are atomic. The rest of the logic is +// in __sanitizer_cov (it's fine to call it more than once). +// +// With CoverageLevel>=3 we also split critical edges this effectively +// instrumenting all edges. +// +// CoverageLevel>=4 add indirect call profiling implented as a function call. +// +// This coverage implementation provides very limited data: +// it only tells if a given function (block) was ever executed. No counters. +// But for many use cases this is what we need and the added slowdown small. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Instrumentation.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InlineAsm.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/MDBuilder.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/ModuleUtils.h" + +using namespace llvm; + +#define DEBUG_TYPE "sancov" + +static const char *const kSanCovModuleInitName = "__sanitizer_cov_module_init"; +static const char *const kSanCovName = "__sanitizer_cov"; +static const char *const kSanCovIndirCallName = "__sanitizer_cov_indir_call16"; +static const char *const kSanCovTraceEnter = "__sanitizer_cov_trace_func_enter"; +static const char *const kSanCovTraceBB = "__sanitizer_cov_trace_basic_block"; +static const char *const kSanCovModuleCtorName = "sancov.module_ctor"; +static const uint64_t kSanCtorAndDtorPriority = 1; + +static cl::opt<int> ClCoverageLevel("sanitizer-coverage-level", + cl::desc("Sanitizer Coverage. 0: none, 1: entry block, 2: all blocks, " + "3: all blocks and critical edges, " + "4: above plus indirect calls"), + cl::Hidden, cl::init(0)); + +static cl::opt<int> ClCoverageBlockThreshold( + "sanitizer-coverage-block-threshold", + cl::desc("Add coverage instrumentation only to the entry block if there " + "are more than this number of blocks."), + cl::Hidden, cl::init(1500)); + +static cl::opt<bool> + ClExperimentalTracing("sanitizer-coverage-experimental-tracing", + cl::desc("Experimental basic-block tracing: insert " + "callbacks at every basic block"), + cl::Hidden, cl::init(false)); + +namespace { + +class SanitizerCoverageModule : public ModulePass { + public: + SanitizerCoverageModule(int CoverageLevel = 0) + : ModulePass(ID), + CoverageLevel(std::max(CoverageLevel, (int)ClCoverageLevel)) {} + bool runOnModule(Module &M) override; + bool runOnFunction(Function &F); + static char ID; // Pass identification, replacement for typeid + const char *getPassName() const override { + return "SanitizerCoverageModule"; + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<DataLayoutPass>(); + } + + private: + void InjectCoverageForIndirectCalls(Function &F, + ArrayRef<Instruction *> IndirCalls); + bool InjectCoverage(Function &F, ArrayRef<BasicBlock *> AllBlocks, + ArrayRef<Instruction *> IndirCalls); + void InjectCoverageAtBlock(Function &F, BasicBlock &BB); + Function *SanCovFunction; + Function *SanCovIndirCallFunction; + Function *SanCovModuleInit; + Function *SanCovTraceEnter, *SanCovTraceBB; + InlineAsm *EmptyAsm; + Type *IntptrTy; + LLVMContext *C; + + GlobalVariable *GuardArray; + + int CoverageLevel; +}; + +} // namespace + +static Function *checkInterfaceFunction(Constant *FuncOrBitcast) { + if (Function *F = dyn_cast<Function>(FuncOrBitcast)) + return F; + std::string Err; + raw_string_ostream Stream(Err); + Stream << "SanitizerCoverage interface function redefined: " + << *FuncOrBitcast; + report_fatal_error(Err); +} + +bool SanitizerCoverageModule::runOnModule(Module &M) { + if (!CoverageLevel) return false; + C = &(M.getContext()); + DataLayoutPass *DLP = &getAnalysis<DataLayoutPass>(); + IntptrTy = Type::getIntNTy(*C, DLP->getDataLayout().getPointerSizeInBits()); + Type *VoidTy = Type::getVoidTy(*C); + IRBuilder<> IRB(*C); + Type *Int32PtrTy = PointerType::getUnqual(IRB.getInt32Ty()); + + Function *CtorFunc = + Function::Create(FunctionType::get(VoidTy, false), + GlobalValue::InternalLinkage, kSanCovModuleCtorName, &M); + ReturnInst::Create(*C, BasicBlock::Create(*C, "", CtorFunc)); + appendToGlobalCtors(M, CtorFunc, kSanCtorAndDtorPriority); + + SanCovFunction = checkInterfaceFunction( + M.getOrInsertFunction(kSanCovName, VoidTy, Int32PtrTy, nullptr)); + SanCovIndirCallFunction = checkInterfaceFunction(M.getOrInsertFunction( + kSanCovIndirCallName, VoidTy, IntptrTy, IntptrTy, nullptr)); + SanCovModuleInit = checkInterfaceFunction( + M.getOrInsertFunction(kSanCovModuleInitName, Type::getVoidTy(*C), + Int32PtrTy, IntptrTy, nullptr)); + SanCovModuleInit->setLinkage(Function::ExternalLinkage); + // We insert an empty inline asm after cov callbacks to avoid callback merge. + EmptyAsm = InlineAsm::get(FunctionType::get(IRB.getVoidTy(), false), + StringRef(""), StringRef(""), + /*hasSideEffects=*/true); + + if (ClExperimentalTracing) { + SanCovTraceEnter = checkInterfaceFunction( + M.getOrInsertFunction(kSanCovTraceEnter, VoidTy, Int32PtrTy, nullptr)); + SanCovTraceBB = checkInterfaceFunction( + M.getOrInsertFunction(kSanCovTraceBB, VoidTy, Int32PtrTy, nullptr)); + } + + // At this point we create a dummy array of guards because we don't + // know how many elements we will need. + Type *Int32Ty = IRB.getInt32Ty(); + GuardArray = + new GlobalVariable(M, Int32Ty, false, GlobalValue::ExternalLinkage, + nullptr, "__sancov_gen_cov_tmp"); + + for (auto &F : M) + runOnFunction(F); + + // Now we know how many elements we need. Create an array of guards + // with one extra element at the beginning for the size. + Type *Int32ArrayNTy = + ArrayType::get(Int32Ty, SanCovFunction->getNumUses() + 1); + GlobalVariable *RealGuardArray = new GlobalVariable( + M, Int32ArrayNTy, false, GlobalValue::PrivateLinkage, + Constant::getNullValue(Int32ArrayNTy), "__sancov_gen_cov"); + + // Replace the dummy array with the real one. + GuardArray->replaceAllUsesWith( + IRB.CreatePointerCast(RealGuardArray, Int32PtrTy)); + GuardArray->eraseFromParent(); + + // Call __sanitizer_cov_module_init + IRB.SetInsertPoint(CtorFunc->getEntryBlock().getTerminator()); + IRB.CreateCall2(SanCovModuleInit, + IRB.CreatePointerCast(RealGuardArray, Int32PtrTy), + ConstantInt::get(IntptrTy, SanCovFunction->getNumUses())); + return true; +} + +bool SanitizerCoverageModule::runOnFunction(Function &F) { + if (F.empty()) return false; + if (F.getName().find(".module_ctor") != std::string::npos) + return false; // Should not instrument sanitizer init functions. + if (CoverageLevel >= 3) + SplitAllCriticalEdges(F, this); + SmallVector<Instruction*, 8> IndirCalls; + SmallVector<BasicBlock*, 16> AllBlocks; + for (auto &BB : F) { + AllBlocks.push_back(&BB); + if (CoverageLevel >= 4) + for (auto &Inst : BB) { + CallSite CS(&Inst); + if (CS && !CS.getCalledFunction()) + IndirCalls.push_back(&Inst); + } + } + InjectCoverage(F, AllBlocks, IndirCalls); + return true; +} + +bool +SanitizerCoverageModule::InjectCoverage(Function &F, + ArrayRef<BasicBlock *> AllBlocks, + ArrayRef<Instruction *> IndirCalls) { + if (!CoverageLevel) return false; + + if (CoverageLevel == 1 || + (unsigned)ClCoverageBlockThreshold < AllBlocks.size()) { + InjectCoverageAtBlock(F, F.getEntryBlock()); + } else { + for (auto BB : AllBlocks) + InjectCoverageAtBlock(F, *BB); + } + InjectCoverageForIndirectCalls(F, IndirCalls); + return true; +} + +// On every indirect call we call a run-time function +// __sanitizer_cov_indir_call* with two parameters: +// - callee address, +// - global cache array that contains kCacheSize pointers (zero-initialized). +// The cache is used to speed up recording the caller-callee pairs. +// The address of the caller is passed implicitly via caller PC. +// kCacheSize is encoded in the name of the run-time function. +void SanitizerCoverageModule::InjectCoverageForIndirectCalls( + Function &F, ArrayRef<Instruction *> IndirCalls) { + if (IndirCalls.empty()) return; + const int kCacheSize = 16; + const int kCacheAlignment = 64; // Align for better performance. + Type *Ty = ArrayType::get(IntptrTy, kCacheSize); + for (auto I : IndirCalls) { + IRBuilder<> IRB(I); + CallSite CS(I); + Value *Callee = CS.getCalledValue(); + if (dyn_cast<InlineAsm>(Callee)) continue; + GlobalVariable *CalleeCache = new GlobalVariable( + *F.getParent(), Ty, false, GlobalValue::PrivateLinkage, + Constant::getNullValue(Ty), "__sancov_gen_callee_cache"); + CalleeCache->setAlignment(kCacheAlignment); + IRB.CreateCall2(SanCovIndirCallFunction, + IRB.CreatePointerCast(Callee, IntptrTy), + IRB.CreatePointerCast(CalleeCache, IntptrTy)); + } +} + +void SanitizerCoverageModule::InjectCoverageAtBlock(Function &F, + BasicBlock &BB) { + BasicBlock::iterator IP = BB.getFirstInsertionPt(), BE = BB.end(); + // Skip static allocas at the top of the entry block so they don't become + // dynamic when we split the block. If we used our optimized stack layout, + // then there will only be one alloca and it will come first. + for (; IP != BE; ++IP) { + AllocaInst *AI = dyn_cast<AllocaInst>(IP); + if (!AI || !AI->isStaticAlloca()) + break; + } + + bool IsEntryBB = &BB == &F.getEntryBlock(); + DebugLoc EntryLoc = + IsEntryBB ? IP->getDebugLoc().getFnDebugLoc(*C) : IP->getDebugLoc(); + IRBuilder<> IRB(IP); + IRB.SetCurrentDebugLocation(EntryLoc); + SmallVector<Value *, 1> Indices; + Value *GuardP = IRB.CreateAdd( + IRB.CreatePointerCast(GuardArray, IntptrTy), + ConstantInt::get(IntptrTy, (1 + SanCovFunction->getNumUses()) * 4)); + Type *Int32PtrTy = PointerType::getUnqual(IRB.getInt32Ty()); + GuardP = IRB.CreateIntToPtr(GuardP, Int32PtrTy); + LoadInst *Load = IRB.CreateLoad(GuardP); + Load->setAtomic(Monotonic); + Load->setAlignment(4); + Load->setMetadata(F.getParent()->getMDKindID("nosanitize"), + MDNode::get(*C, None)); + Value *Cmp = IRB.CreateICmpSGE(Constant::getNullValue(Load->getType()), Load); + Instruction *Ins = SplitBlockAndInsertIfThen( + Cmp, IP, false, MDBuilder(*C).createBranchWeights(1, 100000)); + IRB.SetInsertPoint(Ins); + IRB.SetCurrentDebugLocation(EntryLoc); + // __sanitizer_cov gets the PC of the instruction using GET_CALLER_PC. + IRB.CreateCall(SanCovFunction, GuardP); + IRB.CreateCall(EmptyAsm); // Avoids callback merge. + + if (ClExperimentalTracing) { + // Experimental support for tracing. + // Insert a callback with the same guard variable as used for coverage. + IRB.SetInsertPoint(IP); + IRB.CreateCall(IsEntryBB ? SanCovTraceEnter : SanCovTraceBB, GuardP); + } +} + +char SanitizerCoverageModule::ID = 0; +INITIALIZE_PASS(SanitizerCoverageModule, "sancov", + "SanitizerCoverage: TODO." + "ModulePass", false, false) +ModulePass *llvm::createSanitizerCoverageModulePass(int CoverageLevel) { + return new SanitizerCoverageModule(CoverageLevel); +} diff --git a/lib/Transforms/Instrumentation/ThreadSanitizer.cpp b/lib/Transforms/Instrumentation/ThreadSanitizer.cpp index 89386a6a86de..1b86ae5acf7d 100644 --- a/lib/Transforms/Instrumentation/ThreadSanitizer.cpp +++ b/lib/Transforms/Instrumentation/ThreadSanitizer.cpp @@ -135,33 +135,33 @@ void ThreadSanitizer::initializeCallbacks(Module &M) { IRBuilder<> IRB(M.getContext()); // Initialize the callbacks. TsanFuncEntry = checkInterfaceFunction(M.getOrInsertFunction( - "__tsan_func_entry", IRB.getVoidTy(), IRB.getInt8PtrTy(), NULL)); + "__tsan_func_entry", IRB.getVoidTy(), IRB.getInt8PtrTy(), nullptr)); TsanFuncExit = checkInterfaceFunction(M.getOrInsertFunction( - "__tsan_func_exit", IRB.getVoidTy(), NULL)); + "__tsan_func_exit", IRB.getVoidTy(), nullptr)); OrdTy = IRB.getInt32Ty(); for (size_t i = 0; i < kNumberOfAccessSizes; ++i) { const size_t ByteSize = 1 << i; const size_t BitSize = ByteSize * 8; SmallString<32> ReadName("__tsan_read" + itostr(ByteSize)); TsanRead[i] = checkInterfaceFunction(M.getOrInsertFunction( - ReadName, IRB.getVoidTy(), IRB.getInt8PtrTy(), NULL)); + ReadName, IRB.getVoidTy(), IRB.getInt8PtrTy(), nullptr)); SmallString<32> WriteName("__tsan_write" + itostr(ByteSize)); TsanWrite[i] = checkInterfaceFunction(M.getOrInsertFunction( - WriteName, IRB.getVoidTy(), IRB.getInt8PtrTy(), NULL)); + WriteName, IRB.getVoidTy(), IRB.getInt8PtrTy(), nullptr)); Type *Ty = Type::getIntNTy(M.getContext(), BitSize); Type *PtrTy = Ty->getPointerTo(); SmallString<32> AtomicLoadName("__tsan_atomic" + itostr(BitSize) + "_load"); TsanAtomicLoad[i] = checkInterfaceFunction(M.getOrInsertFunction( - AtomicLoadName, Ty, PtrTy, OrdTy, NULL)); + AtomicLoadName, Ty, PtrTy, OrdTy, nullptr)); SmallString<32> AtomicStoreName("__tsan_atomic" + itostr(BitSize) + "_store"); TsanAtomicStore[i] = checkInterfaceFunction(M.getOrInsertFunction( AtomicStoreName, IRB.getVoidTy(), PtrTy, Ty, OrdTy, - NULL)); + nullptr)); for (int op = AtomicRMWInst::FIRST_BINOP; op <= AtomicRMWInst::LAST_BINOP; ++op) { @@ -185,33 +185,33 @@ void ThreadSanitizer::initializeCallbacks(Module &M) { continue; SmallString<32> RMWName("__tsan_atomic" + itostr(BitSize) + NamePart); TsanAtomicRMW[op][i] = checkInterfaceFunction(M.getOrInsertFunction( - RMWName, Ty, PtrTy, Ty, OrdTy, NULL)); + RMWName, Ty, PtrTy, Ty, OrdTy, nullptr)); } SmallString<32> AtomicCASName("__tsan_atomic" + itostr(BitSize) + "_compare_exchange_val"); TsanAtomicCAS[i] = checkInterfaceFunction(M.getOrInsertFunction( - AtomicCASName, Ty, PtrTy, Ty, Ty, OrdTy, OrdTy, NULL)); + AtomicCASName, Ty, PtrTy, Ty, Ty, OrdTy, OrdTy, nullptr)); } TsanVptrUpdate = checkInterfaceFunction(M.getOrInsertFunction( "__tsan_vptr_update", IRB.getVoidTy(), IRB.getInt8PtrTy(), - IRB.getInt8PtrTy(), NULL)); + IRB.getInt8PtrTy(), nullptr)); TsanVptrLoad = checkInterfaceFunction(M.getOrInsertFunction( - "__tsan_vptr_read", IRB.getVoidTy(), IRB.getInt8PtrTy(), NULL)); + "__tsan_vptr_read", IRB.getVoidTy(), IRB.getInt8PtrTy(), nullptr)); TsanAtomicThreadFence = checkInterfaceFunction(M.getOrInsertFunction( - "__tsan_atomic_thread_fence", IRB.getVoidTy(), OrdTy, NULL)); + "__tsan_atomic_thread_fence", IRB.getVoidTy(), OrdTy, nullptr)); TsanAtomicSignalFence = checkInterfaceFunction(M.getOrInsertFunction( - "__tsan_atomic_signal_fence", IRB.getVoidTy(), OrdTy, NULL)); + "__tsan_atomic_signal_fence", IRB.getVoidTy(), OrdTy, nullptr)); MemmoveFn = checkInterfaceFunction(M.getOrInsertFunction( "memmove", IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), - IRB.getInt8PtrTy(), IntptrTy, NULL)); + IRB.getInt8PtrTy(), IntptrTy, nullptr)); MemcpyFn = checkInterfaceFunction(M.getOrInsertFunction( "memcpy", IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), - IntptrTy, NULL)); + IntptrTy, nullptr)); MemsetFn = checkInterfaceFunction(M.getOrInsertFunction( "memset", IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IRB.getInt32Ty(), - IntptrTy, NULL)); + IntptrTy, nullptr)); } bool ThreadSanitizer::doInitialization(Module &M) { @@ -224,7 +224,7 @@ bool ThreadSanitizer::doInitialization(Module &M) { IRBuilder<> IRB(M.getContext()); IntptrTy = IRB.getIntPtrTy(DL); Value *TsanInit = M.getOrInsertFunction("__tsan_init", - IRB.getVoidTy(), NULL); + IRB.getVoidTy(), nullptr); appendToGlobalCtors(M, cast<Function>(TsanInit), 0); return true; @@ -422,7 +422,7 @@ bool ThreadSanitizer::instrumentLoadOrStore(Instruction *I) { static ConstantInt *createOrdering(IRBuilder<> *IRB, AtomicOrdering ord) { uint32_t v = 0; switch (ord) { - case NotAtomic: assert(false); + case NotAtomic: llvm_unreachable("unexpected atomic ordering!"); case Unordered: // Fall-through. case Monotonic: v = 0; break; // case Consume: v = 1; break; // Not specified yet. @@ -481,8 +481,7 @@ bool ThreadSanitizer::instrumentAtomic(Instruction *I) { Type *PtrTy = Ty->getPointerTo(); Value *Args[] = {IRB.CreatePointerCast(Addr, PtrTy), createOrdering(&IRB, LI->getOrdering())}; - CallInst *C = CallInst::Create(TsanAtomicLoad[Idx], - ArrayRef<Value*>(Args)); + CallInst *C = CallInst::Create(TsanAtomicLoad[Idx], Args); ReplaceInstWithInst(I, C); } else if (StoreInst *SI = dyn_cast<StoreInst>(I)) { @@ -497,8 +496,7 @@ bool ThreadSanitizer::instrumentAtomic(Instruction *I) { Value *Args[] = {IRB.CreatePointerCast(Addr, PtrTy), IRB.CreateIntCast(SI->getValueOperand(), Ty, false), createOrdering(&IRB, SI->getOrdering())}; - CallInst *C = CallInst::Create(TsanAtomicStore[Idx], - ArrayRef<Value*>(Args)); + CallInst *C = CallInst::Create(TsanAtomicStore[Idx], Args); ReplaceInstWithInst(I, C); } else if (AtomicRMWInst *RMWI = dyn_cast<AtomicRMWInst>(I)) { Value *Addr = RMWI->getPointerOperand(); @@ -515,7 +513,7 @@ bool ThreadSanitizer::instrumentAtomic(Instruction *I) { Value *Args[] = {IRB.CreatePointerCast(Addr, PtrTy), IRB.CreateIntCast(RMWI->getValOperand(), Ty, false), createOrdering(&IRB, RMWI->getOrdering())}; - CallInst *C = CallInst::Create(F, ArrayRef<Value*>(Args)); + CallInst *C = CallInst::Create(F, Args); ReplaceInstWithInst(I, C); } else if (AtomicCmpXchgInst *CASI = dyn_cast<AtomicCmpXchgInst>(I)) { Value *Addr = CASI->getPointerOperand(); @@ -543,7 +541,7 @@ bool ThreadSanitizer::instrumentAtomic(Instruction *I) { Value *Args[] = {createOrdering(&IRB, FI->getOrdering())}; Function *F = FI->getSynchScope() == SingleThread ? TsanAtomicSignalFence : TsanAtomicThreadFence; - CallInst *C = CallInst::Create(F, ArrayRef<Value*>(Args)); + CallInst *C = CallInst::Create(F, Args); ReplaceInstWithInst(I, C); } return true; diff --git a/lib/Transforms/ObjCARC/ARCRuntimeEntryPoints.h b/lib/Transforms/ObjCARC/ARCRuntimeEntryPoints.h index 409842863073..e286dbc64a86 100644 --- a/lib/Transforms/ObjCARC/ARCRuntimeEntryPoints.h +++ b/lib/Transforms/ObjCARC/ARCRuntimeEntryPoints.h @@ -19,8 +19,8 @@ /// //===----------------------------------------------------------------------===// -#ifndef LLVM_TRANSFORMS_SCALAR_ARCRUNTIMEENTRYPOINTS_H -#define LLVM_TRANSFORMS_SCALAR_ARCRUNTIMEENTRYPOINTS_H +#ifndef LLVM_LIB_TRANSFORMS_OBJCARC_ARCRUNTIMEENTRYPOINTS_H +#define LLVM_LIB_TRANSFORMS_OBJCARC_ARCRUNTIMEENTRYPOINTS_H #include "ObjCARC.h" @@ -183,4 +183,4 @@ private: } // namespace objcarc } // namespace llvm -#endif // LLVM_TRANSFORMS_SCALAR_ARCRUNTIMEENTRYPOINTS_H +#endif diff --git a/lib/Transforms/ObjCARC/CMakeLists.txt b/lib/Transforms/ObjCARC/CMakeLists.txt index 233deb398011..b449fac13860 100644 --- a/lib/Transforms/ObjCARC/CMakeLists.txt +++ b/lib/Transforms/ObjCARC/CMakeLists.txt @@ -8,6 +8,7 @@ add_llvm_library(LLVMObjCARCOpts ObjCARCContract.cpp DependencyAnalysis.cpp ProvenanceAnalysis.cpp + ProvenanceAnalysisEvaluator.cpp ) add_dependencies(LLVMObjCARCOpts intrinsics_gen) diff --git a/lib/Transforms/ObjCARC/DependencyAnalysis.cpp b/lib/Transforms/ObjCARC/DependencyAnalysis.cpp index 08c884293cc5..f6c236c31ef8 100644 --- a/lib/Transforms/ObjCARC/DependencyAnalysis.cpp +++ b/lib/Transforms/ObjCARC/DependencyAnalysis.cpp @@ -206,8 +206,8 @@ void llvm::objcarc::FindDependencies(DependenceKind Flavor, const Value *Arg, BasicBlock *StartBB, Instruction *StartInst, - SmallPtrSet<Instruction *, 4> &DependingInsts, - SmallPtrSet<const BasicBlock *, 4> &Visited, + SmallPtrSetImpl<Instruction *> &DependingInsts, + SmallPtrSetImpl<const BasicBlock *> &Visited, ProvenanceAnalysis &PA) { BasicBlock::iterator StartPos = StartInst; @@ -229,7 +229,7 @@ llvm::objcarc::FindDependencies(DependenceKind Flavor, // Add the predecessors to the worklist. do { BasicBlock *PredBB = *PI; - if (Visited.insert(PredBB)) + if (Visited.insert(PredBB).second) Worklist.push_back(std::make_pair(PredBB, PredBB->end())); } while (++PI != PE); break; @@ -246,9 +246,7 @@ llvm::objcarc::FindDependencies(DependenceKind Flavor, // Determine whether the original StartBB post-dominates all of the blocks we // visited. If not, insert a sentinal indicating that most optimizations are // not safe. - for (SmallPtrSet<const BasicBlock *, 4>::const_iterator I = Visited.begin(), - E = Visited.end(); I != E; ++I) { - const BasicBlock *BB = *I; + for (const BasicBlock *BB : Visited) { if (BB == StartBB) continue; const TerminatorInst *TI = cast<TerminatorInst>(&BB->back()); diff --git a/lib/Transforms/ObjCARC/DependencyAnalysis.h b/lib/Transforms/ObjCARC/DependencyAnalysis.h index 617cdf3843b8..7b5601ad6d5d 100644 --- a/lib/Transforms/ObjCARC/DependencyAnalysis.h +++ b/lib/Transforms/ObjCARC/DependencyAnalysis.h @@ -20,8 +20,8 @@ /// //===----------------------------------------------------------------------===// -#ifndef LLVM_TRANSFORMS_OBJCARC_DEPEDENCYANALYSIS_H -#define LLVM_TRANSFORMS_OBJCARC_DEPEDENCYANALYSIS_H +#ifndef LLVM_LIB_TRANSFORMS_OBJCARC_DEPENDENCYANALYSIS_H +#define LLVM_LIB_TRANSFORMS_OBJCARC_DEPENDENCYANALYSIS_H #include "llvm/ADT/SmallPtrSet.h" @@ -53,8 +53,8 @@ enum DependenceKind { void FindDependencies(DependenceKind Flavor, const Value *Arg, BasicBlock *StartBB, Instruction *StartInst, - SmallPtrSet<Instruction *, 4> &DependingInstructions, - SmallPtrSet<const BasicBlock *, 4> &Visited, + SmallPtrSetImpl<Instruction *> &DependingInstructions, + SmallPtrSetImpl<const BasicBlock *> &Visited, ProvenanceAnalysis &PA); bool @@ -76,4 +76,4 @@ CanAlterRefCount(const Instruction *Inst, const Value *Ptr, } // namespace objcarc } // namespace llvm -#endif // LLVM_TRANSFORMS_OBJCARC_DEPEDENCYANALYSIS_H +#endif diff --git a/lib/Transforms/ObjCARC/ObjCARC.cpp b/lib/Transforms/ObjCARC/ObjCARC.cpp index 373168e89888..6ea038b8ba8c 100644 --- a/lib/Transforms/ObjCARC/ObjCARC.cpp +++ b/lib/Transforms/ObjCARC/ObjCARC.cpp @@ -42,6 +42,7 @@ void llvm::initializeObjCARCOpts(PassRegistry &Registry) { initializeObjCARCExpandPass(Registry); initializeObjCARCContractPass(Registry); initializeObjCARCOptPass(Registry); + initializePAEvalPass(Registry); } void LLVMInitializeObjCARCOpts(LLVMPassRegistryRef R) { diff --git a/lib/Transforms/ObjCARC/ObjCARC.h b/lib/Transforms/ObjCARC/ObjCARC.h index f71cf2bd4399..7a7eae84a1e2 100644 --- a/lib/Transforms/ObjCARC/ObjCARC.h +++ b/lib/Transforms/ObjCARC/ObjCARC.h @@ -20,8 +20,8 @@ /// //===----------------------------------------------------------------------===// -#ifndef LLVM_TRANSFORMS_SCALAR_OBJCARC_H -#define LLVM_TRANSFORMS_SCALAR_OBJCARC_H +#ifndef LLVM_LIB_TRANSFORMS_OBJCARC_OBJCARC_H +#define LLVM_LIB_TRANSFORMS_OBJCARC_OBJCARC_H #include "llvm/ADT/StringSwitch.h" #include "llvm/Analysis/AliasAnalysis.h" @@ -380,11 +380,15 @@ static inline bool IsObjCIdentifiedObject(const Value *V) { StringRef Name = GV->getName(); // These special variables are known to hold values which are not // reference-counted pointers. - if (Name.startswith("\01L_OBJC_SELECTOR_REFERENCES_") || - Name.startswith("\01L_OBJC_CLASSLIST_REFERENCES_") || - Name.startswith("\01L_OBJC_CLASSLIST_SUP_REFS_$_") || - Name.startswith("\01L_OBJC_METH_VAR_NAME_") || - Name.startswith("\01l_objc_msgSend_fixup_")) + if (Name.startswith("\01l_objc_msgSend_fixup_")) + return true; + + StringRef Section = GV->getSection(); + if (Section.find("__message_refs") != StringRef::npos || + Section.find("__objc_classrefs") != StringRef::npos || + Section.find("__objc_superrefs") != StringRef::npos || + Section.find("__objc_methname") != StringRef::npos || + Section.find("__cstring") != StringRef::npos) return true; } } @@ -395,4 +399,4 @@ static inline bool IsObjCIdentifiedObject(const Value *V) { } // end namespace objcarc } // end namespace llvm -#endif // LLVM_TRANSFORMS_SCALAR_OBJCARC_H +#endif diff --git a/lib/Transforms/ObjCARC/ObjCARCAliasAnalysis.cpp b/lib/Transforms/ObjCARC/ObjCARCAliasAnalysis.cpp index 2c09e70cc9c6..c61b6b0e6dd9 100644 --- a/lib/Transforms/ObjCARC/ObjCARCAliasAnalysis.cpp +++ b/lib/Transforms/ObjCARC/ObjCARCAliasAnalysis.cpp @@ -62,8 +62,8 @@ ObjCARCAliasAnalysis::alias(const Location &LocA, const Location &LocB) { const Value *SA = StripPointerCastsAndObjCCalls(LocA.Ptr); const Value *SB = StripPointerCastsAndObjCCalls(LocB.Ptr); AliasResult Result = - AliasAnalysis::alias(Location(SA, LocA.Size, LocA.TBAATag), - Location(SB, LocB.Size, LocB.TBAATag)); + AliasAnalysis::alias(Location(SA, LocA.Size, LocA.AATags), + Location(SB, LocB.Size, LocB.AATags)); if (Result != MayAlias) return Result; @@ -93,7 +93,7 @@ ObjCARCAliasAnalysis::pointsToConstantMemory(const Location &Loc, // First, strip off no-ops, including ObjC-specific no-ops, and try making // a precise alias query. const Value *S = StripPointerCastsAndObjCCalls(Loc.Ptr); - if (AliasAnalysis::pointsToConstantMemory(Location(S, Loc.Size, Loc.TBAATag), + if (AliasAnalysis::pointsToConstantMemory(Location(S, Loc.Size, Loc.AATags), OrLocal)) return true; diff --git a/lib/Transforms/ObjCARC/ObjCARCAliasAnalysis.h b/lib/Transforms/ObjCARC/ObjCARCAliasAnalysis.h index 97b565be0d2a..3fcea4e9b86d 100644 --- a/lib/Transforms/ObjCARC/ObjCARCAliasAnalysis.h +++ b/lib/Transforms/ObjCARC/ObjCARCAliasAnalysis.h @@ -20,8 +20,8 @@ /// //===----------------------------------------------------------------------===// -#ifndef LLVM_TRANSFORMS_OBJCARC_OBJCARCALIASANALYSIS_H -#define LLVM_TRANSFORMS_OBJCARC_OBJCARCALIASANALYSIS_H +#ifndef LLVM_LIB_TRANSFORMS_OBJCARC_OBJCARCALIASANALYSIS_H +#define LLVM_LIB_TRANSFORMS_OBJCARC_OBJCARCALIASANALYSIS_H #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Pass.h" @@ -71,4 +71,4 @@ namespace objcarc { } // namespace objcarc } // namespace llvm -#endif // LLVM_TRANSFORMS_OBJCARC_OBJCARCALIASANALYSIS_H +#endif diff --git a/lib/Transforms/ObjCARC/ObjCARCContract.cpp b/lib/Transforms/ObjCARC/ObjCARCContract.cpp index f48d53d11b71..eb325eb9417f 100644 --- a/lib/Transforms/ObjCARC/ObjCARCContract.cpp +++ b/lib/Transforms/ObjCARC/ObjCARCContract.cpp @@ -72,9 +72,9 @@ namespace { bool ContractAutorelease(Function &F, Instruction *Autorelease, InstructionClass Class, - SmallPtrSet<Instruction *, 4> + SmallPtrSetImpl<Instruction *> &DependingInstructions, - SmallPtrSet<const BasicBlock *, 4> + SmallPtrSetImpl<const BasicBlock *> &Visited); void ContractRelease(Instruction *Release, @@ -150,9 +150,9 @@ ObjCARCContract::OptimizeRetainCall(Function &F, Instruction *Retain) { bool ObjCARCContract::ContractAutorelease(Function &F, Instruction *Autorelease, InstructionClass Class, - SmallPtrSet<Instruction *, 4> + SmallPtrSetImpl<Instruction *> &DependingInstructions, - SmallPtrSet<const BasicBlock *, 4> + SmallPtrSetImpl<const BasicBlock *> &Visited) { const Value *Arg = GetObjCArg(Autorelease); @@ -508,9 +508,8 @@ bool ObjCARCContract::runOnFunction(Function &F) { // If this function has no escaping allocas or suspicious vararg usage, // objc_storeStrong calls can be marked with the "tail" keyword. if (TailOkForStoreStrongs) - for (SmallPtrSet<CallInst *, 8>::iterator I = StoreStrongCalls.begin(), - E = StoreStrongCalls.end(); I != E; ++I) - (*I)->setTailCall(); + for (CallInst *CI : StoreStrongCalls) + CI->setTailCall(); StoreStrongCalls.clear(); return Changed; diff --git a/lib/Transforms/ObjCARC/ObjCARCOpts.cpp b/lib/Transforms/ObjCARC/ObjCARCOpts.cpp index dd4dd50f0ba5..76932e6b600b 100644 --- a/lib/Transforms/ObjCARC/ObjCARCOpts.cpp +++ b/lib/Transforms/ObjCARC/ObjCARCOpts.cpp @@ -188,7 +188,7 @@ static inline bool AreAnyUnderlyingObjectsAnAlloca(const Value *V) { if (isa<AllocaInst>(P)) return true; - if (!Visited.insert(P)) + if (!Visited.insert(P).second) continue; if (const SelectInst *SI = dyn_cast<const SelectInst>(P)) { @@ -411,10 +411,8 @@ bool RRInfo::Merge(const RRInfo &Other) { // Merge the insert point sets. If there are any differences, // that makes this a partial merge. bool Partial = ReverseInsertPts.size() != Other.ReverseInsertPts.size(); - for (SmallPtrSet<Instruction *, 2>::const_iterator - I = Other.ReverseInsertPts.begin(), - E = Other.ReverseInsertPts.end(); I != E; ++I) - Partial |= ReverseInsertPts.insert(*I); + for (Instruction *Inst : Other.ReverseInsertPts) + Partial |= ReverseInsertPts.insert(Inst).second; return Partial; } @@ -882,13 +880,10 @@ static void AppendMDNodeToInstForPtr(unsigned NodeId, Sequence OldSeq, Sequence NewSeq) { MDNode *Node = nullptr; - Value *tmp[3] = {PtrSourceMDNodeID, - SequenceToMDString(Inst->getContext(), - OldSeq), - SequenceToMDString(Inst->getContext(), - NewSeq)}; - Node = MDNode::get(Inst->getContext(), - ArrayRef<Value*>(tmp, 3)); + Metadata *tmp[3] = {PtrSourceMDNodeID, + SequenceToMDString(Inst->getContext(), OldSeq), + SequenceToMDString(Inst->getContext(), NewSeq)}; + Node = MDNode::get(Inst->getContext(), tmp); Inst->setMetadata(NodeId, Node); } @@ -908,8 +903,7 @@ static void GenerateARCBBEntranceAnnotation(const char *Name, BasicBlock *BB, Type *I8X = PointerType::getUnqual(Type::getInt8Ty(C)); Type *I8XX = PointerType::getUnqual(I8X); Type *Params[] = {I8XX, I8XX}; - FunctionType *FTy = FunctionType::get(Type::getVoidTy(C), - ArrayRef<Type*>(Params, 2), + FunctionType *FTy = FunctionType::get(Type::getVoidTy(C), Params, /*isVarArg=*/false); Constant *Callee = M->getOrInsertFunction(Name, FTy); @@ -951,8 +945,7 @@ static void GenerateARCBBTerminatorAnnotation(const char *Name, BasicBlock *BB, Type *I8X = PointerType::getUnqual(Type::getInt8Ty(C)); Type *I8XX = PointerType::getUnqual(I8X); Type *Params[] = {I8XX, I8XX}; - FunctionType *FTy = FunctionType::get(Type::getVoidTy(C), - ArrayRef<Type*>(Params, 2), + FunctionType *FTy = FunctionType::get(Type::getVoidTy(C), Params, /*isVarArg=*/false); Constant *Callee = M->getOrInsertFunction(Name, FTy); @@ -2199,7 +2192,7 @@ ComputePostOrders(Function &F, while (SuccStack.back().second != SE) { BasicBlock *SuccBB = *SuccStack.back().second++; - if (Visited.insert(SuccBB)) { + if (Visited.insert(SuccBB).second) { TerminatorInst *TI = cast<TerminatorInst>(&SuccBB->back()); SuccStack.push_back(std::make_pair(SuccBB, succ_iterator(TI))); BBStates[CurrBB].addSucc(SuccBB); @@ -2240,7 +2233,7 @@ ComputePostOrders(Function &F, BBState::edge_iterator PE = BBStates[PredStack.back().first].pred_end(); while (PredStack.back().second != PE) { BasicBlock *BB = *PredStack.back().second++; - if (Visited.insert(BB)) { + if (Visited.insert(BB).second) { PredStack.push_back(std::make_pair(BB, BBStates[BB].pred_begin())); goto reverse_dfs_next_succ; } @@ -2299,10 +2292,7 @@ void ObjCARCOpt::MoveCalls(Value *Arg, DEBUG(dbgs() << "== ObjCARCOpt::MoveCalls ==\n"); // Insert the new retain and release calls. - for (SmallPtrSet<Instruction *, 2>::const_iterator - PI = ReleasesToMove.ReverseInsertPts.begin(), - PE = ReleasesToMove.ReverseInsertPts.end(); PI != PE; ++PI) { - Instruction *InsertPt = *PI; + for (Instruction *InsertPt : ReleasesToMove.ReverseInsertPts) { Value *MyArg = ArgTy == ParamTy ? Arg : new BitCastInst(Arg, ParamTy, "", InsertPt); Constant *Decl = EP.get(ARCRuntimeEntryPoints::EPT_Retain); @@ -2313,10 +2303,7 @@ void ObjCARCOpt::MoveCalls(Value *Arg, DEBUG(dbgs() << "Inserting new Retain: " << *Call << "\n" "At insertion point: " << *InsertPt << "\n"); } - for (SmallPtrSet<Instruction *, 2>::const_iterator - PI = RetainsToMove.ReverseInsertPts.begin(), - PE = RetainsToMove.ReverseInsertPts.end(); PI != PE; ++PI) { - Instruction *InsertPt = *PI; + for (Instruction *InsertPt : RetainsToMove.ReverseInsertPts) { Value *MyArg = ArgTy == ParamTy ? Arg : new BitCastInst(Arg, ParamTy, "", InsertPt); Constant *Decl = EP.get(ARCRuntimeEntryPoints::EPT_Release); @@ -2333,18 +2320,12 @@ void ObjCARCOpt::MoveCalls(Value *Arg, } // Delete the original retain and release calls. - for (SmallPtrSet<Instruction *, 2>::const_iterator - AI = RetainsToMove.Calls.begin(), - AE = RetainsToMove.Calls.end(); AI != AE; ++AI) { - Instruction *OrigRetain = *AI; + for (Instruction *OrigRetain : RetainsToMove.Calls) { Retains.blot(OrigRetain); DeadInsts.push_back(OrigRetain); DEBUG(dbgs() << "Deleting retain: " << *OrigRetain << "\n"); } - for (SmallPtrSet<Instruction *, 2>::const_iterator - AI = ReleasesToMove.Calls.begin(), - AE = ReleasesToMove.Calls.end(); AI != AE; ++AI) { - Instruction *OrigRelease = *AI; + for (Instruction *OrigRelease : ReleasesToMove.Calls) { Releases.erase(OrigRelease); DeadInsts.push_back(OrigRelease); DEBUG(dbgs() << "Deleting release: " << *OrigRelease << "\n"); @@ -2392,10 +2373,7 @@ ObjCARCOpt::ConnectTDBUTraversals(DenseMap<const BasicBlock *, BBState> KnownSafeTD &= NewRetainRRI.KnownSafe; MultipleOwners = MultipleOwners || MultiOwnersSet.count(GetObjCArg(NewRetain)); - for (SmallPtrSet<Instruction *, 2>::const_iterator - LI = NewRetainRRI.Calls.begin(), - LE = NewRetainRRI.Calls.end(); LI != LE; ++LI) { - Instruction *NewRetainRelease = *LI; + for (Instruction *NewRetainRelease : NewRetainRRI.Calls) { DenseMap<Value *, RRInfo>::const_iterator Jt = Releases.find(NewRetainRelease); if (Jt == Releases.end()) @@ -2410,7 +2388,7 @@ ObjCARCOpt::ConnectTDBUTraversals(DenseMap<const BasicBlock *, BBState> if (!NewRetainReleaseRRI.Calls.count(NewRetain)) return false; - if (ReleasesToMove.Calls.insert(NewRetainRelease)) { + if (ReleasesToMove.Calls.insert(NewRetainRelease).second) { // If we overflow when we compute the path count, don't remove/move // anything. @@ -2441,12 +2419,8 @@ ObjCARCOpt::ConnectTDBUTraversals(DenseMap<const BasicBlock *, BBState> // Collect the optimal insertion points. if (!KnownSafe) - for (SmallPtrSet<Instruction *, 2>::const_iterator - RI = NewRetainReleaseRRI.ReverseInsertPts.begin(), - RE = NewRetainReleaseRRI.ReverseInsertPts.end(); - RI != RE; ++RI) { - Instruction *RIP = *RI; - if (ReleasesToMove.ReverseInsertPts.insert(RIP)) { + for (Instruction *RIP : NewRetainReleaseRRI.ReverseInsertPts) { + if (ReleasesToMove.ReverseInsertPts.insert(RIP).second) { // If we overflow when we compute the path count, don't // remove/move anything. const BBState &RIPBBState = BBStates[RIP->getParent()]; @@ -2476,10 +2450,7 @@ ObjCARCOpt::ConnectTDBUTraversals(DenseMap<const BasicBlock *, BBState> const RRInfo &NewReleaseRRI = It->second; KnownSafeBU &= NewReleaseRRI.KnownSafe; CFGHazardAfflicted |= NewReleaseRRI.CFGHazardAfflicted; - for (SmallPtrSet<Instruction *, 2>::const_iterator - LI = NewReleaseRRI.Calls.begin(), - LE = NewReleaseRRI.Calls.end(); LI != LE; ++LI) { - Instruction *NewReleaseRetain = *LI; + for (Instruction *NewReleaseRetain : NewReleaseRRI.Calls) { MapVector<Value *, RRInfo>::const_iterator Jt = Retains.find(NewReleaseRetain); if (Jt == Retains.end()) @@ -2494,7 +2465,7 @@ ObjCARCOpt::ConnectTDBUTraversals(DenseMap<const BasicBlock *, BBState> if (!NewReleaseRetainRRI.Calls.count(NewRelease)) return false; - if (RetainsToMove.Calls.insert(NewReleaseRetain)) { + if (RetainsToMove.Calls.insert(NewReleaseRetain).second) { // If we overflow when we compute the path count, don't remove/move // anything. const BBState &NRRBBState = BBStates[NewReleaseRetain->getParent()]; @@ -2509,12 +2480,8 @@ ObjCARCOpt::ConnectTDBUTraversals(DenseMap<const BasicBlock *, BBState> // Collect the optimal insertion points. if (!KnownSafe) - for (SmallPtrSet<Instruction *, 2>::const_iterator - RI = NewReleaseRetainRRI.ReverseInsertPts.begin(), - RE = NewReleaseRetainRRI.ReverseInsertPts.end(); - RI != RE; ++RI) { - Instruction *RIP = *RI; - if (RetainsToMove.ReverseInsertPts.insert(RIP)) { + for (Instruction *RIP : NewReleaseRetainRRI.ReverseInsertPts) { + if (RetainsToMove.ReverseInsertPts.insert(RIP).second) { // If we overflow when we compute the path count, don't // remove/move anything. const BBState &RIPBBState = BBStates[RIP->getParent()]; @@ -2850,8 +2817,8 @@ bool ObjCARCOpt::OptimizeSequences(Function &F) { /// shared pointer argument. Note that Retain need not be in BB. static bool HasSafePathToPredecessorCall(const Value *Arg, Instruction *Retain, - SmallPtrSet<Instruction *, 4> &DepInsts, - SmallPtrSet<const BasicBlock *, 4> &Visited, + SmallPtrSetImpl<Instruction *> &DepInsts, + SmallPtrSetImpl<const BasicBlock *> &Visited, ProvenanceAnalysis &PA) { FindDependencies(CanChangeRetainCount, Arg, Retain->getParent(), Retain, DepInsts, Visited, PA); @@ -2879,8 +2846,8 @@ HasSafePathToPredecessorCall(const Value *Arg, Instruction *Retain, static CallInst * FindPredecessorRetainWithSafePath(const Value *Arg, BasicBlock *BB, Instruction *Autorelease, - SmallPtrSet<Instruction *, 4> &DepInsts, - SmallPtrSet<const BasicBlock *, 4> &Visited, + SmallPtrSetImpl<Instruction *> &DepInsts, + SmallPtrSetImpl<const BasicBlock *> &Visited, ProvenanceAnalysis &PA) { FindDependencies(CanChangeRetainCount, Arg, BB, Autorelease, DepInsts, Visited, PA); @@ -2906,8 +2873,8 @@ FindPredecessorRetainWithSafePath(const Value *Arg, BasicBlock *BB, static CallInst * FindPredecessorAutoreleaseWithSafePath(const Value *Arg, BasicBlock *BB, ReturnInst *Ret, - SmallPtrSet<Instruction *, 4> &DepInsts, - SmallPtrSet<const BasicBlock *, 4> &V, + SmallPtrSetImpl<Instruction *> &DepInsts, + SmallPtrSetImpl<const BasicBlock *> &V, ProvenanceAnalysis &PA) { FindDependencies(NeedsPositiveRetainCount, Arg, BB, Ret, DepInsts, V, PA); diff --git a/lib/Transforms/ObjCARC/ProvenanceAnalysis.cpp b/lib/Transforms/ObjCARC/ProvenanceAnalysis.cpp index 22be6fdf45f9..410abfc354a0 100644 --- a/lib/Transforms/ObjCARC/ProvenanceAnalysis.cpp +++ b/lib/Transforms/ObjCARC/ProvenanceAnalysis.cpp @@ -62,7 +62,7 @@ bool ProvenanceAnalysis::relatedPHI(const PHINode *A, SmallPtrSet<const Value *, 4> UniqueSrc; for (unsigned i = 0, e = A->getNumIncomingValues(); i != e; ++i) { const Value *PV1 = A->getIncomingValue(i); - if (UniqueSrc.insert(PV1) && related(PV1, B)) + if (UniqueSrc.insert(PV1).second && related(PV1, B)) return true; } @@ -94,7 +94,7 @@ static bool IsStoredObjCPointer(const Value *P) { if (isa<PtrToIntInst>(P)) // Assume the worst. return true; - if (Visited.insert(Ur)) + if (Visited.insert(Ur).second) Worklist.push_back(Ur); } } while (!Worklist.empty()); diff --git a/lib/Transforms/ObjCARC/ProvenanceAnalysis.h b/lib/Transforms/ObjCARC/ProvenanceAnalysis.h index a13fb9e9b029..782046812f05 100644 --- a/lib/Transforms/ObjCARC/ProvenanceAnalysis.h +++ b/lib/Transforms/ObjCARC/ProvenanceAnalysis.h @@ -22,8 +22,8 @@ /// //===----------------------------------------------------------------------===// -#ifndef LLVM_TRANSFORMS_OBJCARC_PROVENANCEANALYSIS_H -#define LLVM_TRANSFORMS_OBJCARC_PROVENANCEANALYSIS_H +#ifndef LLVM_LIB_TRANSFORMS_OBJCARC_PROVENANCEANALYSIS_H +#define LLVM_LIB_TRANSFORMS_OBJCARC_PROVENANCEANALYSIS_H #include "llvm/ADT/DenseMap.h" @@ -77,4 +77,4 @@ public: } // end namespace objcarc } // end namespace llvm -#endif // LLVM_TRANSFORMS_OBJCARC_PROVENANCEANALYSIS_H +#endif diff --git a/lib/Transforms/ObjCARC/ProvenanceAnalysisEvaluator.cpp b/lib/Transforms/ObjCARC/ProvenanceAnalysisEvaluator.cpp new file mode 100644 index 000000000000..d836632dc617 --- /dev/null +++ b/lib/Transforms/ObjCARC/ProvenanceAnalysisEvaluator.cpp @@ -0,0 +1,92 @@ +//===- ProvenanceAnalysisEvaluator.cpp - ObjC ARC Optimization ------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "ProvenanceAnalysis.h" +#include "llvm/Pass.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/Passes.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Function.h" +#include "llvm/Support/raw_ostream.h" + +using namespace llvm; +using namespace llvm::objcarc; + +namespace { +class PAEval : public FunctionPass { + +public: + static char ID; + PAEval(); + void getAnalysisUsage(AnalysisUsage &AU) const override; + bool runOnFunction(Function &F) override; +}; +} + +char PAEval::ID = 0; +PAEval::PAEval() : FunctionPass(ID) {} + +void PAEval::getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequired<AliasAnalysis>(); +} + +static StringRef getName(Value *V) { + StringRef Name = V->getName(); + if (Name.startswith("\1")) + return Name.substr(1); + return Name; +} + +static void insertIfNamed(SetVector<Value *> &Values, Value *V) { + if (!V->hasName()) + return; + Values.insert(V); +} + +bool PAEval::runOnFunction(Function &F) { + SetVector<Value *> Values; + + for (auto &Arg : F.args()) + insertIfNamed(Values, &Arg); + + for (auto I = inst_begin(F), E = inst_end(F); I != E; ++I) { + insertIfNamed(Values, &*I); + + for (auto &Op : I->operands()) + insertIfNamed(Values, Op); + } + + ProvenanceAnalysis PA; + PA.setAA(&getAnalysis<AliasAnalysis>()); + + for (Value *V1 : Values) { + StringRef NameV1 = getName(V1); + for (Value *V2 : Values) { + StringRef NameV2 = getName(V2); + if (NameV1 >= NameV2) + continue; + errs() << NameV1 << " and " << NameV2; + if (PA.related(V1, V2)) + errs() << " are related.\n"; + else + errs() << " are not related.\n"; + } + } + + return false; +} + +FunctionPass *llvm::createPAEvalPass() { return new PAEval(); } + +INITIALIZE_PASS_BEGIN(PAEval, "pa-eval", + "Evaluate ProvenanceAnalysis on all pairs", false, true) +INITIALIZE_AG_DEPENDENCY(AliasAnalysis) +INITIALIZE_PASS_END(PAEval, "pa-eval", + "Evaluate ProvenanceAnalysis on all pairs", false, true) diff --git a/lib/Transforms/Scalar/ADCE.cpp b/lib/Transforms/Scalar/ADCE.cpp index 1a3a4aadce6a..3d9198469bc5 100644 --- a/lib/Transforms/Scalar/ADCE.cpp +++ b/lib/Transforms/Scalar/ADCE.cpp @@ -73,7 +73,7 @@ bool ADCE::runOnFunction(Function& F) { for (Instruction::op_iterator OI = curr->op_begin(), OE = curr->op_end(); OI != OE; ++OI) if (Instruction* Inst = dyn_cast<Instruction>(OI)) - if (alive.insert(Inst)) + if (alive.insert(Inst).second) worklist.push_back(Inst); } diff --git a/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp b/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp new file mode 100644 index 000000000000..f48cefaa4fba --- /dev/null +++ b/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp @@ -0,0 +1,428 @@ +//===----------------------- AlignmentFromAssumptions.cpp -----------------===// +// Set Load/Store Alignments From Assumptions +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements a ScalarEvolution-based transformation to set +// the alignments of load, stores and memory intrinsics based on the truth +// expressions of assume intrinsics. The primary motivation is to handle +// complex alignment assumptions that apply to vector loads and stores that +// appear after vectorization and unrolling. +// +//===----------------------------------------------------------------------===// + +#define AA_NAME "alignment-from-assumptions" +#define DEBUG_TYPE AA_NAME +#include "llvm/Transforms/Scalar.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +using namespace llvm; + +STATISTIC(NumLoadAlignChanged, + "Number of loads changed by alignment assumptions"); +STATISTIC(NumStoreAlignChanged, + "Number of stores changed by alignment assumptions"); +STATISTIC(NumMemIntAlignChanged, + "Number of memory intrinsics changed by alignment assumptions"); + +namespace { +struct AlignmentFromAssumptions : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + AlignmentFromAssumptions() : FunctionPass(ID) { + initializeAlignmentFromAssumptionsPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F); + + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequired<AssumptionCacheTracker>(); + AU.addRequired<ScalarEvolution>(); + AU.addRequired<DominatorTreeWrapperPass>(); + + AU.setPreservesCFG(); + AU.addPreserved<LoopInfo>(); + AU.addPreserved<DominatorTreeWrapperPass>(); + AU.addPreserved<ScalarEvolution>(); + } + + // For memory transfers, we need a common alignment for both the source and + // destination. If we have a new alignment for only one operand of a transfer + // instruction, save it in these maps. If we reach the other operand through + // another assumption later, then we may change the alignment at that point. + DenseMap<MemTransferInst *, unsigned> NewDestAlignments, NewSrcAlignments; + + ScalarEvolution *SE; + DominatorTree *DT; + const DataLayout *DL; + + bool extractAlignmentInfo(CallInst *I, Value *&AAPtr, const SCEV *&AlignSCEV, + const SCEV *&OffSCEV); + bool processAssumption(CallInst *I); +}; +} + +char AlignmentFromAssumptions::ID = 0; +static const char aip_name[] = "Alignment from assumptions"; +INITIALIZE_PASS_BEGIN(AlignmentFromAssumptions, AA_NAME, + aip_name, false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(ScalarEvolution) +INITIALIZE_PASS_END(AlignmentFromAssumptions, AA_NAME, + aip_name, false, false) + +FunctionPass *llvm::createAlignmentFromAssumptionsPass() { + return new AlignmentFromAssumptions(); +} + +// Given an expression for the (constant) alignment, AlignSCEV, and an +// expression for the displacement between a pointer and the aligned address, +// DiffSCEV, compute the alignment of the displaced pointer if it can be reduced +// to a constant. Using SCEV to compute alignment handles the case where +// DiffSCEV is a recurrence with constant start such that the aligned offset +// is constant. e.g. {16,+,32} % 32 -> 16. +static unsigned getNewAlignmentDiff(const SCEV *DiffSCEV, + const SCEV *AlignSCEV, + ScalarEvolution *SE) { + // DiffUnits = Diff % int64_t(Alignment) + const SCEV *DiffAlignDiv = SE->getUDivExpr(DiffSCEV, AlignSCEV); + const SCEV *DiffAlign = SE->getMulExpr(DiffAlignDiv, AlignSCEV); + const SCEV *DiffUnitsSCEV = SE->getMinusSCEV(DiffAlign, DiffSCEV); + + DEBUG(dbgs() << "\talignment relative to " << *AlignSCEV << " is " << + *DiffUnitsSCEV << " (diff: " << *DiffSCEV << ")\n"); + + if (const SCEVConstant *ConstDUSCEV = + dyn_cast<SCEVConstant>(DiffUnitsSCEV)) { + int64_t DiffUnits = ConstDUSCEV->getValue()->getSExtValue(); + + // If the displacement is an exact multiple of the alignment, then the + // displaced pointer has the same alignment as the aligned pointer, so + // return the alignment value. + if (!DiffUnits) + return (unsigned) + cast<SCEVConstant>(AlignSCEV)->getValue()->getSExtValue(); + + // If the displacement is not an exact multiple, but the remainder is a + // constant, then return this remainder (but only if it is a power of 2). + uint64_t DiffUnitsAbs = abs64(DiffUnits); + if (isPowerOf2_64(DiffUnitsAbs)) + return (unsigned) DiffUnitsAbs; + } + + return 0; +} + +// There is an address given by an offset OffSCEV from AASCEV which has an +// alignment AlignSCEV. Use that information, if possible, to compute a new +// alignment for Ptr. +static unsigned getNewAlignment(const SCEV *AASCEV, const SCEV *AlignSCEV, + const SCEV *OffSCEV, Value *Ptr, + ScalarEvolution *SE) { + const SCEV *PtrSCEV = SE->getSCEV(Ptr); + const SCEV *DiffSCEV = SE->getMinusSCEV(PtrSCEV, AASCEV); + + // On 32-bit platforms, DiffSCEV might now have type i32 -- we've always + // sign-extended OffSCEV to i64, so make sure they agree again. + DiffSCEV = SE->getNoopOrSignExtend(DiffSCEV, OffSCEV->getType()); + + // What we really want to know is the overall offset to the aligned + // address. This address is displaced by the provided offset. + DiffSCEV = SE->getMinusSCEV(DiffSCEV, OffSCEV); + + DEBUG(dbgs() << "AFI: alignment of " << *Ptr << " relative to " << + *AlignSCEV << " and offset " << *OffSCEV << + " using diff " << *DiffSCEV << "\n"); + + unsigned NewAlignment = getNewAlignmentDiff(DiffSCEV, AlignSCEV, SE); + DEBUG(dbgs() << "\tnew alignment: " << NewAlignment << "\n"); + + if (NewAlignment) { + return NewAlignment; + } else if (const SCEVAddRecExpr *DiffARSCEV = + dyn_cast<SCEVAddRecExpr>(DiffSCEV)) { + // The relative offset to the alignment assumption did not yield a constant, + // but we should try harder: if we assume that a is 32-byte aligned, then in + // for (i = 0; i < 1024; i += 4) r += a[i]; not all of the loads from a are + // 32-byte aligned, but instead alternate between 32 and 16-byte alignment. + // As a result, the new alignment will not be a constant, but can still + // be improved over the default (of 4) to 16. + + const SCEV *DiffStartSCEV = DiffARSCEV->getStart(); + const SCEV *DiffIncSCEV = DiffARSCEV->getStepRecurrence(*SE); + + DEBUG(dbgs() << "\ttrying start/inc alignment using start " << + *DiffStartSCEV << " and inc " << *DiffIncSCEV << "\n"); + + // Now compute the new alignment using the displacement to the value in the + // first iteration, and also the alignment using the per-iteration delta. + // If these are the same, then use that answer. Otherwise, use the smaller + // one, but only if it divides the larger one. + NewAlignment = getNewAlignmentDiff(DiffStartSCEV, AlignSCEV, SE); + unsigned NewIncAlignment = getNewAlignmentDiff(DiffIncSCEV, AlignSCEV, SE); + + DEBUG(dbgs() << "\tnew start alignment: " << NewAlignment << "\n"); + DEBUG(dbgs() << "\tnew inc alignment: " << NewIncAlignment << "\n"); + + if (!NewAlignment || !NewIncAlignment) { + return 0; + } else if (NewAlignment > NewIncAlignment) { + if (NewAlignment % NewIncAlignment == 0) { + DEBUG(dbgs() << "\tnew start/inc alignment: " << + NewIncAlignment << "\n"); + return NewIncAlignment; + } + } else if (NewIncAlignment > NewAlignment) { + if (NewIncAlignment % NewAlignment == 0) { + DEBUG(dbgs() << "\tnew start/inc alignment: " << + NewAlignment << "\n"); + return NewAlignment; + } + } else if (NewIncAlignment == NewAlignment) { + DEBUG(dbgs() << "\tnew start/inc alignment: " << + NewAlignment << "\n"); + return NewAlignment; + } + } + + return 0; +} + +bool AlignmentFromAssumptions::extractAlignmentInfo(CallInst *I, + Value *&AAPtr, const SCEV *&AlignSCEV, + const SCEV *&OffSCEV) { + // An alignment assume must be a statement about the least-significant + // bits of the pointer being zero, possibly with some offset. + ICmpInst *ICI = dyn_cast<ICmpInst>(I->getArgOperand(0)); + if (!ICI) + return false; + + // This must be an expression of the form: x & m == 0. + if (ICI->getPredicate() != ICmpInst::ICMP_EQ) + return false; + + // Swap things around so that the RHS is 0. + Value *CmpLHS = ICI->getOperand(0); + Value *CmpRHS = ICI->getOperand(1); + const SCEV *CmpLHSSCEV = SE->getSCEV(CmpLHS); + const SCEV *CmpRHSSCEV = SE->getSCEV(CmpRHS); + if (CmpLHSSCEV->isZero()) + std::swap(CmpLHS, CmpRHS); + else if (!CmpRHSSCEV->isZero()) + return false; + + BinaryOperator *CmpBO = dyn_cast<BinaryOperator>(CmpLHS); + if (!CmpBO || CmpBO->getOpcode() != Instruction::And) + return false; + + // Swap things around so that the right operand of the and is a constant + // (the mask); we cannot deal with variable masks. + Value *AndLHS = CmpBO->getOperand(0); + Value *AndRHS = CmpBO->getOperand(1); + const SCEV *AndLHSSCEV = SE->getSCEV(AndLHS); + const SCEV *AndRHSSCEV = SE->getSCEV(AndRHS); + if (isa<SCEVConstant>(AndLHSSCEV)) { + std::swap(AndLHS, AndRHS); + std::swap(AndLHSSCEV, AndRHSSCEV); + } + + const SCEVConstant *MaskSCEV = dyn_cast<SCEVConstant>(AndRHSSCEV); + if (!MaskSCEV) + return false; + + // The mask must have some trailing ones (otherwise the condition is + // trivial and tells us nothing about the alignment of the left operand). + unsigned TrailingOnes = + MaskSCEV->getValue()->getValue().countTrailingOnes(); + if (!TrailingOnes) + return false; + + // Cap the alignment at the maximum with which LLVM can deal (and make sure + // we don't overflow the shift). + uint64_t Alignment; + TrailingOnes = std::min(TrailingOnes, + unsigned(sizeof(unsigned) * CHAR_BIT - 1)); + Alignment = std::min(1u << TrailingOnes, +Value::MaximumAlignment); + + Type *Int64Ty = Type::getInt64Ty(I->getParent()->getParent()->getContext()); + AlignSCEV = SE->getConstant(Int64Ty, Alignment); + + // The LHS might be a ptrtoint instruction, or it might be the pointer + // with an offset. + AAPtr = nullptr; + OffSCEV = nullptr; + if (PtrToIntInst *PToI = dyn_cast<PtrToIntInst>(AndLHS)) { + AAPtr = PToI->getPointerOperand(); + OffSCEV = SE->getConstant(Int64Ty, 0); + } else if (const SCEVAddExpr* AndLHSAddSCEV = + dyn_cast<SCEVAddExpr>(AndLHSSCEV)) { + // Try to find the ptrtoint; subtract it and the rest is the offset. + for (SCEVAddExpr::op_iterator J = AndLHSAddSCEV->op_begin(), + JE = AndLHSAddSCEV->op_end(); J != JE; ++J) + if (const SCEVUnknown *OpUnk = dyn_cast<SCEVUnknown>(*J)) + if (PtrToIntInst *PToI = dyn_cast<PtrToIntInst>(OpUnk->getValue())) { + AAPtr = PToI->getPointerOperand(); + OffSCEV = SE->getMinusSCEV(AndLHSAddSCEV, *J); + break; + } + } + + if (!AAPtr) + return false; + + // Sign extend the offset to 64 bits (so that it is like all of the other + // expressions). + unsigned OffSCEVBits = OffSCEV->getType()->getPrimitiveSizeInBits(); + if (OffSCEVBits < 64) + OffSCEV = SE->getSignExtendExpr(OffSCEV, Int64Ty); + else if (OffSCEVBits > 64) + return false; + + AAPtr = AAPtr->stripPointerCasts(); + return true; +} + +bool AlignmentFromAssumptions::processAssumption(CallInst *ACall) { + Value *AAPtr; + const SCEV *AlignSCEV, *OffSCEV; + if (!extractAlignmentInfo(ACall, AAPtr, AlignSCEV, OffSCEV)) + return false; + + const SCEV *AASCEV = SE->getSCEV(AAPtr); + + // Apply the assumption to all other users of the specified pointer. + SmallPtrSet<Instruction *, 32> Visited; + SmallVector<Instruction*, 16> WorkList; + for (User *J : AAPtr->users()) { + if (J == ACall) + continue; + + if (Instruction *K = dyn_cast<Instruction>(J)) + if (isValidAssumeForContext(ACall, K, DL, DT)) + WorkList.push_back(K); + } + + while (!WorkList.empty()) { + Instruction *J = WorkList.pop_back_val(); + + if (LoadInst *LI = dyn_cast<LoadInst>(J)) { + unsigned NewAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV, + LI->getPointerOperand(), SE); + + if (NewAlignment > LI->getAlignment()) { + LI->setAlignment(NewAlignment); + ++NumLoadAlignChanged; + } + } else if (StoreInst *SI = dyn_cast<StoreInst>(J)) { + unsigned NewAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV, + SI->getPointerOperand(), SE); + + if (NewAlignment > SI->getAlignment()) { + SI->setAlignment(NewAlignment); + ++NumStoreAlignChanged; + } + } else if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(J)) { + unsigned NewDestAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV, + MI->getDest(), SE); + + // For memory transfers, we need a common alignment for both the + // source and destination. If we have a new alignment for this + // instruction, but only for one operand, save it. If we reach the + // other operand through another assumption later, then we may + // change the alignment at that point. + if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(MI)) { + unsigned NewSrcAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV, + MTI->getSource(), SE); + + DenseMap<MemTransferInst *, unsigned>::iterator DI = + NewDestAlignments.find(MTI); + unsigned AltDestAlignment = (DI == NewDestAlignments.end()) ? + 0 : DI->second; + + DenseMap<MemTransferInst *, unsigned>::iterator SI = + NewSrcAlignments.find(MTI); + unsigned AltSrcAlignment = (SI == NewSrcAlignments.end()) ? + 0 : SI->second; + + DEBUG(dbgs() << "\tmem trans: " << NewDestAlignment << " " << + AltDestAlignment << " " << NewSrcAlignment << + " " << AltSrcAlignment << "\n"); + + // Of these four alignments, pick the largest possible... + unsigned NewAlignment = 0; + if (NewDestAlignment <= std::max(NewSrcAlignment, AltSrcAlignment)) + NewAlignment = std::max(NewAlignment, NewDestAlignment); + if (AltDestAlignment <= std::max(NewSrcAlignment, AltSrcAlignment)) + NewAlignment = std::max(NewAlignment, AltDestAlignment); + if (NewSrcAlignment <= std::max(NewDestAlignment, AltDestAlignment)) + NewAlignment = std::max(NewAlignment, NewSrcAlignment); + if (AltSrcAlignment <= std::max(NewDestAlignment, AltDestAlignment)) + NewAlignment = std::max(NewAlignment, AltSrcAlignment); + + if (NewAlignment > MI->getAlignment()) { + MI->setAlignment(ConstantInt::get(Type::getInt32Ty( + MI->getParent()->getContext()), NewAlignment)); + ++NumMemIntAlignChanged; + } + + NewDestAlignments.insert(std::make_pair(MTI, NewDestAlignment)); + NewSrcAlignments.insert(std::make_pair(MTI, NewSrcAlignment)); + } else if (NewDestAlignment > MI->getAlignment()) { + assert((!isa<MemIntrinsic>(MI) || isa<MemSetInst>(MI)) && + "Unknown memory intrinsic"); + + MI->setAlignment(ConstantInt::get(Type::getInt32Ty( + MI->getParent()->getContext()), NewDestAlignment)); + ++NumMemIntAlignChanged; + } + } + + // Now that we've updated that use of the pointer, look for other uses of + // the pointer to update. + Visited.insert(J); + for (User *UJ : J->users()) { + Instruction *K = cast<Instruction>(UJ); + if (!Visited.count(K) && isValidAssumeForContext(ACall, K, DL, DT)) + WorkList.push_back(K); + } + } + + return true; +} + +bool AlignmentFromAssumptions::runOnFunction(Function &F) { + bool Changed = false; + auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); + SE = &getAnalysis<ScalarEvolution>(); + DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + DataLayoutPass *DLP = getAnalysisIfAvailable<DataLayoutPass>(); + DL = DLP ? &DLP->getDataLayout() : nullptr; + + NewDestAlignments.clear(); + NewSrcAlignments.clear(); + + for (auto &AssumeVH : AC.assumptions()) + if (AssumeVH) + Changed |= processAssumption(cast<CallInst>(AssumeVH)); + + return Changed; +} + diff --git a/lib/Transforms/Scalar/CMakeLists.txt b/lib/Transforms/Scalar/CMakeLists.txt index 261ddda30150..b3ee11ed67cd 100644 --- a/lib/Transforms/Scalar/CMakeLists.txt +++ b/lib/Transforms/Scalar/CMakeLists.txt @@ -1,5 +1,6 @@ add_llvm_library(LLVMScalarOpts ADCE.cpp + AlignmentFromAssumptions.cpp ConstantHoisting.cpp ConstantProp.cpp CorrelatedValuePropagation.cpp diff --git a/lib/Transforms/Scalar/ConstantHoisting.cpp b/lib/Transforms/Scalar/ConstantHoisting.cpp index 763d02b9fcd6..27c177a542e3 100644 --- a/lib/Transforms/Scalar/ConstantHoisting.cpp +++ b/lib/Transforms/Scalar/ConstantHoisting.cpp @@ -91,7 +91,7 @@ struct RebasedConstantInfo { Constant *Offset; RebasedConstantInfo(ConstantUseListType &&Uses, Constant *Offset) - : Uses(Uses), Offset(Offset) { } + : Uses(std::move(Uses)), Offset(Offset) { } }; /// \brief A base constant and all its rebased constants. @@ -395,7 +395,7 @@ void ConstantHoisting::findAndMakeBaseConstant(ConstCandVecType::iterator S, ConstInfo.RebasedConstants.push_back( RebasedConstantInfo(std::move(ConstCand->Uses), Offset)); } - ConstantVec.push_back(ConstInfo); + ConstantVec.push_back(std::move(ConstInfo)); } /// \brief Finds and combines constant candidates that can be easily diff --git a/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp b/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp index 082946229b35..5a3b5cf34cc3 100644 --- a/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp +++ b/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp @@ -73,7 +73,7 @@ bool CorrelatedValuePropagation::processSelect(SelectInst *S) { if (S->getType()->isVectorTy()) return false; if (isa<Constant>(S->getOperand(0))) return false; - Constant *C = LVI->getConstant(S->getOperand(0), S->getParent()); + Constant *C = LVI->getConstant(S->getOperand(0), S->getParent(), S); if (!C) return false; ConstantInt *CI = dyn_cast<ConstantInt>(C); @@ -100,7 +100,7 @@ bool CorrelatedValuePropagation::processPHI(PHINode *P) { Value *Incoming = P->getIncomingValue(i); if (isa<Constant>(Incoming)) continue; - Value *V = LVI->getConstantOnEdge(Incoming, P->getIncomingBlock(i), BB); + Value *V = LVI->getConstantOnEdge(Incoming, P->getIncomingBlock(i), BB, P); // Look if the incoming value is a select with a constant but LVI tells us // that the incoming value can never be that constant. In that case replace @@ -114,7 +114,7 @@ bool CorrelatedValuePropagation::processPHI(PHINode *P) { if (!C) continue; if (LVI->getPredicateOnEdge(ICmpInst::ICMP_EQ, SI, C, - P->getIncomingBlock(i), BB) != + P->getIncomingBlock(i), BB, P) != LazyValueInfo::False) continue; @@ -126,6 +126,7 @@ bool CorrelatedValuePropagation::processPHI(PHINode *P) { Changed = true; } + // FIXME: Provide DL, TLI, DT, AT to SimplifyInstruction. if (Value *V = SimplifyInstruction(P)) { P->replaceAllUsesWith(V); P->eraseFromParent(); @@ -147,7 +148,7 @@ bool CorrelatedValuePropagation::processMemAccess(Instruction *I) { if (isa<Constant>(Pointer)) return false; - Constant *C = LVI->getConstant(Pointer, I->getParent()); + Constant *C = LVI->getConstant(Pointer, I->getParent(), I); if (!C) return false; ++NumMemAccess; @@ -173,13 +174,15 @@ bool CorrelatedValuePropagation::processCmp(CmpInst *C) { if (PI == PE) return false; LazyValueInfo::Tristate Result = LVI->getPredicateOnEdge(C->getPredicate(), - C->getOperand(0), Op1, *PI, C->getParent()); + C->getOperand(0), Op1, *PI, + C->getParent(), C); if (Result == LazyValueInfo::Unknown) return false; ++PI; while (PI != PE) { LazyValueInfo::Tristate Res = LVI->getPredicateOnEdge(C->getPredicate(), - C->getOperand(0), Op1, *PI, C->getParent()); + C->getOperand(0), Op1, *PI, + C->getParent(), C); if (Res != Result) return false; ++PI; } @@ -229,7 +232,8 @@ bool CorrelatedValuePropagation::processSwitch(SwitchInst *SI) { for (pred_iterator PI = PB; PI != PE; ++PI) { // Is the switch condition equal to the case value? LazyValueInfo::Tristate Value = LVI->getPredicateOnEdge(CmpInst::ICMP_EQ, - Cond, Case, *PI, BB); + Cond, Case, *PI, + BB, SI); // Give up on this case if nothing is known. if (Value == LazyValueInfo::Unknown) { State = LazyValueInfo::Unknown; diff --git a/lib/Transforms/Scalar/DeadStoreElimination.cpp b/lib/Transforms/Scalar/DeadStoreElimination.cpp index 3af8ee7546fb..a1ddc00da5ba 100644 --- a/lib/Transforms/Scalar/DeadStoreElimination.cpp +++ b/lib/Transforms/Scalar/DeadStoreElimination.cpp @@ -356,15 +356,8 @@ static OverwriteResult isOverwrite(const AliasAnalysis::Location &Later, // If we don't know the sizes of either access, then we can't do a // comparison. if (Later.Size == AliasAnalysis::UnknownSize || - Earlier.Size == AliasAnalysis::UnknownSize) { - // If we have no DataLayout information around, then the size of the store - // is inferrable from the pointee type. If they are the same type, then - // we know that the store is safe. - if (DL == nullptr && Later.Ptr->getType() == Earlier.Ptr->getType()) - return OverwriteComplete; - + Earlier.Size == AliasAnalysis::UnknownSize) return OverwriteUnknown; - } // Make sure that the Later size is >= the Earlier size. if (Later.Size >= Earlier.Size) diff --git a/lib/Transforms/Scalar/EarlyCSE.cpp b/lib/Transforms/Scalar/EarlyCSE.cpp index 735f5c194cb5..394b0d3de7bd 100644 --- a/lib/Transforms/Scalar/EarlyCSE.cpp +++ b/lib/Transforms/Scalar/EarlyCSE.cpp @@ -16,17 +16,21 @@ #include "llvm/ADT/Hashing.h" #include "llvm/ADT/ScopedHashTable.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/Pass.h" #include "llvm/Support/Debug.h" #include "llvm/Support/RecyclingAllocator.h" #include "llvm/Target/TargetLibraryInfo.h" #include "llvm/Transforms/Utils/Local.h" -#include <vector> +#include <deque> using namespace llvm; +using namespace llvm::PatternMatch; #define DEBUG_TYPE "early-cse" @@ -266,6 +270,7 @@ public: const DataLayout *DL; const TargetLibraryInfo *TLI; DominatorTree *DT; + AssumptionCache *AC; typedef RecyclingAllocator<BumpPtrAllocator, ScopedHashTableVal<SimpleValue, Value*> > AllocatorTy; typedef ScopedHashTable<SimpleValue, Value*, DenseMapInfo<SimpleValue>, @@ -378,6 +383,7 @@ private: // This transformation requires dominator postdominator info void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<AssumptionCacheTracker>(); AU.addRequired<DominatorTreeWrapperPass>(); AU.addRequired<TargetLibraryInfo>(); AU.setPreservesCFG(); @@ -393,6 +399,7 @@ FunctionPass *llvm::createEarlyCSEPass() { } INITIALIZE_PASS_BEGIN(EarlyCSE, "early-cse", "Early CSE", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfo) INITIALIZE_PASS_END(EarlyCSE, "early-cse", "Early CSE", false, false) @@ -431,9 +438,18 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { continue; } + // Skip assume intrinsics, they don't really have side effects (although + // they're marked as such to ensure preservation of control dependencies), + // and this pass will not disturb any of the assumption's control + // dependencies. + if (match(Inst, m_Intrinsic<Intrinsic::assume>())) { + DEBUG(dbgs() << "EarlyCSE skipping assumption: " << *Inst << '\n'); + continue; + } + // If the instruction can be simplified (e.g. X+0 = X) then replace it with // its simpler value. - if (Value *V = SimplifyInstruction(Inst, DL, TLI, DT)) { + if (Value *V = SimplifyInstruction(Inst, DL, TLI, DT, AC)) { DEBUG(dbgs() << "EarlyCSE Simplify: " << *Inst << " to: " << *V << '\n'); Inst->replaceAllUsesWith(V); Inst->eraseFromParent(); @@ -530,7 +546,7 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { Changed = true; ++NumDSE; LastStore = nullptr; - continue; + // fallthrough - we can exploit information about this store } // Okay, we just invalidated anything we knew about loaded values. Try @@ -556,12 +572,17 @@ bool EarlyCSE::runOnFunction(Function &F) { if (skipOptnoneFunction(F)) return false; - std::vector<StackNode *> nodesToProcess; + // Note, deque is being used here because there is significant performance gains + // over vector when the container becomes very large due to the specific access + // patterns. For more information see the mailing list discussion on this: + // http://lists.cs.uiuc.edu/pipermail/llvm-commits/Week-of-Mon-20120116/135228.html + std::deque<StackNode *> nodesToProcess; DataLayoutPass *DLP = getAnalysisIfAvailable<DataLayoutPass>(); DL = DLP ? &DLP->getDataLayout() : nullptr; TLI = &getAnalysis<TargetLibraryInfo>(); DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); // Tables that the pass uses when walking the domtree. ScopedHTType AVTable; diff --git a/lib/Transforms/Scalar/GVN.cpp b/lib/Transforms/Scalar/GVN.cpp index 106eba099ca0..b814b2525dca 100644 --- a/lib/Transforms/Scalar/GVN.cpp +++ b/lib/Transforms/Scalar/GVN.cpp @@ -20,10 +20,12 @@ #include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/Hashing.h" #include "llvm/ADT/MapVector.h" +#include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/CFG.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" @@ -45,6 +47,7 @@ #include "llvm/Support/Debug.h" #include "llvm/Target/TargetLibraryInfo.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/SSAUpdater.h" #include <vector> using namespace llvm; @@ -590,6 +593,7 @@ namespace { DominatorTree *DT; const DataLayout *DL; const TargetLibraryInfo *TLI; + AssumptionCache *AC; SetVector<BasicBlock *> DeadBlocks; ValueTable VN; @@ -679,6 +683,7 @@ namespace { // This transformation requires dominator postdominator info void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<AssumptionCacheTracker>(); AU.addRequired<DominatorTreeWrapperPass>(); AU.addRequired<TargetLibraryInfo>(); if (!NoLoads) @@ -705,6 +710,7 @@ namespace { void dump(DenseMap<uint32_t, Value*> &d); bool iterateOnFunction(Function &F); bool performPRE(Function &F); + bool performScalarPRE(Instruction *I); Value *findLeader(const BasicBlock *BB, uint32_t num); void cleanupGlobalSets(); void verifyRemoved(const Instruction *I) const; @@ -727,6 +733,7 @@ FunctionPass *llvm::createGVNPass(bool NoLoads) { } INITIALIZE_PASS_BEGIN(GVN, "gvn", "Global Value Numbering", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(MemoryDependenceAnalysis) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfo) @@ -1616,7 +1623,7 @@ bool GVN::PerformLoadPRE(LoadInst *LI, AvailValInBlkVect &ValuesPerBlock, // If all preds have a single successor, then we know it is safe to insert // the load on the pred (?!?), so we can insert code to materialize the // pointer if it is not available. - PHITransAddr Address(LI->getPointerOperand(), DL); + PHITransAddr Address(LI->getPointerOperand(), DL, AC); Value *LoadPtr = nullptr; LoadPtr = Address.PHITranslateWithInsertion(LoadBB, UnavailablePred, *DT, NewInsts); @@ -1669,9 +1676,11 @@ bool GVN::PerformLoadPRE(LoadInst *LI, AvailValInBlkVect &ValuesPerBlock, LI->getAlignment(), UnavailablePred->getTerminator()); - // Transfer the old load's TBAA tag to the new load. - if (MDNode *Tag = LI->getMetadata(LLVMContext::MD_tbaa)) - NewLoad->setMetadata(LLVMContext::MD_tbaa, Tag); + // Transfer the old load's AA tags to the new load. + AAMDNodes Tags; + LI->getAAMetadata(Tags); + if (Tags) + NewLoad->setAAMetadata(Tags); // Transfer DebugLoc. NewLoad->setDebugLoc(LI->getDebugLoc()); @@ -1700,8 +1709,7 @@ bool GVN::PerformLoadPRE(LoadInst *LI, AvailValInBlkVect &ValuesPerBlock, bool GVN::processNonLocalLoad(LoadInst *LI) { // Step 1: Find the non-local dependencies of the load. LoadDepVect Deps; - AliasAnalysis::Location Loc = VN.getAliasAnalysis()->getLocation(LI); - MD->getNonLocalPointerDependency(Loc, true, LI->getParent(), Deps); + MD->getNonLocalPointerDependency(LI, Deps); // If we had to process more than one hundred blocks to find the // dependencies, this load isn't worth worrying about. Optimizing @@ -1722,6 +1730,15 @@ bool GVN::processNonLocalLoad(LoadInst *LI) { return false; } + // If this load follows a GEP, see if we can PRE the indices before analyzing. + if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(LI->getOperand(0))) { + for (GetElementPtrInst::op_iterator OI = GEP->idx_begin(), + OE = GEP->idx_end(); + OI != OE; ++OI) + if (Instruction *I = dyn_cast<Instruction>(OI->get())) + performScalarPRE(I); + } + // Step 2: Analyze the availability of the load AvailValInBlkVect ValuesPerBlock; UnavailBlkVect UnavailableBlocks; @@ -1774,36 +1791,24 @@ static void patchReplacementInstruction(Instruction *I, Value *Repl) { ReplOp->setHasNoUnsignedWrap(false); } if (Instruction *ReplInst = dyn_cast<Instruction>(Repl)) { - SmallVector<std::pair<unsigned, MDNode*>, 4> Metadata; - ReplInst->getAllMetadataOtherThanDebugLoc(Metadata); - for (int i = 0, n = Metadata.size(); i < n; ++i) { - unsigned Kind = Metadata[i].first; - MDNode *IMD = I->getMetadata(Kind); - MDNode *ReplMD = Metadata[i].second; - switch(Kind) { - default: - ReplInst->setMetadata(Kind, nullptr); // Remove unknown metadata - break; - case LLVMContext::MD_dbg: - llvm_unreachable("getAllMetadataOtherThanDebugLoc returned a MD_dbg"); - case LLVMContext::MD_tbaa: - ReplInst->setMetadata(Kind, MDNode::getMostGenericTBAA(IMD, ReplMD)); - break; - case LLVMContext::MD_range: - ReplInst->setMetadata(Kind, MDNode::getMostGenericRange(IMD, ReplMD)); - break; - case LLVMContext::MD_prof: - llvm_unreachable("MD_prof in a non-terminator instruction"); - break; - case LLVMContext::MD_fpmath: - ReplInst->setMetadata(Kind, MDNode::getMostGenericFPMath(IMD, ReplMD)); - break; - case LLVMContext::MD_invariant_load: - // Only set the !invariant.load if it is present in both instructions. - ReplInst->setMetadata(Kind, IMD); - break; - } - } + // FIXME: If both the original and replacement value are part of the + // same control-flow region (meaning that the execution of one + // guarentees the executation of the other), then we can combine the + // noalias scopes here and do better than the general conservative + // answer used in combineMetadata(). + + // In general, GVN unifies expressions over different control-flow + // regions, and so we need a conservative combination of the noalias + // scopes. + unsigned KnownIDs[] = { + LLVMContext::MD_tbaa, + LLVMContext::MD_alias_scope, + LLVMContext::MD_noalias, + LLVMContext::MD_range, + LLVMContext::MD_fpmath, + LLVMContext::MD_invariant_load, + }; + combineMetadata(ReplInst, I, KnownIDs); } } @@ -2101,15 +2106,15 @@ bool GVN::propagateEquality(Value *LHS, Value *RHS, std::swap(LHS, RHS); assert((isa<Argument>(LHS) || isa<Instruction>(LHS)) && "Unexpected value!"); - // If there is no obvious reason to prefer the left-hand side over the right- - // hand side, ensure the longest lived term is on the right-hand side, so the - // shortest lived term will be replaced by the longest lived. This tends to - // expose more simplifications. + // If there is no obvious reason to prefer the left-hand side over the + // right-hand side, ensure the longest lived term is on the right-hand side, + // so the shortest lived term will be replaced by the longest lived. + // This tends to expose more simplifications. uint32_t LVN = VN.lookup_or_add(LHS); if ((isa<Argument>(LHS) && isa<Argument>(RHS)) || (isa<Instruction>(LHS) && isa<Instruction>(RHS))) { - // Move the 'oldest' value to the right-hand side, using the value number as - // a proxy for age. + // Move the 'oldest' value to the right-hand side, using the value number + // as a proxy for age. uint32_t RVN = VN.lookup_or_add(RHS); if (LVN < RVN) { std::swap(LHS, RHS); @@ -2138,10 +2143,10 @@ bool GVN::propagateEquality(Value *LHS, Value *RHS, NumGVNEqProp += NumReplacements; } - // Now try to deduce additional equalities from this one. For example, if the - // known equality was "(A != B)" == "false" then it follows that A and B are - // equal in the scope. Only boolean equalities with an explicit true or false - // RHS are currently supported. + // Now try to deduce additional equalities from this one. For example, if + // the known equality was "(A != B)" == "false" then it follows that A and B + // are equal in the scope. Only boolean equalities with an explicit true or + // false RHS are currently supported. if (!RHS->getType()->isIntegerTy(1)) // Not a boolean equality - bail out. continue; @@ -2166,7 +2171,7 @@ bool GVN::propagateEquality(Value *LHS, Value *RHS, // If we are propagating an equality like "(A == B)" == "true" then also // propagate the equality A == B. When propagating a comparison such as // "(A >= B)" == "true", replace all instances of "A < B" with "false". - if (ICmpInst *Cmp = dyn_cast<ICmpInst>(LHS)) { + if (CmpInst *Cmp = dyn_cast<CmpInst>(LHS)) { Value *Op0 = Cmp->getOperand(0), *Op1 = Cmp->getOperand(1); // If "A == B" is known true, or "A != B" is known false, then replace @@ -2175,12 +2180,17 @@ bool GVN::propagateEquality(Value *LHS, Value *RHS, (isKnownFalse && Cmp->getPredicate() == CmpInst::ICMP_NE)) Worklist.push_back(std::make_pair(Op0, Op1)); + // Handle the floating point versions of equality comparisons too. + if ((isKnownTrue && Cmp->getPredicate() == CmpInst::FCMP_OEQ) || + (isKnownFalse && Cmp->getPredicate() == CmpInst::FCMP_UNE)) + Worklist.push_back(std::make_pair(Op0, Op1)); + // If "A >= B" is known true, replace "A < B" with false everywhere. CmpInst::Predicate NotPred = Cmp->getInversePredicate(); Constant *NotVal = ConstantInt::get(Cmp->getType(), isKnownFalse); - // Since we don't have the instruction "A < B" immediately to hand, work out - // the value number that it would have and use that to find an appropriate - // instruction (if any). + // Since we don't have the instruction "A < B" immediately to hand, work + // out the value number that it would have and use that to find an + // appropriate instruction (if any). uint32_t NextNum = VN.getNextUnusedValueNumber(); uint32_t Num = VN.lookup_or_add_cmp(Cmp->getOpcode(), NotPred, Op0, Op1); // If the number we were assigned was brand new then there is no point in @@ -2219,7 +2229,7 @@ bool GVN::processInstruction(Instruction *I) { // to value numbering it. Value numbering often exposes redundancies, for // example if it determines that %y is equal to %x then the instruction // "%z = and i32 %x, %y" becomes "%z = and i32 %x, %x" which we now simplify. - if (Value *V = SimplifyInstruction(I, DL, TLI, DT)) { + if (Value *V = SimplifyInstruction(I, DL, TLI, DT, AC)) { I->replaceAllUsesWith(V); if (MD && V->getType()->getScalarType()->isPointerTy()) MD->invalidateCachedPointerInfo(V); @@ -2339,6 +2349,7 @@ bool GVN::runOnFunction(Function& F) { DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); DataLayoutPass *DLP = getAnalysisIfAvailable<DataLayoutPass>(); DL = DLP ? &DLP->getDataLayout() : nullptr; + AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); TLI = &getAnalysis<TargetLibraryInfo>(); VN.setAliasAnalysis(&getAnalysis<AliasAnalysis>()); VN.setMemDep(MD); @@ -2435,175 +2446,182 @@ bool GVN::processBlock(BasicBlock *BB) { return ChangedFunction; } -/// performPRE - Perform a purely local form of PRE that looks for diamond -/// control flow patterns and attempts to perform simple PRE at the join point. -bool GVN::performPRE(Function &F) { - bool Changed = false; +bool GVN::performScalarPRE(Instruction *CurInst) { SmallVector<std::pair<Value*, BasicBlock*>, 8> predMap; - for (BasicBlock *CurrentBlock : depth_first(&F.getEntryBlock())) { - // Nothing to PRE in the entry block. - if (CurrentBlock == &F.getEntryBlock()) continue; - // Don't perform PRE on a landing pad. - if (CurrentBlock->isLandingPad()) continue; + if (isa<AllocaInst>(CurInst) || isa<TerminatorInst>(CurInst) || + isa<PHINode>(CurInst) || CurInst->getType()->isVoidTy() || + CurInst->mayReadFromMemory() || CurInst->mayHaveSideEffects() || + isa<DbgInfoIntrinsic>(CurInst)) + return false; - for (BasicBlock::iterator BI = CurrentBlock->begin(), - BE = CurrentBlock->end(); BI != BE; ) { - Instruction *CurInst = BI++; + // Don't do PRE on compares. The PHI would prevent CodeGenPrepare from + // sinking the compare again, and it would force the code generator to + // move the i1 from processor flags or predicate registers into a general + // purpose register. + if (isa<CmpInst>(CurInst)) + return false; - if (isa<AllocaInst>(CurInst) || - isa<TerminatorInst>(CurInst) || isa<PHINode>(CurInst) || - CurInst->getType()->isVoidTy() || - CurInst->mayReadFromMemory() || CurInst->mayHaveSideEffects() || - isa<DbgInfoIntrinsic>(CurInst)) - continue; + // We don't currently value number ANY inline asm calls. + if (CallInst *CallI = dyn_cast<CallInst>(CurInst)) + if (CallI->isInlineAsm()) + return false; - // Don't do PRE on compares. The PHI would prevent CodeGenPrepare from - // sinking the compare again, and it would force the code generator to - // move the i1 from processor flags or predicate registers into a general - // purpose register. - if (isa<CmpInst>(CurInst)) - continue; + uint32_t ValNo = VN.lookup(CurInst); + + // Look for the predecessors for PRE opportunities. We're + // only trying to solve the basic diamond case, where + // a value is computed in the successor and one predecessor, + // but not the other. We also explicitly disallow cases + // where the successor is its own predecessor, because they're + // more complicated to get right. + unsigned NumWith = 0; + unsigned NumWithout = 0; + BasicBlock *PREPred = nullptr; + BasicBlock *CurrentBlock = CurInst->getParent(); + predMap.clear(); + + for (pred_iterator PI = pred_begin(CurrentBlock), PE = pred_end(CurrentBlock); + PI != PE; ++PI) { + BasicBlock *P = *PI; + // We're not interested in PRE where the block is its + // own predecessor, or in blocks with predecessors + // that are not reachable. + if (P == CurrentBlock) { + NumWithout = 2; + break; + } else if (!DT->isReachableFromEntry(P)) { + NumWithout = 2; + break; + } - // We don't currently value number ANY inline asm calls. - if (CallInst *CallI = dyn_cast<CallInst>(CurInst)) - if (CallI->isInlineAsm()) - continue; + Value *predV = findLeader(P, ValNo); + if (!predV) { + predMap.push_back(std::make_pair(static_cast<Value *>(nullptr), P)); + PREPred = P; + ++NumWithout; + } else if (predV == CurInst) { + /* CurInst dominates this predecessor. */ + NumWithout = 2; + break; + } else { + predMap.push_back(std::make_pair(predV, P)); + ++NumWith; + } + } - uint32_t ValNo = VN.lookup(CurInst); - - // Look for the predecessors for PRE opportunities. We're - // only trying to solve the basic diamond case, where - // a value is computed in the successor and one predecessor, - // but not the other. We also explicitly disallow cases - // where the successor is its own predecessor, because they're - // more complicated to get right. - unsigned NumWith = 0; - unsigned NumWithout = 0; - BasicBlock *PREPred = nullptr; - predMap.clear(); - - for (pred_iterator PI = pred_begin(CurrentBlock), - PE = pred_end(CurrentBlock); PI != PE; ++PI) { - BasicBlock *P = *PI; - // We're not interested in PRE where the block is its - // own predecessor, or in blocks with predecessors - // that are not reachable. - if (P == CurrentBlock) { - NumWithout = 2; - break; - } else if (!DT->isReachableFromEntry(P)) { - NumWithout = 2; - break; - } + // Don't do PRE when it might increase code size, i.e. when + // we would need to insert instructions in more than one pred. + if (NumWithout != 1 || NumWith == 0) + return false; - Value* predV = findLeader(P, ValNo); - if (!predV) { - predMap.push_back(std::make_pair(static_cast<Value *>(nullptr), P)); - PREPred = P; - ++NumWithout; - } else if (predV == CurInst) { - /* CurInst dominates this predecessor. */ - NumWithout = 2; - break; - } else { - predMap.push_back(std::make_pair(predV, P)); - ++NumWith; - } - } + // Don't do PRE across indirect branch. + if (isa<IndirectBrInst>(PREPred->getTerminator())) + return false; - // Don't do PRE when it might increase code size, i.e. when - // we would need to insert instructions in more than one pred. - if (NumWithout != 1 || NumWith == 0) - continue; + // We can't do PRE safely on a critical edge, so instead we schedule + // the edge to be split and perform the PRE the next time we iterate + // on the function. + unsigned SuccNum = GetSuccessorNumber(PREPred, CurrentBlock); + if (isCriticalEdge(PREPred->getTerminator(), SuccNum)) { + toSplit.push_back(std::make_pair(PREPred->getTerminator(), SuccNum)); + return false; + } - // Don't do PRE across indirect branch. - if (isa<IndirectBrInst>(PREPred->getTerminator())) - continue; + // Instantiate the expression in the predecessor that lacked it. + // Because we are going top-down through the block, all value numbers + // will be available in the predecessor by the time we need them. Any + // that weren't originally present will have been instantiated earlier + // in this loop. + Instruction *PREInstr = CurInst->clone(); + bool success = true; + for (unsigned i = 0, e = CurInst->getNumOperands(); i != e; ++i) { + Value *Op = PREInstr->getOperand(i); + if (isa<Argument>(Op) || isa<Constant>(Op) || isa<GlobalValue>(Op)) + continue; - // We can't do PRE safely on a critical edge, so instead we schedule - // the edge to be split and perform the PRE the next time we iterate - // on the function. - unsigned SuccNum = GetSuccessorNumber(PREPred, CurrentBlock); - if (isCriticalEdge(PREPred->getTerminator(), SuccNum)) { - toSplit.push_back(std::make_pair(PREPred->getTerminator(), SuccNum)); - continue; - } + if (Value *V = findLeader(PREPred, VN.lookup(Op))) { + PREInstr->setOperand(i, V); + } else { + success = false; + break; + } + } - // Instantiate the expression in the predecessor that lacked it. - // Because we are going top-down through the block, all value numbers - // will be available in the predecessor by the time we need them. Any - // that weren't originally present will have been instantiated earlier - // in this loop. - Instruction *PREInstr = CurInst->clone(); - bool success = true; - for (unsigned i = 0, e = CurInst->getNumOperands(); i != e; ++i) { - Value *Op = PREInstr->getOperand(i); - if (isa<Argument>(Op) || isa<Constant>(Op) || isa<GlobalValue>(Op)) - continue; + // Fail out if we encounter an operand that is not available in + // the PRE predecessor. This is typically because of loads which + // are not value numbered precisely. + if (!success) { + DEBUG(verifyRemoved(PREInstr)); + delete PREInstr; + return false; + } - if (Value *V = findLeader(PREPred, VN.lookup(Op))) { - PREInstr->setOperand(i, V); - } else { - success = false; - break; - } - } + PREInstr->insertBefore(PREPred->getTerminator()); + PREInstr->setName(CurInst->getName() + ".pre"); + PREInstr->setDebugLoc(CurInst->getDebugLoc()); + VN.add(PREInstr, ValNo); + ++NumGVNPRE; - // Fail out if we encounter an operand that is not available in - // the PRE predecessor. This is typically because of loads which - // are not value numbered precisely. - if (!success) { - DEBUG(verifyRemoved(PREInstr)); - delete PREInstr; - continue; - } + // Update the availability map to include the new instruction. + addToLeaderTable(ValNo, PREInstr, PREPred); - PREInstr->insertBefore(PREPred->getTerminator()); - PREInstr->setName(CurInst->getName() + ".pre"); - PREInstr->setDebugLoc(CurInst->getDebugLoc()); - VN.add(PREInstr, ValNo); - ++NumGVNPRE; - - // Update the availability map to include the new instruction. - addToLeaderTable(ValNo, PREInstr, PREPred); - - // Create a PHI to make the value available in this block. - PHINode* Phi = PHINode::Create(CurInst->getType(), predMap.size(), - CurInst->getName() + ".pre-phi", - CurrentBlock->begin()); - for (unsigned i = 0, e = predMap.size(); i != e; ++i) { - if (Value *V = predMap[i].first) - Phi->addIncoming(V, predMap[i].second); - else - Phi->addIncoming(PREInstr, PREPred); - } + // Create a PHI to make the value available in this block. + PHINode *Phi = + PHINode::Create(CurInst->getType(), predMap.size(), + CurInst->getName() + ".pre-phi", CurrentBlock->begin()); + for (unsigned i = 0, e = predMap.size(); i != e; ++i) { + if (Value *V = predMap[i].first) + Phi->addIncoming(V, predMap[i].second); + else + Phi->addIncoming(PREInstr, PREPred); + } - VN.add(Phi, ValNo); - addToLeaderTable(ValNo, Phi, CurrentBlock); - Phi->setDebugLoc(CurInst->getDebugLoc()); - CurInst->replaceAllUsesWith(Phi); - if (Phi->getType()->getScalarType()->isPointerTy()) { - // Because we have added a PHI-use of the pointer value, it has now - // "escaped" from alias analysis' perspective. We need to inform - // AA of this. - for (unsigned ii = 0, ee = Phi->getNumIncomingValues(); ii != ee; - ++ii) { - unsigned jj = PHINode::getOperandNumForIncomingValue(ii); - VN.getAliasAnalysis()->addEscapingUse(Phi->getOperandUse(jj)); - } + VN.add(Phi, ValNo); + addToLeaderTable(ValNo, Phi, CurrentBlock); + Phi->setDebugLoc(CurInst->getDebugLoc()); + CurInst->replaceAllUsesWith(Phi); + if (Phi->getType()->getScalarType()->isPointerTy()) { + // Because we have added a PHI-use of the pointer value, it has now + // "escaped" from alias analysis' perspective. We need to inform + // AA of this. + for (unsigned ii = 0, ee = Phi->getNumIncomingValues(); ii != ee; ++ii) { + unsigned jj = PHINode::getOperandNumForIncomingValue(ii); + VN.getAliasAnalysis()->addEscapingUse(Phi->getOperandUse(jj)); + } - if (MD) - MD->invalidateCachedPointerInfo(Phi); - } - VN.erase(CurInst); - removeFromLeaderTable(ValNo, CurInst, CurrentBlock); + if (MD) + MD->invalidateCachedPointerInfo(Phi); + } + VN.erase(CurInst); + removeFromLeaderTable(ValNo, CurInst, CurrentBlock); + + DEBUG(dbgs() << "GVN PRE removed: " << *CurInst << '\n'); + if (MD) + MD->removeInstruction(CurInst); + DEBUG(verifyRemoved(CurInst)); + CurInst->eraseFromParent(); + return true; +} + +/// performPRE - Perform a purely local form of PRE that looks for diamond +/// control flow patterns and attempts to perform simple PRE at the join point. +bool GVN::performPRE(Function &F) { + bool Changed = false; + for (BasicBlock *CurrentBlock : depth_first(&F.getEntryBlock())) { + // Nothing to PRE in the entry block. + if (CurrentBlock == &F.getEntryBlock()) + continue; + + // Don't perform PRE on a landing pad. + if (CurrentBlock->isLandingPad()) + continue; - DEBUG(dbgs() << "GVN PRE removed: " << *CurInst << '\n'); - if (MD) MD->removeInstruction(CurInst); - DEBUG(verifyRemoved(CurInst)); - CurInst->eraseFromParent(); - Changed = true; + for (BasicBlock::iterator BI = CurrentBlock->begin(), + BE = CurrentBlock->end(); + BI != BE;) { + Instruction *CurInst = BI++; + Changed = performScalarPRE(CurInst); } } @@ -2641,25 +2659,21 @@ bool GVN::iterateOnFunction(Function &F) { // Top-down walk of the dominator tree bool Changed = false; -#if 0 - // Needed for value numbering with phi construction to work. - ReversePostOrderTraversal<Function*> RPOT(&F); - for (ReversePostOrderTraversal<Function*>::rpo_iterator RI = RPOT.begin(), - RE = RPOT.end(); RI != RE; ++RI) - Changed |= processBlock(*RI); -#else // Save the blocks this function have before transformation begins. GVN may // split critical edge, and hence may invalidate the RPO/DT iterator. // std::vector<BasicBlock *> BBVect; BBVect.reserve(256); - for (DomTreeNode *x : depth_first(DT->getRootNode())) - BBVect.push_back(x->getBlock()); + // Needed for value numbering with phi construction to work. + ReversePostOrderTraversal<Function *> RPOT(&F); + for (ReversePostOrderTraversal<Function *>::rpo_iterator RI = RPOT.begin(), + RE = RPOT.end(); + RI != RE; ++RI) + BBVect.push_back(*RI); for (std::vector<BasicBlock *>::iterator I = BBVect.begin(), E = BBVect.end(); I != E; I++) Changed |= processBlock(*I); -#endif return Changed; } @@ -2802,7 +2816,7 @@ bool GVN::processFoldableCondBr(BranchInst *BI) { return true; } -// performPRE() will trigger assert if it come across an instruciton without +// performPRE() will trigger assert if it comes across an instruction without // associated val-num. As it normally has far more live instructions than dead // instructions, it makes more sense just to "fabricate" a val-number for the // dead code than checking if instruction involved is dead or not. diff --git a/lib/Transforms/Scalar/IndVarSimplify.cpp b/lib/Transforms/Scalar/IndVarSimplify.cpp index 9cf0ca0912f9..c01f57f26ea9 100644 --- a/lib/Transforms/Scalar/IndVarSimplify.cpp +++ b/lib/Transforms/Scalar/IndVarSimplify.cpp @@ -31,6 +31,7 @@ #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/ScalarEvolutionExpander.h" +#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" #include "llvm/IR/Constants.h" @@ -69,11 +70,12 @@ static cl::opt<bool> ReduceLiveIVs("liv-reduce", cl::Hidden, namespace { class IndVarSimplify : public LoopPass { - LoopInfo *LI; - ScalarEvolution *SE; - DominatorTree *DT; - const DataLayout *DL; - TargetLibraryInfo *TLI; + LoopInfo *LI; + ScalarEvolution *SE; + DominatorTree *DT; + const DataLayout *DL; + TargetLibraryInfo *TLI; + const TargetTransformInfo *TTI; SmallVector<WeakVH, 16> DeadInsts; bool Changed; @@ -650,7 +652,7 @@ namespace { struct WideIVInfo { PHINode *NarrowIV; Type *WidestNativeType; // Widest integer type created [sz]ext - bool IsSigned; // Was an sext user seen before a zext? + bool IsSigned; // Was a sext user seen before a zext? WideIVInfo() : NarrowIV(nullptr), WidestNativeType(nullptr), IsSigned(false) {} @@ -661,7 +663,7 @@ namespace { /// extended by this sign or zero extend operation. This is used to determine /// the final width of the IV before actually widening it. static void visitIVCast(CastInst *Cast, WideIVInfo &WI, ScalarEvolution *SE, - const DataLayout *DL) { + const DataLayout *DL, const TargetTransformInfo *TTI) { bool IsSigned = Cast->getOpcode() == Instruction::SExt; if (!IsSigned && Cast->getOpcode() != Instruction::ZExt) return; @@ -671,6 +673,19 @@ static void visitIVCast(CastInst *Cast, WideIVInfo &WI, ScalarEvolution *SE, if (DL && !DL->isLegalInteger(Width)) return; + // Cast is either an sext or zext up to this point. + // We should not widen an indvar if arithmetics on the wider indvar are more + // expensive than those on the narrower indvar. We check only the cost of ADD + // because at least an ADD is required to increment the induction variable. We + // could compute more comprehensively the cost of all instructions on the + // induction variable when necessary. + if (TTI && + TTI->getArithmeticInstrCost(Instruction::Add, Ty) > + TTI->getArithmeticInstrCost(Instruction::Add, + Cast->getOperand(0)->getType())) { + return; + } + if (!WI.WidestNativeType) { WI.WidestNativeType = SE->getEffectiveSCEVType(Ty); WI.IsSigned = IsSigned; @@ -757,8 +772,13 @@ protected: const SCEVAddRecExpr* GetExtendedOperandRecurrence(NarrowIVDefUse DU); + const SCEV *GetSCEVByOpCode(const SCEV *LHS, const SCEV *RHS, + unsigned OpCode) const; + Instruction *WidenIVUse(NarrowIVDefUse DU, SCEVExpander &Rewriter); + bool WidenLoopCompare(NarrowIVDefUse DU); + void pushNarrowIVUsers(Instruction *NarrowDef, Instruction *WideDef); }; } // anonymous namespace @@ -833,18 +853,35 @@ Instruction *WidenIV::CloneIVUser(NarrowIVDefUse DU) { } } +const SCEV *WidenIV::GetSCEVByOpCode(const SCEV *LHS, const SCEV *RHS, + unsigned OpCode) const { + if (OpCode == Instruction::Add) + return SE->getAddExpr(LHS, RHS); + if (OpCode == Instruction::Sub) + return SE->getMinusSCEV(LHS, RHS); + if (OpCode == Instruction::Mul) + return SE->getMulExpr(LHS, RHS); + + llvm_unreachable("Unsupported opcode."); +} + /// No-wrap operations can transfer sign extension of their result to their /// operands. Generate the SCEV value for the widened operation without /// actually modifying the IR yet. If the expression after extending the /// operands is an AddRec for this loop, return it. const SCEVAddRecExpr* WidenIV::GetExtendedOperandRecurrence(NarrowIVDefUse DU) { + // Handle the common case of add<nsw/nuw> - if (DU.NarrowUse->getOpcode() != Instruction::Add) + const unsigned OpCode = DU.NarrowUse->getOpcode(); + // Only Add/Sub/Mul instructions supported yet. + if (OpCode != Instruction::Add && OpCode != Instruction::Sub && + OpCode != Instruction::Mul) return nullptr; // One operand (NarrowDef) has already been extended to WideDef. Now determine // if extending the other will lead to a recurrence. - unsigned ExtendOperIdx = DU.NarrowUse->getOperand(0) == DU.NarrowDef ? 1 : 0; + const unsigned ExtendOperIdx = + DU.NarrowUse->getOperand(0) == DU.NarrowDef ? 1 : 0; assert(DU.NarrowUse->getOperand(1-ExtendOperIdx) == DU.NarrowDef && "bad DU"); const SCEV *ExtendOperExpr = nullptr; @@ -859,13 +896,20 @@ const SCEVAddRecExpr* WidenIV::GetExtendedOperandRecurrence(NarrowIVDefUse DU) { else return nullptr; - // When creating this AddExpr, don't apply the current operations NSW or NUW + // When creating this SCEV expr, don't apply the current operations NSW or NUW // flags. This instruction may be guarded by control flow that the no-wrap // behavior depends on. Non-control-equivalent instructions can be mapped to // the same SCEV expression, and it would be incorrect to transfer NSW/NUW // semantics to those operations. - const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>( - SE->getAddExpr(SE->getSCEV(DU.WideDef), ExtendOperExpr)); + const SCEV *lhs = SE->getSCEV(DU.WideDef); + const SCEV *rhs = ExtendOperExpr; + + // Let's swap operands to the initial order for the case of non-commutative + // operations, like SUB. See PR21014. + if (ExtendOperIdx == 0) + std::swap(lhs, rhs); + const SCEVAddRecExpr *AddRec = + dyn_cast<SCEVAddRecExpr>(GetSCEVByOpCode(lhs, rhs, OpCode)); if (!AddRec || AddRec->getLoop() != L) return nullptr; @@ -908,6 +952,35 @@ static void truncateIVUse(NarrowIVDefUse DU, DominatorTree *DT) { DU.NarrowUse->replaceUsesOfWith(DU.NarrowDef, Trunc); } +/// If the narrow use is a compare instruction, then widen the compare +// (and possibly the other operand). The extend operation is hoisted into the +// loop preheader as far as possible. +bool WidenIV::WidenLoopCompare(NarrowIVDefUse DU) { + ICmpInst *Cmp = dyn_cast<ICmpInst>(DU.NarrowUse); + if (!Cmp) + return false; + + // Sign of IV user and compare must match. + if (IsSigned != CmpInst::isSigned(Cmp->getPredicate())) + return false; + + Value *Op = Cmp->getOperand(Cmp->getOperand(0) == DU.NarrowDef ? 1 : 0); + unsigned CastWidth = SE->getTypeSizeInBits(Op->getType()); + unsigned IVWidth = SE->getTypeSizeInBits(WideType); + assert (CastWidth <= IVWidth && "Unexpected width while widening compare."); + + // Widen the compare instruction. + IRBuilder<> Builder(getInsertPointForUses(DU.NarrowUse, DU.NarrowDef, DT)); + DU.NarrowUse->replaceUsesOfWith(DU.NarrowDef, DU.WideDef); + + // Widen the other operand of the compare, if necessary. + if (CastWidth < IVWidth) { + Value *ExtOp = getExtend(Op, WideType, IsSigned, Cmp); + DU.NarrowUse->replaceUsesOfWith(Op, ExtOp); + } + return true; +} + /// WidenIVUse - Determine whether an individual user of the narrow IV can be /// widened. If so, return the wide clone of the user. Instruction *WidenIV::WidenIVUse(NarrowIVDefUse DU, SCEVExpander &Rewriter) { @@ -975,10 +1048,15 @@ Instruction *WidenIV::WidenIVUse(NarrowIVDefUse DU, SCEVExpander &Rewriter) { // Does this user itself evaluate to a recurrence after widening? const SCEVAddRecExpr *WideAddRec = GetWideRecurrence(DU.NarrowUse); + if (!WideAddRec) + WideAddRec = GetExtendedOperandRecurrence(DU); + if (!WideAddRec) { - WideAddRec = GetExtendedOperandRecurrence(DU); - } - if (!WideAddRec) { + // If use is a loop condition, try to promote the condition instead of + // truncating the IV first. + if (WidenLoopCompare(DU)) + return nullptr; + // This user does not evaluate to a recurence after widening, so don't // follow it. Instead insert a Trunc to kill off the original use, // eventually isolating the original narrow IV so it can be removed. @@ -1024,7 +1102,7 @@ void WidenIV::pushNarrowIVUsers(Instruction *NarrowDef, Instruction *WideDef) { Instruction *NarrowUser = cast<Instruction>(U); // Handle data flow merges and bizarre phi cycles. - if (!Widened.insert(NarrowUser)) + if (!Widened.insert(NarrowUser).second) continue; NarrowIVUsers.push_back(NarrowIVDefUse(NarrowDef, NarrowUser, WideDef)); @@ -1124,14 +1202,16 @@ namespace { class IndVarSimplifyVisitor : public IVVisitor { ScalarEvolution *SE; const DataLayout *DL; + const TargetTransformInfo *TTI; PHINode *IVPhi; public: WideIVInfo WI; IndVarSimplifyVisitor(PHINode *IV, ScalarEvolution *SCEV, - const DataLayout *DL, const DominatorTree *DTree): - SE(SCEV), DL(DL), IVPhi(IV) { + const DataLayout *DL, const TargetTransformInfo *TTI, + const DominatorTree *DTree) + : SE(SCEV), DL(DL), TTI(TTI), IVPhi(IV) { DT = DTree; WI.NarrowIV = IVPhi; if (ReduceLiveIVs) @@ -1139,7 +1219,9 @@ namespace { } // Implement the interface used by simplifyUsersOfIV. - void visitCast(CastInst *Cast) override { visitIVCast(Cast, WI, SE, DL); } + void visitCast(CastInst *Cast) override { + visitIVCast(Cast, WI, SE, DL, TTI); + } }; } @@ -1173,7 +1255,7 @@ void IndVarSimplify::SimplifyAndExtend(Loop *L, PHINode *CurrIV = LoopPhis.pop_back_val(); // Information about sign/zero extensions of CurrIV. - IndVarSimplifyVisitor Visitor(CurrIV, SE, DL, DT); + IndVarSimplifyVisitor Visitor(CurrIV, SE, DL, TTI, DT); Changed |= simplifyUsersOfIV(CurrIV, SE, &LPM, DeadInsts, &Visitor); @@ -1200,9 +1282,9 @@ void IndVarSimplify::SimplifyAndExtend(Loop *L, /// BackedgeTakenInfo. If these expressions have not been reduced, then /// expanding them may incur additional cost (albeit in the loop preheader). static bool isHighCostExpansion(const SCEV *S, BranchInst *BI, - SmallPtrSet<const SCEV*, 8> &Processed, + SmallPtrSetImpl<const SCEV*> &Processed, ScalarEvolution *SE) { - if (!Processed.insert(S)) + if (!Processed.insert(S).second) return false; // If the backedge-taken count is a UDiv, it's very likely a UDiv that @@ -1373,7 +1455,7 @@ static bool needsLFTR(Loop *L, DominatorTree *DT) { /// Recursive helper for hasConcreteDef(). Unfortunately, this currently boils /// down to checking that all operands are constant and listing instructions /// that may hide undef. -static bool hasConcreteDefImpl(Value *V, SmallPtrSet<Value*, 8> &Visited, +static bool hasConcreteDefImpl(Value *V, SmallPtrSetImpl<Value*> &Visited, unsigned Depth) { if (isa<Constant>(V)) return !isa<UndefValue>(V); @@ -1393,7 +1475,7 @@ static bool hasConcreteDefImpl(Value *V, SmallPtrSet<Value*, 8> &Visited, // Optimistically handle other instructions. for (User::op_iterator OI = I->op_begin(), E = I->op_end(); OI != E; ++OI) { - if (!Visited.insert(*OI)) + if (!Visited.insert(*OI).second) continue; if (!hasConcreteDefImpl(*OI, Visited, Depth+1)) return false; @@ -1637,8 +1719,29 @@ LinearFunctionTestReplace(Loop *L, // FIXME: In theory, SCEV could drop flags even though they exist in IR. // A more robust solution would involve getting a new expression for // CmpIndVar by applying non-NSW/NUW AddExprs. + auto WrappingFlags = + ScalarEvolution::setFlags(SCEV::FlagNUW, SCEV::FlagNSW); + const SCEV *IVInit = IncrementedIndvarSCEV->getStart(); + if (SE->getTypeSizeInBits(IVInit->getType()) > + SE->getTypeSizeInBits(IVCount->getType())) + IVInit = SE->getTruncateExpr(IVInit, IVCount->getType()); + unsigned BitWidth = SE->getTypeSizeInBits(IVCount->getType()); + Type *WideTy = IntegerType::get(SE->getContext(), BitWidth + 1); + // Check if InitIV + BECount+1 requires sign/zero extension. + // If not, clear the corresponding flag from WrappingFlags because it is not + // necessary for those flags in the IncrementedIndvarSCEV expression. + if (SE->getSignExtendExpr(SE->getAddExpr(IVInit, BackedgeTakenCount), + WideTy) == + SE->getAddExpr(SE->getSignExtendExpr(IVInit, WideTy), + SE->getSignExtendExpr(BackedgeTakenCount, WideTy))) + WrappingFlags = ScalarEvolution::clearFlags(WrappingFlags, SCEV::FlagNSW); + if (SE->getZeroExtendExpr(SE->getAddExpr(IVInit, BackedgeTakenCount), + WideTy) == + SE->getAddExpr(SE->getZeroExtendExpr(IVInit, WideTy), + SE->getZeroExtendExpr(BackedgeTakenCount, WideTy))) + WrappingFlags = ScalarEvolution::clearFlags(WrappingFlags, SCEV::FlagNUW); if (!ScalarEvolution::maskFlags(IncrementedIndvarSCEV->getNoWrapFlags(), - SCEV::FlagNUW | SCEV::FlagNSW)) { + WrappingFlags)) { // Add one to the "backedge-taken" count to get the trip count. // This addition may overflow, which is valid as long as the comparison is // truncated to BackedgeTakenCount->getType(). @@ -1832,6 +1935,7 @@ bool IndVarSimplify::runOnLoop(Loop *L, LPPassManager &LPM) { DataLayoutPass *DLP = getAnalysisIfAvailable<DataLayoutPass>(); DL = DLP ? &DLP->getDataLayout() : nullptr; TLI = getAnalysisIfAvailable<TargetLibraryInfo>(); + TTI = getAnalysisIfAvailable<TargetTransformInfo>(); DeadInsts.clear(); Changed = false; diff --git a/lib/Transforms/Scalar/JumpThreading.cpp b/lib/Transforms/Scalar/JumpThreading.cpp index 21f80385cf46..78beb3f98dcd 100644 --- a/lib/Transforms/Scalar/JumpThreading.cpp +++ b/lib/Transforms/Scalar/JumpThreading.cpp @@ -26,6 +26,7 @@ #include "llvm/IR/DataLayout.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" #include "llvm/IR/ValueHandle.h" #include "llvm/Pass.h" #include "llvm/Support/CommandLine.h" @@ -44,7 +45,7 @@ STATISTIC(NumFolds, "Number of terminators folded"); STATISTIC(NumDupes, "Number of branch blocks duplicated to eliminate phi"); static cl::opt<unsigned> -Threshold("jump-threading-threshold", +BBDuplicateThreshold("jump-threading-threshold", cl::desc("Max block size to duplicate for jump threading"), cl::init(6), cl::Hidden); @@ -87,6 +88,8 @@ namespace { #endif DenseSet<std::pair<Value*, BasicBlock*> > RecursionSet; + unsigned BBDupThreshold; + // RAII helper for updating the recursion stack. struct RecursionSetRemover { DenseSet<std::pair<Value*, BasicBlock*> > &TheSet; @@ -102,7 +105,8 @@ namespace { }; public: static char ID; // Pass identification - JumpThreading() : FunctionPass(ID) { + JumpThreading(int T = -1) : FunctionPass(ID) { + BBDupThreshold = (T == -1) ? BBDuplicateThreshold : unsigned(T); initializeJumpThreadingPass(*PassRegistry::getPassRegistry()); } @@ -123,9 +127,11 @@ namespace { bool ComputeValueKnownInPredecessors(Value *V, BasicBlock *BB, PredValueInfo &Result, - ConstantPreference Preference); + ConstantPreference Preference, + Instruction *CxtI = nullptr); bool ProcessThreadableEdges(Value *Cond, BasicBlock *BB, - ConstantPreference Preference); + ConstantPreference Preference, + Instruction *CxtI = nullptr); bool ProcessBranchOnPHI(PHINode *PN); bool ProcessBranchOnXOR(BinaryOperator *BO); @@ -144,7 +150,7 @@ INITIALIZE_PASS_END(JumpThreading, "jump-threading", "Jump Threading", false, false) // Public interface to the Jump Threading pass -FunctionPass *llvm::createJumpThreadingPass() { return new JumpThreading(); } +FunctionPass *llvm::createJumpThreadingPass(int Threshold) { return new JumpThreading(Threshold); } /// runOnFunction - Top level algorithm. /// @@ -182,7 +188,7 @@ bool JumpThreading::runOnFunction(Function &F) { // If the block is trivially dead, zap it. This eliminates the successor // edges which simplifies the CFG. - if (pred_begin(BB) == pred_end(BB) && + if (pred_empty(BB) && BB != &BB->getParent()->getEntryBlock()) { DEBUG(dbgs() << " JT: Deleting dead block '" << BB->getName() << "' with terminator: " << *BB->getTerminator() << '\n'); @@ -339,7 +345,8 @@ static Constant *getKnownConstant(Value *Val, ConstantPreference Preference) { /// bool JumpThreading:: ComputeValueKnownInPredecessors(Value *V, BasicBlock *BB, PredValueInfo &Result, - ConstantPreference Preference) { + ConstantPreference Preference, + Instruction *CxtI) { // This method walks up use-def chains recursively. Because of this, we could // get into an infinite loop going around loops in the use-def chain. To // prevent this, keep track of what (value, block) pairs we've already visited @@ -381,7 +388,7 @@ ComputeValueKnownInPredecessors(Value *V, BasicBlock *BB, PredValueInfo &Result, BasicBlock *P = *PI; // If the value is known by LazyValueInfo to be a constant in a // predecessor, use that information to try to thread this block. - Constant *PredCst = LVI->getConstantOnEdge(V, P, BB); + Constant *PredCst = LVI->getConstantOnEdge(V, P, BB, CxtI); if (Constant *KC = getKnownConstant(PredCst, Preference)) Result.push_back(std::make_pair(KC, P)); } @@ -397,7 +404,8 @@ ComputeValueKnownInPredecessors(Value *V, BasicBlock *BB, PredValueInfo &Result, Result.push_back(std::make_pair(KC, PN->getIncomingBlock(i))); } else { Constant *CI = LVI->getConstantOnEdge(InVal, - PN->getIncomingBlock(i), BB); + PN->getIncomingBlock(i), + BB, CxtI); if (Constant *KC = getKnownConstant(CI, Preference)) Result.push_back(std::make_pair(KC, PN->getIncomingBlock(i))); } @@ -416,9 +424,9 @@ ComputeValueKnownInPredecessors(Value *V, BasicBlock *BB, PredValueInfo &Result, if (I->getOpcode() == Instruction::Or || I->getOpcode() == Instruction::And) { ComputeValueKnownInPredecessors(I->getOperand(0), BB, LHSVals, - WantInteger); + WantInteger, CxtI); ComputeValueKnownInPredecessors(I->getOperand(1), BB, RHSVals, - WantInteger); + WantInteger, CxtI); if (LHSVals.empty() && RHSVals.empty()) return false; @@ -459,7 +467,7 @@ ComputeValueKnownInPredecessors(Value *V, BasicBlock *BB, PredValueInfo &Result, isa<ConstantInt>(I->getOperand(1)) && cast<ConstantInt>(I->getOperand(1))->isOne()) { ComputeValueKnownInPredecessors(I->getOperand(0), BB, Result, - WantInteger); + WantInteger, CxtI); if (Result.empty()) return false; @@ -477,7 +485,7 @@ ComputeValueKnownInPredecessors(Value *V, BasicBlock *BB, PredValueInfo &Result, if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->getOperand(1))) { PredValueInfoTy LHSVals; ComputeValueKnownInPredecessors(BO->getOperand(0), BB, LHSVals, - WantInteger); + WantInteger, CxtI); // Try to use constant folding to simplify the binary operator. for (unsigned i = 0, e = LHSVals.size(); i != e; ++i) { @@ -511,7 +519,8 @@ ComputeValueKnownInPredecessors(Value *V, BasicBlock *BB, PredValueInfo &Result, LazyValueInfo::Tristate ResT = LVI->getPredicateOnEdge(Cmp->getPredicate(), LHS, - cast<Constant>(RHS), PredBB, BB); + cast<Constant>(RHS), PredBB, BB, + CxtI ? CxtI : Cmp); if (ResT == LazyValueInfo::Unknown) continue; Res = ConstantInt::get(Type::getInt1Ty(LHS->getContext()), ResT); @@ -524,7 +533,6 @@ ComputeValueKnownInPredecessors(Value *V, BasicBlock *BB, PredValueInfo &Result, return !Result.empty(); } - // If comparing a live-in value against a constant, see if we know the // live-in value on any predecessors. if (isa<Constant>(Cmp->getOperand(1)) && Cmp->getType()->isIntegerTy()) { @@ -538,7 +546,7 @@ ComputeValueKnownInPredecessors(Value *V, BasicBlock *BB, PredValueInfo &Result, // predecessor, use that information to try to thread this block. LazyValueInfo::Tristate Res = LVI->getPredicateOnEdge(Cmp->getPredicate(), Cmp->getOperand(0), - RHSCst, P, BB); + RHSCst, P, BB, CxtI ? CxtI : Cmp); if (Res == LazyValueInfo::Unknown) continue; @@ -554,7 +562,7 @@ ComputeValueKnownInPredecessors(Value *V, BasicBlock *BB, PredValueInfo &Result, if (Constant *CmpConst = dyn_cast<Constant>(Cmp->getOperand(1))) { PredValueInfoTy LHSVals; ComputeValueKnownInPredecessors(I->getOperand(0), BB, LHSVals, - WantInteger); + WantInteger, CxtI); for (unsigned i = 0, e = LHSVals.size(); i != e; ++i) { Constant *V = LHSVals[i].first; @@ -577,7 +585,7 @@ ComputeValueKnownInPredecessors(Value *V, BasicBlock *BB, PredValueInfo &Result, PredValueInfoTy Conds; if ((TrueVal || FalseVal) && ComputeValueKnownInPredecessors(SI->getCondition(), BB, Conds, - WantInteger)) { + WantInteger, CxtI)) { for (unsigned i = 0, e = Conds.size(); i != e; ++i) { Constant *Cond = Conds[i].first; @@ -604,7 +612,7 @@ ComputeValueKnownInPredecessors(Value *V, BasicBlock *BB, PredValueInfo &Result, } // If all else fails, see if LVI can figure out a constant value for us. - Constant *CI = LVI->getConstant(V, BB); + Constant *CI = LVI->getConstant(V, BB, CxtI); if (Constant *KC = getKnownConstant(CI, Preference)) { for (pred_iterator PI = pred_begin(BB), E = pred_end(BB); PI != E; ++PI) Result.push_back(std::make_pair(KC, *PI)); @@ -654,7 +662,7 @@ static bool hasAddressTakenAndUsed(BasicBlock *BB) { bool JumpThreading::ProcessBlock(BasicBlock *BB) { // If the block is trivially dead, just return and let the caller nuke it. // This simplifies other transformations. - if (pred_begin(BB) == pred_end(BB) && + if (pred_empty(BB) && BB != &BB->getParent()->getEntryBlock()) return false; @@ -744,7 +752,7 @@ bool JumpThreading::ProcessBlock(BasicBlock *BB) { // All the rest of our checks depend on the condition being an instruction. if (!CondInst) { // FIXME: Unify this with code below. - if (ProcessThreadableEdges(Condition, BB, Preference)) + if (ProcessThreadableEdges(Condition, BB, Preference, Terminator)) return true; return false; } @@ -766,13 +774,14 @@ bool JumpThreading::ProcessBlock(BasicBlock *BB) { // FIXME: We could handle mixed true/false by duplicating code. LazyValueInfo::Tristate Baseline = LVI->getPredicateOnEdge(CondCmp->getPredicate(), CondCmp->getOperand(0), - CondConst, *PI, BB); + CondConst, *PI, BB, CondCmp); if (Baseline != LazyValueInfo::Unknown) { // Check that all remaining incoming values match the first one. while (++PI != PE) { LazyValueInfo::Tristate Ret = LVI->getPredicateOnEdge(CondCmp->getPredicate(), - CondCmp->getOperand(0), CondConst, *PI, BB); + CondCmp->getOperand(0), CondConst, *PI, BB, + CondCmp); if (Ret != Baseline) break; } @@ -787,6 +796,21 @@ bool JumpThreading::ProcessBlock(BasicBlock *BB) { } } + } else if (CondBr && CondConst && CondBr->isConditional()) { + // There might be an invariant in the same block with the conditional + // that can determine the predicate. + + LazyValueInfo::Tristate Ret = + LVI->getPredicateAt(CondCmp->getPredicate(), CondCmp->getOperand(0), + CondConst, CondCmp); + if (Ret != LazyValueInfo::Unknown) { + unsigned ToRemove = Ret == LazyValueInfo::True ? 1 : 0; + unsigned ToKeep = Ret == LazyValueInfo::True ? 0 : 1; + CondBr->getSuccessor(ToRemove)->removePredecessor(BB, true); + BranchInst::Create(CondBr->getSuccessor(ToKeep), CondBr); + CondBr->eraseFromParent(); + return true; + } } if (CondBr && CondConst && TryToUnfoldSelect(CondCmp, BB)) @@ -814,7 +838,7 @@ bool JumpThreading::ProcessBlock(BasicBlock *BB) { // a PHI node in the current block. If we can prove that any predecessors // compute a predictable value based on a PHI node, thread those predecessors. // - if (ProcessThreadableEdges(CondInst, BB, Preference)) + if (ProcessThreadableEdges(CondInst, BB, Preference, Terminator)) return true; // If this is an otherwise-unfoldable branch on a phi node in the current @@ -877,6 +901,9 @@ bool JumpThreading::SimplifyPartiallyRedundantLoad(LoadInst *LI) { // If the returned value is the load itself, replace with an undef. This can // only happen in dead loops. if (AvailableVal == LI) AvailableVal = UndefValue::get(LI->getType()); + if (AvailableVal->getType() != LI->getType()) + AvailableVal = + CastInst::CreateBitOrPointerCast(AvailableVal, LI->getType(), "", LI); LI->replaceAllUsesWith(AvailableVal); LI->eraseFromParent(); return true; @@ -888,9 +915,10 @@ bool JumpThreading::SimplifyPartiallyRedundantLoad(LoadInst *LI) { if (BBIt != LoadBB->begin()) return false; - // If all of the loads and stores that feed the value have the same TBAA tag, - // then we can propagate it onto any newly inserted loads. - MDNode *TBAATag = LI->getMetadata(LLVMContext::MD_tbaa); + // If all of the loads and stores that feed the value have the same AA tags, + // then we can propagate them onto any newly inserted loads. + AAMDNodes AATags; + LI->getAAMetadata(AATags); SmallPtrSet<BasicBlock*, 8> PredsScanned; typedef SmallVector<std::pair<BasicBlock*, Value*>, 8> AvailablePredsTy; @@ -904,21 +932,21 @@ bool JumpThreading::SimplifyPartiallyRedundantLoad(LoadInst *LI) { BasicBlock *PredBB = *PI; // If we already scanned this predecessor, skip it. - if (!PredsScanned.insert(PredBB)) + if (!PredsScanned.insert(PredBB).second) continue; // Scan the predecessor to see if the value is available in the pred. BBIt = PredBB->end(); - MDNode *ThisTBAATag = nullptr; + AAMDNodes ThisAATags; Value *PredAvailable = FindAvailableLoadedValue(LoadedPtr, PredBB, BBIt, 6, - nullptr, &ThisTBAATag); + nullptr, &ThisAATags); if (!PredAvailable) { OneUnavailablePred = PredBB; continue; } - // If tbaa tags disagree or are not present, forget about them. - if (TBAATag != ThisTBAATag) TBAATag = nullptr; + // If AA tags disagree or are not present, forget about them. + if (AATags != ThisAATags) AATags = AAMDNodes(); // If so, this load is partially redundant. Remember this info so that we // can create a PHI node. @@ -978,8 +1006,8 @@ bool JumpThreading::SimplifyPartiallyRedundantLoad(LoadInst *LI) { LI->getAlignment(), UnavailablePred->getTerminator()); NewVal->setDebugLoc(LI->getDebugLoc()); - if (TBAATag) - NewVal->setMetadata(LLVMContext::MD_tbaa, TBAATag); + if (AATags) + NewVal->setAAMetadata(AATags); AvailablePreds.push_back(std::make_pair(UnavailablePred, NewVal)); } @@ -1006,7 +1034,16 @@ bool JumpThreading::SimplifyPartiallyRedundantLoad(LoadInst *LI) { assert(I != AvailablePreds.end() && I->first == P && "Didn't find entry for predecessor!"); - PN->addIncoming(I->second, I->first); + // If we have an available predecessor but it requires casting, insert the + // cast in the predecessor and use the cast. Note that we have to update the + // AvailablePreds vector as we go so that all of the PHI entries for this + // predecessor use the same bitcast. + Value *&PredV = I->second; + if (PredV->getType() != LI->getType()) + PredV = CastInst::CreateBitOrPointerCast(PredV, LI->getType(), "", + P->getTerminator()); + + PN->addIncoming(PredV, I->first); } //cerr << "PRE: " << *LI << *PN << "\n"; @@ -1081,14 +1118,15 @@ FindMostPopularDest(BasicBlock *BB, } bool JumpThreading::ProcessThreadableEdges(Value *Cond, BasicBlock *BB, - ConstantPreference Preference) { + ConstantPreference Preference, + Instruction *CxtI) { // If threading this would thread across a loop header, don't even try to // thread the edge. if (LoopHeaders.count(BB)) return false; PredValueInfoTy PredValues; - if (!ComputeValueKnownInPredecessors(Cond, BB, PredValues, Preference)) + if (!ComputeValueKnownInPredecessors(Cond, BB, PredValues, Preference, CxtI)) return false; assert(!PredValues.empty() && @@ -1113,7 +1151,7 @@ bool JumpThreading::ProcessThreadableEdges(Value *Cond, BasicBlock *BB, for (unsigned i = 0, e = PredValues.size(); i != e; ++i) { BasicBlock *Pred = PredValues[i].second; - if (!SeenPreds.insert(Pred)) + if (!SeenPreds.insert(Pred).second) continue; // Duplicate predecessor entry. // If the predecessor ends with an indirect goto, we can't change its @@ -1253,10 +1291,10 @@ bool JumpThreading::ProcessBranchOnXOR(BinaryOperator *BO) { PredValueInfoTy XorOpValues; bool isLHS = true; if (!ComputeValueKnownInPredecessors(BO->getOperand(0), BB, XorOpValues, - WantInteger)) { + WantInteger, BO)) { assert(XorOpValues.empty()); if (!ComputeValueKnownInPredecessors(BO->getOperand(1), BB, XorOpValues, - WantInteger)) + WantInteger, BO)) return false; isLHS = false; } @@ -1366,8 +1404,8 @@ bool JumpThreading::ThreadEdge(BasicBlock *BB, return false; } - unsigned JumpThreadCost = getJumpThreadDuplicationCost(BB, Threshold); - if (JumpThreadCost > Threshold) { + unsigned JumpThreadCost = getJumpThreadDuplicationCost(BB, BBDupThreshold); + if (JumpThreadCost > BBDupThreshold) { DEBUG(dbgs() << " Not threading BB '" << BB->getName() << "' - Cost is too high: " << JumpThreadCost << "\n"); return false; @@ -1509,8 +1547,8 @@ bool JumpThreading::DuplicateCondBranchOnPHIIntoPred(BasicBlock *BB, return false; } - unsigned DuplicationCost = getJumpThreadDuplicationCost(BB, Threshold); - if (DuplicationCost > Threshold) { + unsigned DuplicationCost = getJumpThreadDuplicationCost(BB, BBDupThreshold); + if (DuplicationCost > BBDupThreshold) { DEBUG(dbgs() << " Not duplicating BB '" << BB->getName() << "' - Cost is too high: " << DuplicationCost << "\n"); return false; @@ -1672,10 +1710,10 @@ bool JumpThreading::TryToUnfoldSelect(CmpInst *CondCmp, BasicBlock *BB) { // cases will be threaded in any case. LazyValueInfo::Tristate LHSFolds = LVI->getPredicateOnEdge(CondCmp->getPredicate(), SI->getOperand(1), - CondRHS, Pred, BB); + CondRHS, Pred, BB, CondCmp); LazyValueInfo::Tristate RHSFolds = LVI->getPredicateOnEdge(CondCmp->getPredicate(), SI->getOperand(2), - CondRHS, Pred, BB); + CondRHS, Pred, BB, CondCmp); if ((LHSFolds != LazyValueInfo::Unknown || RHSFolds != LazyValueInfo::Unknown) && LHSFolds != RHSFolds) { diff --git a/lib/Transforms/Scalar/LICM.cpp b/lib/Transforms/Scalar/LICM.cpp index abcceb20050a..e145981846d9 100644 --- a/lib/Transforms/Scalar/LICM.cpp +++ b/lib/Transforms/Scalar/LICM.cpp @@ -120,6 +120,7 @@ namespace { bool MayThrow; // The current loop contains an instruction which // may throw, thus preventing code motion of // instructions with side effects. + bool HeaderMayThrow; // Same as previous, but specific to loop header DenseMap<Loop*, AliasSetTracker*> LoopToAliasSetMap; /// cloneBasicBlockAnalysis - Simple Analysis hook. Clone alias set info. @@ -130,6 +131,9 @@ namespace { /// set. void deleteAnalysisValue(Value *V, Loop *L) override; + /// Simple Analysis hook. Delete loop L from alias set map. + void deleteAnalysisLoop(Loop *L) override; + /// SinkRegion - Walk the specified region of the CFG (defined by all blocks /// dominated by the specified block, and that are in the current loop) in /// reverse depth first order w.r.t the DominatorTree. This allows us to @@ -180,9 +184,9 @@ namespace { /// store into the memory location pointed to by V. /// bool pointerInvalidatedByLoop(Value *V, uint64_t Size, - const MDNode *TBAAInfo) { + const AAMDNodes &AAInfo) { // Check to see if any of the basic blocks in CurLoop invalidate *V. - return CurAST->getAliasSetForPointer(V, Size, TBAAInfo).isMod(); + return CurAST->getAliasSetForPointer(V, Size, AAInfo).isMod(); } bool canSinkOrHoistInst(Instruction &I); @@ -270,7 +274,12 @@ bool LICM::runOnLoop(Loop *L, LPPassManager &LPM) { CurAST->add(*BB); // Incorporate the specified basic block } - MayThrow = false; + HeaderMayThrow = false; + BasicBlock *Header = L->getHeader(); + for (BasicBlock::iterator I = Header->begin(), E = Header->end(); + (I != E) && !HeaderMayThrow; ++I) + HeaderMayThrow |= I->mayThrow(); + MayThrow = HeaderMayThrow; // TODO: We've already searched for instructions which may throw in subloops. // We may want to reuse this information. for (Loop::block_iterator BB = L->block_begin(), BBE = L->block_end(); @@ -313,7 +322,8 @@ bool LICM::runOnLoop(Loop *L, LPPassManager &LPM) { // SSAUpdater strategy during promotion that was LCSSA aware and reformed // it as it went. if (Changed) - formLCSSARecursively(*L, *DT, getAnalysisIfAvailable<ScalarEvolution>()); + formLCSSARecursively(*L, *DT, LI, + getAnalysisIfAvailable<ScalarEvolution>()); } // Check that neither this loop nor its parent have had LCSSA broken. LICM is @@ -441,15 +451,18 @@ bool LICM::canSinkOrHoistInst(Instruction &I) { // in the same alias set as something that ends up being modified. if (AA->pointsToConstantMemory(LI->getOperand(0))) return true; - if (LI->getMetadata("invariant.load")) + if (LI->getMetadata(LLVMContext::MD_invariant_load)) return true; // Don't hoist loads which have may-aliased stores in loop. uint64_t Size = 0; if (LI->getType()->isSized()) Size = AA->getTypeStoreSize(LI->getType()); - return !pointerInvalidatedByLoop(LI->getOperand(0), Size, - LI->getMetadata(LLVMContext::MD_tbaa)); + + AAMDNodes AAInfo; + LI->getAAMetadata(AAInfo); + + return !pointerInvalidatedByLoop(LI->getOperand(0), Size, AAInfo); } else if (CallInst *CI = dyn_cast<CallInst>(&I)) { // Don't sink or hoist dbg info; it's legal, but not useful. if (isa<DbgInfoIntrinsic>(I)) @@ -594,8 +607,13 @@ void LICM::sink(Instruction &I) { // PHI nodes in exit blocks due to LCSSA form. Just RAUW them with clones of // the instruction. while (!I.use_empty()) { + Instruction *User = I.user_back(); + if (!DT->isReachableFromEntry(User->getParent())) { + User->replaceUsesOfWith(&I, UndefValue::get(I.getType())); + continue; + } // The user must be a PHI node. - PHINode *PN = cast<PHINode>(I.user_back()); + PHINode *PN = cast<PHINode>(User); BasicBlock *ExitBlock = PN->getParent(); assert(ExitBlockSet.count(ExitBlock) && @@ -647,12 +665,7 @@ bool LICM::isSafeToExecuteUnconditionally(Instruction &Inst) { bool LICM::isGuaranteedToExecute(Instruction &Inst) { - // Somewhere in this loop there is an instruction which may throw and make us - // exit the loop. - if (MayThrow) - return false; - - // Otherwise we have to check to make sure that the instruction dominates all + // We have to check to make sure that the instruction dominates all // of the exit blocks. If it doesn't, then there is a path out of the loop // which does not execute this instruction, so we can't hoist it. @@ -660,7 +673,14 @@ bool LICM::isGuaranteedToExecute(Instruction &Inst) { // common), it is always guaranteed to dominate the exit blocks. Since this // is a common case, and can save some work, check it now. if (Inst.getParent() == CurLoop->getHeader()) - return true; + // If there's a throw in the header block, we can't guarantee we'll reach + // Inst. + return !HeaderMayThrow; + + // Somewhere in this loop there is an instruction which may throw and make us + // exit the loop. + if (MayThrow) + return false; // Get the exit blocks for the current loop. SmallVector<BasicBlock*, 8> ExitBlocks; @@ -682,7 +702,7 @@ bool LICM::isGuaranteedToExecute(Instruction &Inst) { namespace { class LoopPromoter : public LoadAndStorePromoter { Value *SomePtr; // Designated pointer to store to. - SmallPtrSet<Value*, 4> &PointerMustAliases; + SmallPtrSetImpl<Value*> &PointerMustAliases; SmallVectorImpl<BasicBlock*> &LoopExitBlocks; SmallVectorImpl<Instruction*> &LoopInsertPts; PredIteratorCache &PredCache; @@ -690,7 +710,7 @@ namespace { LoopInfo &LI; DebugLoc DL; int Alignment; - MDNode *TBAATag; + AAMDNodes AATags; Value *maybeInsertLCSSAPHI(Value *V, BasicBlock *BB) const { if (Instruction *I = dyn_cast<Instruction>(V)) @@ -710,14 +730,14 @@ namespace { public: LoopPromoter(Value *SP, const SmallVectorImpl<Instruction *> &Insts, - SSAUpdater &S, SmallPtrSet<Value *, 4> &PMA, + SSAUpdater &S, SmallPtrSetImpl<Value *> &PMA, SmallVectorImpl<BasicBlock *> &LEB, SmallVectorImpl<Instruction *> &LIP, PredIteratorCache &PIC, AliasSetTracker &ast, LoopInfo &li, DebugLoc dl, int alignment, - MDNode *TBAATag) + const AAMDNodes &AATags) : LoadAndStorePromoter(Insts, S), SomePtr(SP), PointerMustAliases(PMA), LoopExitBlocks(LEB), LoopInsertPts(LIP), PredCache(PIC), AST(ast), - LI(li), DL(dl), Alignment(alignment), TBAATag(TBAATag) {} + LI(li), DL(dl), Alignment(alignment), AATags(AATags) {} bool isInstInList(Instruction *I, const SmallVectorImpl<Instruction*> &) const override { @@ -743,7 +763,7 @@ namespace { StoreInst *NewSI = new StoreInst(LiveInValue, Ptr, InsertPos); NewSI->setAlignment(Alignment); NewSI->setDebugLoc(DL); - if (TBAATag) NewSI->setMetadata(LLVMContext::MD_tbaa, TBAATag); + if (AATags) NewSI->setAAMetadata(AATags); } } @@ -798,11 +818,12 @@ void LICM::PromoteAliasSet(AliasSet &AS, // We start with an alignment of one and try to find instructions that allow // us to prove better alignment. unsigned Alignment = 1; - MDNode *TBAATag = nullptr; + AAMDNodes AATags; + bool HasDedicatedExits = CurLoop->hasDedicatedExits(); // Check that all of the pointers in the alias set have the same type. We // cannot (yet) promote a memory location that is loaded and stored in - // different sizes. While we are at it, collect alignment and TBAA info. + // different sizes. While we are at it, collect alignment and AA info. for (AliasSet::iterator ASI = AS.begin(), E = AS.end(); ASI != E; ++ASI) { Value *ASIV = ASI->getValue(); PointerMustAliases.insert(ASIV); @@ -833,6 +854,13 @@ void LICM::PromoteAliasSet(AliasSet &AS, assert(!store->isVolatile() && "AST broken"); if (!store->isSimple()) return; + // Don't sink stores from loops without dedicated block exits. Exits + // containing indirect branches are not transformed by loop simplify, + // make sure we catch that. An additional load may be generated in the + // preheader for SSA updater, so also avoid sinking when no preheader + // is available. + if (!HasDedicatedExits || !Preheader) + return; // Note that we only check GuaranteedToExecute inside the store case // so that we do not introduce stores where they did not exist before @@ -855,13 +883,12 @@ void LICM::PromoteAliasSet(AliasSet &AS, } else return; // Not a load or store. - // Merge the TBAA tags. + // Merge the AA tags. if (LoopUses.empty()) { - // On the first load/store, just take its TBAA tag. - TBAATag = UI->getMetadata(LLVMContext::MD_tbaa); - } else if (TBAATag) { - TBAATag = MDNode::getMostGenericTBAA(TBAATag, - UI->getMetadata(LLVMContext::MD_tbaa)); + // On the first load/store, just take its AA tags. + UI->getAAMetadata(AATags); + } else if (AATags) { + UI->getAAMetadata(AATags, /* Merge = */ true); } LoopUses.push_back(UI); @@ -896,7 +923,7 @@ void LICM::PromoteAliasSet(AliasSet &AS, SmallVector<PHINode*, 16> NewPHIs; SSAUpdater SSA(&NewPHIs); LoopPromoter Promoter(SomePtr, LoopUses, SSA, PointerMustAliases, ExitBlocks, - InsertPts, PIC, *CurAST, *LI, DL, Alignment, TBAATag); + InsertPts, PIC, *CurAST, *LI, DL, Alignment, AATags); // Set up the preheader to have a definition of the value. It is the live-out // value from the preheader that uses in the loop will use. @@ -905,7 +932,7 @@ void LICM::PromoteAliasSet(AliasSet &AS, Preheader->getTerminator()); PreheaderLoad->setAlignment(Alignment); PreheaderLoad->setDebugLoc(DL); - if (TBAATag) PreheaderLoad->setMetadata(LLVMContext::MD_tbaa, TBAATag); + if (AATags) PreheaderLoad->setAAMetadata(AATags); SSA.AddAvailableValue(Preheader, PreheaderLoad); // Rewrite all the loads in the loop and remember all the definitions from @@ -936,3 +963,13 @@ void LICM::deleteAnalysisValue(Value *V, Loop *L) { AST->deleteValue(V); } + +/// Simple Analysis hook. Delete value L from alias set map. +void LICM::deleteAnalysisLoop(Loop *L) { + AliasSetTracker *AST = LoopToAliasSetMap.lookup(L); + if (!AST) + return; + + delete AST; + LoopToAliasSetMap.erase(L); +} diff --git a/lib/Transforms/Scalar/LLVMBuild.txt b/lib/Transforms/Scalar/LLVMBuild.txt index 1f6df7dac7ff..2bb49a3026c9 100644 --- a/lib/Transforms/Scalar/LLVMBuild.txt +++ b/lib/Transforms/Scalar/LLVMBuild.txt @@ -20,4 +20,4 @@ type = Library name = Scalar parent = Transforms library_name = ScalarOpts -required_libraries = Analysis Core IPA InstCombine Support Target TransformUtils +required_libraries = Analysis Core InstCombine ProfileData Support Target TransformUtils diff --git a/lib/Transforms/Scalar/LoadCombine.cpp b/lib/Transforms/Scalar/LoadCombine.cpp index 846aa703c9c3..11e4d7606d96 100644 --- a/lib/Transforms/Scalar/LoadCombine.cpp +++ b/lib/Transforms/Scalar/LoadCombine.cpp @@ -15,6 +15,8 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AliasSetTracker.h" #include "llvm/Analysis/TargetFolder.h" #include "llvm/Pass.h" #include "llvm/IR/DataLayout.h" @@ -51,13 +53,16 @@ struct LoadPOPPair { class LoadCombine : public BasicBlockPass { LLVMContext *C; const DataLayout *DL; + AliasAnalysis *AA; public: LoadCombine() : BasicBlockPass(ID), - C(nullptr), DL(nullptr) { + C(nullptr), DL(nullptr), AA(nullptr) { initializeSROAPass(*PassRegistry::getPassRegistry()); } + + using llvm::Pass::doInitialization; bool doInitialization(Function &) override; bool runOnBasicBlock(BasicBlock &BB) override; void getAnalysisUsage(AnalysisUsage &AU) const override; @@ -223,19 +228,23 @@ bool LoadCombine::runOnBasicBlock(BasicBlock &BB) { if (skipOptnoneFunction(BB) || !DL) return false; + AA = &getAnalysis<AliasAnalysis>(); + IRBuilder<true, TargetFolder> TheBuilder(BB.getContext(), TargetFolder(DL)); Builder = &TheBuilder; DenseMap<const Value *, SmallVector<LoadPOPPair, 8>> LoadMap; + AliasSetTracker AST(*AA); bool Combined = false; unsigned Index = 0; for (auto &I : BB) { - if (I.mayWriteToMemory() || I.mayThrow()) { + if (I.mayThrow() || (I.mayWriteToMemory() && AST.containsUnknown(&I))) { if (combineLoads(LoadMap)) Combined = true; LoadMap.clear(); + AST.clear(); continue; } LoadInst *LI = dyn_cast<LoadInst>(&I); @@ -248,6 +257,7 @@ bool LoadCombine::runOnBasicBlock(BasicBlock &BB) { if (!POP.Pointer) continue; LoadMap[POP.Pointer].push_back(LoadPOPPair(LI, POP, Index++)); + AST.add(LI); } if (combineLoads(LoadMap)) Combined = true; @@ -256,6 +266,9 @@ bool LoadCombine::runOnBasicBlock(BasicBlock &BB) { void LoadCombine::getAnalysisUsage(AnalysisUsage &AU) const { AU.setPreservesCFG(); + + AU.addRequired<AliasAnalysis>(); + AU.addPreserved<AliasAnalysis>(); } char LoadCombine::ID = 0; @@ -264,5 +277,9 @@ BasicBlockPass *llvm::createLoadCombinePass() { return new LoadCombine(); } -INITIALIZE_PASS(LoadCombine, "load-combine", "Combine Adjacent Loads", false, - false) +INITIALIZE_PASS_BEGIN(LoadCombine, "load-combine", "Combine Adjacent Loads", + false, false) +INITIALIZE_AG_DEPENDENCY(AliasAnalysis) +INITIALIZE_PASS_END(LoadCombine, "load-combine", "Combine Adjacent Loads", + false, false) + diff --git a/lib/Transforms/Scalar/LoopDeletion.cpp b/lib/Transforms/Scalar/LoopDeletion.cpp index 5ab686aa831a..1d1f33ae6183 100644 --- a/lib/Transforms/Scalar/LoopDeletion.cpp +++ b/lib/Transforms/Scalar/LoopDeletion.cpp @@ -239,9 +239,8 @@ bool LoopDeletion::runOnLoop(Loop *L, LPPassManager &LPM) { LoopInfo &loopInfo = getAnalysis<LoopInfo>(); SmallPtrSet<BasicBlock*, 8> blocks; blocks.insert(L->block_begin(), L->block_end()); - for (SmallPtrSet<BasicBlock*,8>::iterator I = blocks.begin(), - E = blocks.end(); I != E; ++I) - loopInfo.removeBlock(*I); + for (BasicBlock *BB : blocks) + loopInfo.removeBlock(BB); // The last step is to inform the loop pass manager that we've // eliminated this loop. diff --git a/lib/Transforms/Scalar/LoopInstSimplify.cpp b/lib/Transforms/Scalar/LoopInstSimplify.cpp index ab1a9393c526..1ac38e0f52a5 100644 --- a/lib/Transforms/Scalar/LoopInstSimplify.cpp +++ b/lib/Transforms/Scalar/LoopInstSimplify.cpp @@ -14,6 +14,7 @@ #include "llvm/Transforms/Scalar.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" @@ -41,6 +42,7 @@ namespace { void getAnalysisUsage(AnalysisUsage &AU) const override { AU.setPreservesCFG(); + AU.addRequired<AssumptionCacheTracker>(); AU.addRequired<LoopInfo>(); AU.addRequiredID(LoopSimplifyID); AU.addPreservedID(LoopSimplifyID); @@ -54,6 +56,7 @@ namespace { char LoopInstSimplify::ID = 0; INITIALIZE_PASS_BEGIN(LoopInstSimplify, "loop-instsimplify", "Simplify instructions in loops", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfo) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopInfo) @@ -76,6 +79,8 @@ bool LoopInstSimplify::runOnLoop(Loop *L, LPPassManager &LPM) { DataLayoutPass *DLP = getAnalysisIfAvailable<DataLayoutPass>(); const DataLayout *DL = DLP ? &DLP->getDataLayout() : nullptr; const TargetLibraryInfo *TLI = &getAnalysis<TargetLibraryInfo>(); + auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache( + *L->getHeader()->getParent()); SmallVector<BasicBlock*, 8> ExitBlocks; L->getUniqueExitBlocks(ExitBlocks); @@ -116,7 +121,7 @@ bool LoopInstSimplify::runOnLoop(Loop *L, LPPassManager &LPM) { // Don't bother simplifying unused instructions. if (!I->use_empty()) { - Value *V = SimplifyInstruction(I, DL, TLI, DT); + Value *V = SimplifyInstruction(I, DL, TLI, DT, &AC); if (V && LI->replacementPreservesLCSSAForm(I, V)) { // Mark all uses for resimplification next time round the loop. for (User *U : I->users()) @@ -148,7 +153,7 @@ bool LoopInstSimplify::runOnLoop(Loop *L, LPPassManager &LPM) { for (succ_iterator SI = succ_begin(BB), SE = succ_end(BB); SI != SE; ++SI) { BasicBlock *SuccBB = *SI; - if (!Visited.insert(SuccBB)) + if (!Visited.insert(SuccBB).second) continue; const Loop *SuccLoop = LI->getLoopFor(SuccBB); @@ -161,7 +166,7 @@ bool LoopInstSimplify::runOnLoop(Loop *L, LPPassManager &LPM) { for (unsigned i = 0; i < SubLoopExitBlocks.size(); ++i) { BasicBlock *ExitBB = SubLoopExitBlocks[i]; - if (LI->getLoopFor(ExitBB) == L && Visited.insert(ExitBB)) + if (LI->getLoopFor(ExitBB) == L && Visited.insert(ExitBB).second) VisitStack.push_back(WorklistItem(ExitBB, false)); } diff --git a/lib/Transforms/Scalar/LoopRerollPass.cpp b/lib/Transforms/Scalar/LoopRerollPass.cpp index b6fbb16166dd..8f122041c248 100644 --- a/lib/Transforms/Scalar/LoopRerollPass.cpp +++ b/lib/Transforms/Scalar/LoopRerollPass.cpp @@ -215,9 +215,7 @@ protected: typedef SmallVector<SimpleLoopReduction, 16> SmallReductionVector; // Add a new possible reduction. - void addSLR(SimpleLoopReduction &SLR) { - PossibleReds.push_back(SLR); - } + void addSLR(SimpleLoopReduction &SLR) { PossibleReds.push_back(SLR); } // Setup to track possible reductions corresponding to the provided // rerolling scale. Only reductions with a number of non-PHI instructions @@ -225,7 +223,8 @@ protected: // are filled in: // - A set of all possible instructions in eligible reductions. // - A set of all PHIs in eligible reductions - // - A set of all reduced values (last instructions) in eligible reductions. + // - A set of all reduced values (last instructions) in eligible + // reductions. void restrictToScale(uint64_t Scale, SmallInstructionSet &PossibleRedSet, SmallInstructionSet &PossibleRedPHISet, @@ -238,13 +237,12 @@ protected: if (PossibleReds[i].size() % Scale == 0) { PossibleRedLastSet.insert(PossibleReds[i].getReducedValue()); PossibleRedPHISet.insert(PossibleReds[i].getPHI()); - + PossibleRedSet.insert(PossibleReds[i].getPHI()); PossibleRedIdx[PossibleReds[i].getPHI()] = i; - for (SimpleLoopReduction::iterator J = PossibleReds[i].begin(), - JE = PossibleReds[i].end(); J != JE; ++J) { - PossibleRedSet.insert(*J); - PossibleRedIdx[*J] = i; + for (Instruction *J : PossibleReds[i]) { + PossibleRedSet.insert(J); + PossibleRedIdx[J] = i; } } } @@ -487,7 +485,7 @@ void LoopReroll::collectInLoopUserSet(Loop *L, if (PN->getIncomingBlock(U) == L->getHeader()) continue; } - + if (L->contains(User) && !Exclude.count(User)) { Queue.push_back(User); } @@ -659,16 +657,15 @@ bool LoopReroll::ReductionTracker::validateSelected() { RI != RIE; ++RI) { int i = *RI; int PrevIter = 0, BaseCount = 0, Count = 0; - for (SimpleLoopReduction::iterator J = PossibleReds[i].begin(), - JE = PossibleReds[i].end(); J != JE; ++J) { - // Note that all instructions in the chain must have been found because - // all instructions in the function must have been assigned to some - // iteration. - int Iter = PossibleRedIter[*J]; + for (Instruction *J : PossibleReds[i]) { + // Note that all instructions in the chain must have been found because + // all instructions in the function must have been assigned to some + // iteration. + int Iter = PossibleRedIter[J]; if (Iter != PrevIter && Iter != PrevIter + 1 && !PossibleReds[i].getReducedValue()->isAssociative()) { DEBUG(dbgs() << "LRR: Out-of-order non-associative reduction: " << - *J << "\n"); + J << "\n"); return false; } @@ -881,7 +878,7 @@ bool LoopReroll::reroll(Instruction *IV, Loop *L, BasicBlock *Header, // needed because otherwise isSafeToSpeculativelyExecute returns // false on PHI nodes. if (!isSimpleLoadStore(J2) && !isSafeToSpeculativelyExecute(J2, DL)) - FutureSideEffects = true; + FutureSideEffects = true; } ++J2; @@ -952,9 +949,9 @@ bool LoopReroll::reroll(Instruction *IV, Loop *L, BasicBlock *Header, for (unsigned j = 0; j < J1->getNumOperands() && !MatchFailed; ++j) { Value *Op2 = J2->getOperand(j); - // If this is part of a reduction (and the operation is not - // associatve), then we match all operands, but not those that are - // part of the reduction. + // If this is part of a reduction (and the operation is not + // associatve), then we match all operands, but not those that are + // part of the reduction. if (InReduction) if (Instruction *Op2I = dyn_cast<Instruction>(Op2)) if (Reductions.isPairInSame(J2, Op2I)) @@ -968,11 +965,11 @@ bool LoopReroll::reroll(Instruction *IV, Loop *L, BasicBlock *Header, Op2 = IV; if (J1->getOperand(Swapped ? unsigned(!j) : j) != Op2) { - // If we've not already decided to swap the matched operands, and - // we've not already matched our first operand (note that we could - // have skipped matching the first operand because it is part of a - // reduction above), and the instruction is commutative, then try - // the swapped match. + // If we've not already decided to swap the matched operands, and + // we've not already matched our first operand (note that we could + // have skipped matching the first operand because it is part of a + // reduction above), and the instruction is commutative, then try + // the swapped match. if (!Swapped && J1->isCommutative() && !SomeOpMatched && J1->getOperand(!j) == Op2) { Swapped = true; @@ -1069,7 +1066,7 @@ bool LoopReroll::reroll(Instruction *IV, Loop *L, BasicBlock *Header, continue; } - ++J; + ++J; } // Insert the new induction variable. @@ -1110,9 +1107,9 @@ bool LoopReroll::reroll(Instruction *IV, Loop *L, BasicBlock *Header, ICMinus1 = Expander.expandCodeFor(ICMinus1SCEV, NewIV->getType(), Preheader->getTerminator()); } - - Value *Cond = new ICmpInst(BI, CmpInst::ICMP_EQ, NewIV, ICMinus1, - "exitcond"); + + Value *Cond = + new ICmpInst(BI, CmpInst::ICMP_EQ, NewIV, ICMinus1, "exitcond"); BI->setCondition(Cond); if (BI->getSuccessor(1) != Header) @@ -1182,4 +1179,3 @@ bool LoopReroll::runOnLoop(Loop *L, LPPassManager &LPM) { return Changed; } - diff --git a/lib/Transforms/Scalar/LoopRotation.cpp b/lib/Transforms/Scalar/LoopRotation.cpp index 2ce58314f8ef..9164be224654 100644 --- a/lib/Transforms/Scalar/LoopRotation.cpp +++ b/lib/Transforms/Scalar/LoopRotation.cpp @@ -13,6 +13,7 @@ #include "llvm/Transforms/Scalar.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/CodeMetrics.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopPass.h" @@ -53,6 +54,7 @@ namespace { // LCSSA form makes instruction renaming easier. void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<AssumptionCacheTracker>(); AU.addPreserved<DominatorTreeWrapperPass>(); AU.addRequired<LoopInfo>(); AU.addPreserved<LoopInfo>(); @@ -72,12 +74,14 @@ namespace { unsigned MaxHeaderSize; LoopInfo *LI; const TargetTransformInfo *TTI; + AssumptionCache *AC; }; } char LoopRotate::ID = 0; INITIALIZE_PASS_BEGIN(LoopRotate, "loop-rotate", "Rotate Loops", false, false) INITIALIZE_AG_DEPENDENCY(TargetTransformInfo) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(LoopInfo) INITIALIZE_PASS_DEPENDENCY(LoopSimplify) INITIALIZE_PASS_DEPENDENCY(LCSSA) @@ -98,6 +102,8 @@ bool LoopRotate::runOnLoop(Loop *L, LPPassManager &LPM) { LI = &getAnalysis<LoopInfo>(); TTI = &getAnalysis<TargetTransformInfo>(); + AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache( + *L->getHeader()->getParent()); // Simplify the loop latch before attempting to rotate the header // upward. Rotation may not be needed if the loop tail can be folded into the @@ -184,13 +190,18 @@ static void RewriteUsesOfClonedInstructions(BasicBlock *OrigHeader, } } -/// Determine whether the instructions in this range my be safely and cheaply +/// Determine whether the instructions in this range may be safely and cheaply /// speculated. This is not an important enough situation to develop complex /// heuristics. We handle a single arithmetic instruction along with any type /// conversions. static bool shouldSpeculateInstrs(BasicBlock::iterator Begin, - BasicBlock::iterator End) { + BasicBlock::iterator End, Loop *L) { bool seenIncrement = false; + bool MultiExitLoop = false; + + if (!L->getExitingBlock()) + MultiExitLoop = true; + for (BasicBlock::iterator I = Begin; I != End; ++I) { if (!isSafeToSpeculativelyExecute(I)) @@ -214,11 +225,33 @@ static bool shouldSpeculateInstrs(BasicBlock::iterator Begin, case Instruction::Xor: case Instruction::Shl: case Instruction::LShr: - case Instruction::AShr: + case Instruction::AShr: { + Value *IVOpnd = nullptr; + if (isa<ConstantInt>(I->getOperand(0))) + IVOpnd = I->getOperand(1); + + if (isa<ConstantInt>(I->getOperand(1))) { + if (IVOpnd) + return false; + + IVOpnd = I->getOperand(0); + } + + // If increment operand is used outside of the loop, this speculation + // could cause extra live range interference. + if (MultiExitLoop && IVOpnd) { + for (User *UseI : IVOpnd->users()) { + auto *UserInst = cast<Instruction>(UseI); + if (!L->contains(UserInst)) + return false; + } + } + if (seenIncrement) return false; seenIncrement = true; break; + } case Instruction::Trunc: case Instruction::ZExt: case Instruction::SExt: @@ -232,7 +265,7 @@ static bool shouldSpeculateInstrs(BasicBlock::iterator Begin, /// Fold the loop tail into the loop exit by speculating the loop tail /// instructions. Typically, this is a single post-increment. In the case of a /// simple 2-block loop, hoisting the increment can be much better than -/// duplicating the entire loop header. In the cast of loops with early exits, +/// duplicating the entire loop header. In the case of loops with early exits, /// rotation will not work anyway, but simplifyLoopLatch will put the loop in /// canonical form so downstream passes can handle it. /// @@ -254,7 +287,7 @@ bool LoopRotate::simplifyLoopLatch(Loop *L) { if (!BI) return false; - if (!shouldSpeculateInstrs(Latch->begin(), Jmp)) + if (!shouldSpeculateInstrs(Latch->begin(), Jmp, L)) return false; DEBUG(dbgs() << "Folding loop latch " << Latch->getName() << " into " @@ -323,8 +356,11 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) { // Check size of original header and reject loop if it is very big or we can't // duplicate blocks inside it. { + SmallPtrSet<const Value *, 32> EphValues; + CodeMetrics::collectEphemeralValues(L, AC, EphValues); + CodeMetrics Metrics; - Metrics.analyzeBasicBlock(OrigHeader, *TTI); + Metrics.analyzeBasicBlock(OrigHeader, *TTI, EphValues); if (Metrics.notDuplicatable) { DEBUG(dbgs() << "LoopRotation: NOT rotating - contains non-duplicatable" << " instructions: "; L->dump()); @@ -406,6 +442,7 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) { // With the operands remapped, see if the instruction constant folds or is // otherwise simplifyable. This commonly occurs because the entry from PHI // nodes allows icmps and other instructions to fold. + // FIXME: Provide DL, TLI, DT, AC to SimplifyInstruction. Value *V = SimplifyInstruction(C); if (V && LI->replacementPreservesLCSSAForm(C, V)) { // If so, then delete the temporary instruction and stick the folded value diff --git a/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/lib/Transforms/Scalar/LoopStrengthReduce.cpp index 914b56aa8167..7b60373dc508 100644 --- a/lib/Transforms/Scalar/LoopStrengthReduce.cpp +++ b/lib/Transforms/Scalar/LoopStrengthReduce.cpp @@ -744,7 +744,7 @@ static bool isExistingPhi(const SCEVAddRecExpr *AR, ScalarEvolution &SE) { /// TODO: Allow UDivExpr if we can find an existing IV increment that is an /// obvious multiple of the UDivExpr. static bool isHighCostExpansion(const SCEV *S, - SmallPtrSet<const SCEV*, 8> &Processed, + SmallPtrSetImpl<const SCEV*> &Processed, ScalarEvolution &SE) { // Zero/One operand expressions switch (S->getSCEVType()) { @@ -762,7 +762,7 @@ static bool isHighCostExpansion(const SCEV *S, Processed, SE); } - if (!Processed.insert(S)) + if (!Processed.insert(S).second) return false; if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) { @@ -892,34 +892,34 @@ public: void RateFormula(const TargetTransformInfo &TTI, const Formula &F, - SmallPtrSet<const SCEV *, 16> &Regs, + SmallPtrSetImpl<const SCEV *> &Regs, const DenseSet<const SCEV *> &VisitedRegs, const Loop *L, const SmallVectorImpl<int64_t> &Offsets, ScalarEvolution &SE, DominatorTree &DT, const LSRUse &LU, - SmallPtrSet<const SCEV *, 16> *LoserRegs = nullptr); + SmallPtrSetImpl<const SCEV *> *LoserRegs = nullptr); void print(raw_ostream &OS) const; void dump() const; private: void RateRegister(const SCEV *Reg, - SmallPtrSet<const SCEV *, 16> &Regs, + SmallPtrSetImpl<const SCEV *> &Regs, const Loop *L, ScalarEvolution &SE, DominatorTree &DT); void RatePrimaryRegister(const SCEV *Reg, - SmallPtrSet<const SCEV *, 16> &Regs, + SmallPtrSetImpl<const SCEV *> &Regs, const Loop *L, ScalarEvolution &SE, DominatorTree &DT, - SmallPtrSet<const SCEV *, 16> *LoserRegs); + SmallPtrSetImpl<const SCEV *> *LoserRegs); }; } /// RateRegister - Tally up interesting quantities from the given register. void Cost::RateRegister(const SCEV *Reg, - SmallPtrSet<const SCEV *, 16> &Regs, + SmallPtrSetImpl<const SCEV *> &Regs, const Loop *L, ScalarEvolution &SE, DominatorTree &DT) { if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Reg)) { @@ -967,15 +967,15 @@ void Cost::RateRegister(const SCEV *Reg, /// before, rate it. Optional LoserRegs provides a way to declare any formula /// that refers to one of those regs an instant loser. void Cost::RatePrimaryRegister(const SCEV *Reg, - SmallPtrSet<const SCEV *, 16> &Regs, + SmallPtrSetImpl<const SCEV *> &Regs, const Loop *L, ScalarEvolution &SE, DominatorTree &DT, - SmallPtrSet<const SCEV *, 16> *LoserRegs) { + SmallPtrSetImpl<const SCEV *> *LoserRegs) { if (LoserRegs && LoserRegs->count(Reg)) { Lose(); return; } - if (Regs.insert(Reg)) { + if (Regs.insert(Reg).second) { RateRegister(Reg, Regs, L, SE, DT); if (LoserRegs && isLoser()) LoserRegs->insert(Reg); @@ -984,13 +984,13 @@ void Cost::RatePrimaryRegister(const SCEV *Reg, void Cost::RateFormula(const TargetTransformInfo &TTI, const Formula &F, - SmallPtrSet<const SCEV *, 16> &Regs, + SmallPtrSetImpl<const SCEV *> &Regs, const DenseSet<const SCEV *> &VisitedRegs, const Loop *L, const SmallVectorImpl<int64_t> &Offsets, ScalarEvolution &SE, DominatorTree &DT, const LSRUse &LU, - SmallPtrSet<const SCEV *, 16> *LoserRegs) { + SmallPtrSetImpl<const SCEV *> *LoserRegs) { assert(F.isCanonical() && "Cost is accurate only for canonical formula"); // Tally up the registers. if (const SCEV *ScaledReg = F.ScaledReg) { @@ -1337,10 +1337,9 @@ void LSRUse::RecomputeRegs(size_t LUIdx, RegUseTracker &RegUses) { } // Update the RegTracker. - for (SmallPtrSet<const SCEV *, 4>::iterator I = OldRegs.begin(), - E = OldRegs.end(); I != E; ++I) - if (!Regs.count(*I)) - RegUses.DropRegister(*I, LUIdx); + for (const SCEV *S : OldRegs) + if (!Regs.count(S)) + RegUses.DropRegister(S, LUIdx); } void LSRUse::print(raw_ostream &OS) const { @@ -2226,13 +2225,12 @@ LSRInstance::OptimizeLoopTermCond() { // must dominate all the post-inc comparisons we just set up, and it must // dominate the loop latch edge. IVIncInsertPos = L->getLoopLatch()->getTerminator(); - for (SmallPtrSet<Instruction *, 4>::const_iterator I = PostIncs.begin(), - E = PostIncs.end(); I != E; ++I) { + for (Instruction *Inst : PostIncs) { BasicBlock *BB = DT.findNearestCommonDominator(IVIncInsertPos->getParent(), - (*I)->getParent()); - if (BB == (*I)->getParent()) - IVIncInsertPos = *I; + Inst->getParent()); + if (BB == Inst->getParent()) + IVIncInsertPos = Inst; else if (BB != IVIncInsertPos->getParent()) IVIncInsertPos = BB->getTerminator(); } @@ -2557,7 +2555,7 @@ bool IVChain::isProfitableIncrement(const SCEV *OperExpr, /// /// TODO: Consider IVInc free if it's already used in another chains. static bool -isProfitableChain(IVChain &Chain, SmallPtrSet<Instruction*, 4> &Users, +isProfitableChain(IVChain &Chain, SmallPtrSetImpl<Instruction*> &Users, ScalarEvolution &SE, const TargetTransformInfo &TTI) { if (StressIVChain) return true; @@ -2567,9 +2565,8 @@ isProfitableChain(IVChain &Chain, SmallPtrSet<Instruction*, 4> &Users, if (!Users.empty()) { DEBUG(dbgs() << "Chain: " << *Chain.Incs[0].UserInst << " users:\n"; - for (SmallPtrSet<Instruction*, 4>::const_iterator I = Users.begin(), - E = Users.end(); I != E; ++I) { - dbgs() << " " << **I << "\n"; + for (Instruction *Inst : Users) { + dbgs() << " " << *Inst << "\n"; }); return false; } @@ -2805,7 +2802,7 @@ void LSRInstance::CollectChains() { User::op_iterator IVOpIter = findIVOperand(I->op_begin(), IVOpEnd, L, SE); while (IVOpIter != IVOpEnd) { Instruction *IVOpInst = cast<Instruction>(*IVOpIter); - if (UniqueOperands.insert(IVOpInst)) + if (UniqueOperands.insert(IVOpInst).second) ChainInstruction(I, IVOpInst, ChainUsersVec); IVOpIter = findIVOperand(std::next(IVOpIter), IVOpEnd, L, SE); } @@ -3119,11 +3116,15 @@ bool LSRInstance::InsertFormula(LSRUse &LU, unsigned LUIdx, const Formula &F) { void LSRInstance::CollectLoopInvariantFixupsAndFormulae() { SmallVector<const SCEV *, 8> Worklist(RegUses.begin(), RegUses.end()); - SmallPtrSet<const SCEV *, 8> Inserted; + SmallPtrSet<const SCEV *, 32> Visited; while (!Worklist.empty()) { const SCEV *S = Worklist.pop_back_val(); + // Don't process the same SCEV twice + if (!Visited.insert(S).second) + continue; + if (const SCEVNAryExpr *N = dyn_cast<SCEVNAryExpr>(S)) Worklist.append(N->op_begin(), N->op_end()); else if (const SCEVCastExpr *C = dyn_cast<SCEVCastExpr>(S)) @@ -3132,7 +3133,6 @@ LSRInstance::CollectLoopInvariantFixupsAndFormulae() { Worklist.push_back(D->getLHS()); Worklist.push_back(D->getRHS()); } else if (const SCEVUnknown *US = dyn_cast<SCEVUnknown>(S)) { - if (!Inserted.insert(US)) continue; const Value *V = US->getValue(); if (const Instruction *Inst = dyn_cast<Instruction>(V)) { // Look for instructions defined outside the loop. @@ -3774,7 +3774,7 @@ void LSRInstance::GenerateCrossUseConstantOffsets() { for (int LUIdx = UsedByIndices.find_first(); LUIdx != -1; LUIdx = UsedByIndices.find_next(LUIdx)) // Make a memo of this use, offset, and register tuple. - if (UniqueItems.insert(std::make_pair(LUIdx, Imm))) + if (UniqueItems.insert(std::make_pair(LUIdx, Imm)).second) WorkItems.push_back(WorkItem(LUIdx, Imm, OrigReg)); } } @@ -4302,10 +4302,9 @@ void LSRInstance::SolveRecurse(SmallVectorImpl<const Formula *> &Solution, // reference that register in order to be considered. This prunes out // unprofitable searching. SmallSetVector<const SCEV *, 4> ReqRegs; - for (SmallPtrSet<const SCEV *, 16>::const_iterator I = CurRegs.begin(), - E = CurRegs.end(); I != E; ++I) - if (LU.Regs.count(*I)) - ReqRegs.insert(*I); + for (const SCEV *S : CurRegs) + if (LU.Regs.count(S)) + ReqRegs.insert(S); SmallPtrSet<const SCEV *, 16> NewRegs; Cost NewCost; @@ -4350,9 +4349,8 @@ void LSRInstance::SolveRecurse(SmallVectorImpl<const Formula *> &Solution, } else { DEBUG(dbgs() << "New best at "; NewCost.print(dbgs()); dbgs() << ".\n Regs:"; - for (SmallPtrSet<const SCEV *, 16>::const_iterator - I = NewRegs.begin(), E = NewRegs.end(); I != E; ++I) - dbgs() << ' ' << **I; + for (const SCEV *S : NewRegs) + dbgs() << ' ' << *S; dbgs() << '\n'); SolutionCost = NewCost; diff --git a/lib/Transforms/Scalar/LoopUnrollPass.cpp b/lib/Transforms/Scalar/LoopUnrollPass.cpp index 935f289f040f..fef52107f623 100644 --- a/lib/Transforms/Scalar/LoopUnrollPass.cpp +++ b/lib/Transforms/Scalar/LoopUnrollPass.cpp @@ -13,7 +13,9 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Scalar.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/CodeMetrics.h" +#include "llvm/Analysis/FunctionTargetTransformInfo.h" #include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/TargetTransformInfo.h" @@ -52,7 +54,7 @@ UnrollRuntime("unroll-runtime", cl::ZeroOrMore, cl::init(false), cl::Hidden, static cl::opt<unsigned> PragmaUnrollThreshold("pragma-unroll-threshold", cl::init(16 * 1024), cl::Hidden, - cl::desc("Unrolled size limit for loops with an unroll(enable) or " + cl::desc("Unrolled size limit for loops with an unroll(full) or " "unroll_count pragma.")); namespace { @@ -102,6 +104,7 @@ namespace { /// loop preheaders be inserted into the CFG... /// void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<AssumptionCacheTracker>(); AU.addRequired<LoopInfo>(); AU.addPreserved<LoopInfo>(); AU.addRequiredID(LoopSimplifyID); @@ -111,6 +114,7 @@ namespace { AU.addRequired<ScalarEvolution>(); AU.addPreserved<ScalarEvolution>(); AU.addRequired<TargetTransformInfo>(); + AU.addRequired<FunctionTargetTransformInfo>(); // FIXME: Loop unroll requires LCSSA. And LCSSA requires dom info. // If loop unroll does not preserve dom info then LCSSA pass on next // loop will receive invalid dom info. @@ -120,7 +124,7 @@ namespace { // Fill in the UnrollingPreferences parameter with values from the // TargetTransformationInfo. - void getUnrollingPreferences(Loop *L, const TargetTransformInfo &TTI, + void getUnrollingPreferences(Loop *L, const FunctionTargetTransformInfo &FTTI, TargetTransformInfo::UnrollingPreferences &UP) { UP.Threshold = CurrentThreshold; UP.OptSizeThreshold = OptSizeUnrollThreshold; @@ -130,7 +134,7 @@ namespace { UP.MaxCount = UINT_MAX; UP.Partial = CurrentAllowPartial; UP.Runtime = CurrentRuntime; - TTI.getUnrollingPreferences(L, UP); + FTTI.getUnrollingPreferences(L, UP); } // Select and return an unroll count based on parameters from @@ -138,12 +142,11 @@ namespace { // SetExplicitly is set to true if the unroll count is is set by // the user or a pragma rather than selected heuristically. unsigned - selectUnrollCount(const Loop *L, unsigned TripCount, bool HasEnablePragma, + selectUnrollCount(const Loop *L, unsigned TripCount, bool PragmaFullUnroll, unsigned PragmaCount, const TargetTransformInfo::UnrollingPreferences &UP, bool &SetExplicitly); - // Select threshold values used to limit unrolling based on a // total unrolled size. Parameters Threshold and PartialThreshold // are set to the maximum unrolled size for fully and partially @@ -183,6 +186,8 @@ namespace { char LoopUnroll::ID = 0; INITIALIZE_PASS_BEGIN(LoopUnroll, "loop-unroll", "Unroll loops", false, false) INITIALIZE_AG_DEPENDENCY(TargetTransformInfo) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(FunctionTargetTransformInfo) INITIALIZE_PASS_DEPENDENCY(LoopInfo) INITIALIZE_PASS_DEPENDENCY(LoopSimplify) INITIALIZE_PASS_DEPENDENCY(LCSSA) @@ -201,11 +206,15 @@ Pass *llvm::createSimpleLoopUnrollPass() { /// ApproximateLoopSize - Approximate the size of the loop. static unsigned ApproximateLoopSize(const Loop *L, unsigned &NumCalls, bool &NotDuplicatable, - const TargetTransformInfo &TTI) { + const TargetTransformInfo &TTI, + AssumptionCache *AC) { + SmallPtrSet<const Value *, 32> EphValues; + CodeMetrics::collectEphemeralValues(L, AC, EphValues); + CodeMetrics Metrics; for (Loop::block_iterator I = L->block_begin(), E = L->block_end(); I != E; ++I) - Metrics.analyzeBasicBlock(*I, TTI); + Metrics.analyzeBasicBlock(*I, TTI, EphValues); NumCalls = Metrics.NumInlineCandidates; NotDuplicatable = Metrics.notDuplicatable; @@ -213,19 +222,22 @@ static unsigned ApproximateLoopSize(const Loop *L, unsigned &NumCalls, // Don't allow an estimate of size zero. This would allows unrolling of loops // with huge iteration counts, which is a compile time problem even if it's - // not a problem for code quality. - if (LoopSize == 0) LoopSize = 1; + // not a problem for code quality. Also, the code using this size may assume + // that each loop has at least three instructions (likely a conditional + // branch, a comparison feeding that branch, and some kind of loop increment + // feeding that comparison instruction). + LoopSize = std::max(LoopSize, 3u); return LoopSize; } -// Returns the value associated with the given metadata node name (for -// example, "llvm.loop.unroll.count"). If no such named metadata node -// exists, then nullptr is returned. -static const ConstantInt *GetUnrollMetadataValue(const Loop *L, - StringRef Name) { +// Returns the loop hint metadata node with the given name (for example, +// "llvm.loop.unroll.count"). If no such metadata node exists, then nullptr is +// returned. +static const MDNode *GetUnrollMetadata(const Loop *L, StringRef Name) { MDNode *LoopID = L->getLoopID(); - if (!LoopID) return nullptr; + if (!LoopID) + return nullptr; // First operand should refer to the loop id itself. assert(LoopID->getNumOperands() > 0 && "requires at least one operand"); @@ -233,41 +245,38 @@ static const ConstantInt *GetUnrollMetadataValue(const Loop *L, for (unsigned i = 1, e = LoopID->getNumOperands(); i < e; ++i) { const MDNode *MD = dyn_cast<MDNode>(LoopID->getOperand(i)); - if (!MD) continue; + if (!MD) + continue; const MDString *S = dyn_cast<MDString>(MD->getOperand(0)); - if (!S) continue; + if (!S) + continue; - if (Name.equals(S->getString())) { - assert(MD->getNumOperands() == 2 && - "Unroll hint metadata should have two operands."); - return cast<ConstantInt>(MD->getOperand(1)); - } + if (Name.equals(S->getString())) + return MD; } return nullptr; } -// Returns true if the loop has an unroll(enable) pragma. -static bool HasUnrollEnablePragma(const Loop *L) { - const ConstantInt *EnableValue = - GetUnrollMetadataValue(L, "llvm.loop.unroll.enable"); - return (EnableValue && EnableValue->getZExtValue()); +// Returns true if the loop has an unroll(full) pragma. +static bool HasUnrollFullPragma(const Loop *L) { + return GetUnrollMetadata(L, "llvm.loop.unroll.full"); } // Returns true if the loop has an unroll(disable) pragma. static bool HasUnrollDisablePragma(const Loop *L) { - const ConstantInt *EnableValue = - GetUnrollMetadataValue(L, "llvm.loop.unroll.enable"); - return (EnableValue && !EnableValue->getZExtValue()); + return GetUnrollMetadata(L, "llvm.loop.unroll.disable"); } // If loop has an unroll_count pragma return the (necessarily // positive) value from the pragma. Otherwise return 0. static unsigned UnrollCountPragmaValue(const Loop *L) { - const ConstantInt *CountValue = - GetUnrollMetadataValue(L, "llvm.loop.unroll.count"); - if (CountValue) { - unsigned Count = CountValue->getZExtValue(); + const MDNode *MD = GetUnrollMetadata(L, "llvm.loop.unroll.count"); + if (MD) { + assert(MD->getNumOperands() == 2 && + "Unroll count hint metadata should have two operands."); + unsigned Count = + mdconst::extract<ConstantInt>(MD->getOperand(1))->getZExtValue(); assert(Count >= 1 && "Unroll count must be positive."); return Count; } @@ -283,9 +292,9 @@ static void SetLoopAlreadyUnrolled(Loop *L) { if (!LoopID) return; // First remove any existing loop unrolling metadata. - SmallVector<Value *, 4> Vals; + SmallVector<Metadata *, 4> MDs; // Reserve first location for self reference to the LoopID metadata node. - Vals.push_back(nullptr); + MDs.push_back(nullptr); for (unsigned i = 1, ie = LoopID->getNumOperands(); i < ie; ++i) { bool IsUnrollMetadata = false; MDNode *MD = dyn_cast<MDNode>(LoopID->getOperand(i)); @@ -293,26 +302,25 @@ static void SetLoopAlreadyUnrolled(Loop *L) { const MDString *S = dyn_cast<MDString>(MD->getOperand(0)); IsUnrollMetadata = S && S->getString().startswith("llvm.loop.unroll."); } - if (!IsUnrollMetadata) Vals.push_back(LoopID->getOperand(i)); + if (!IsUnrollMetadata) + MDs.push_back(LoopID->getOperand(i)); } // Add unroll(disable) metadata to disable future unrolling. LLVMContext &Context = L->getHeader()->getContext(); - SmallVector<Value *, 2> DisableOperands; - DisableOperands.push_back(MDString::get(Context, "llvm.loop.unroll.enable")); - DisableOperands.push_back(ConstantInt::get(Type::getInt1Ty(Context), 0)); + SmallVector<Metadata *, 1> DisableOperands; + DisableOperands.push_back(MDString::get(Context, "llvm.loop.unroll.disable")); MDNode *DisableNode = MDNode::get(Context, DisableOperands); - Vals.push_back(DisableNode); + MDs.push_back(DisableNode); - MDNode *NewLoopID = MDNode::get(Context, Vals); + MDNode *NewLoopID = MDNode::get(Context, MDs); // Set operand 0 to refer to the loop id itself. NewLoopID->replaceOperandWith(0, NewLoopID); L->setLoopID(NewLoopID); - LoopID->replaceAllUsesWith(NewLoopID); } unsigned LoopUnroll::selectUnrollCount( - const Loop *L, unsigned TripCount, bool HasEnablePragma, + const Loop *L, unsigned TripCount, bool PragmaFullUnroll, unsigned PragmaCount, const TargetTransformInfo::UnrollingPreferences &UP, bool &SetExplicitly) { SetExplicitly = true; @@ -326,9 +334,7 @@ unsigned LoopUnroll::selectUnrollCount( if (Count == 0) { if (PragmaCount) { Count = PragmaCount; - } else if (HasEnablePragma) { - // unroll(enable) pragma without an unroll_count pragma - // indicates to unroll loop fully. + } else if (PragmaFullUnroll) { Count = TripCount; } } @@ -360,6 +366,10 @@ bool LoopUnroll::runOnLoop(Loop *L, LPPassManager &LPM) { LoopInfo *LI = &getAnalysis<LoopInfo>(); ScalarEvolution *SE = &getAnalysis<ScalarEvolution>(); const TargetTransformInfo &TTI = getAnalysis<TargetTransformInfo>(); + const FunctionTargetTransformInfo &FTTI = + getAnalysis<FunctionTargetTransformInfo>(); + auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache( + *L->getHeader()->getParent()); BasicBlock *Header = L->getHeader(); DEBUG(dbgs() << "Loop Unroll: F[" << Header->getParent()->getName() @@ -368,37 +378,43 @@ bool LoopUnroll::runOnLoop(Loop *L, LPPassManager &LPM) { if (HasUnrollDisablePragma(L)) { return false; } - bool HasEnablePragma = HasUnrollEnablePragma(L); + bool PragmaFullUnroll = HasUnrollFullPragma(L); unsigned PragmaCount = UnrollCountPragmaValue(L); - bool HasPragma = HasEnablePragma || PragmaCount > 0; + bool HasPragma = PragmaFullUnroll || PragmaCount > 0; TargetTransformInfo::UnrollingPreferences UP; - getUnrollingPreferences(L, TTI, UP); + getUnrollingPreferences(L, FTTI, UP); // Find trip count and trip multiple if count is not available unsigned TripCount = 0; unsigned TripMultiple = 1; - // Find "latch trip count". UnrollLoop assumes that control cannot exit - // via the loop latch on any iteration prior to TripCount. The loop may exit - // early via an earlier branch. - BasicBlock *LatchBlock = L->getLoopLatch(); - if (LatchBlock) { - TripCount = SE->getSmallConstantTripCount(L, LatchBlock); - TripMultiple = SE->getSmallConstantTripMultiple(L, LatchBlock); + // If there are multiple exiting blocks but one of them is the latch, use the + // latch for the trip count estimation. Otherwise insist on a single exiting + // block for the trip count estimation. + BasicBlock *ExitingBlock = L->getLoopLatch(); + if (!ExitingBlock || !L->isLoopExiting(ExitingBlock)) + ExitingBlock = L->getExitingBlock(); + if (ExitingBlock) { + TripCount = SE->getSmallConstantTripCount(L, ExitingBlock); + TripMultiple = SE->getSmallConstantTripMultiple(L, ExitingBlock); } // Select an initial unroll count. This may be reduced later based // on size thresholds. bool CountSetExplicitly; - unsigned Count = selectUnrollCount(L, TripCount, HasEnablePragma, PragmaCount, - UP, CountSetExplicitly); + unsigned Count = selectUnrollCount(L, TripCount, PragmaFullUnroll, + PragmaCount, UP, CountSetExplicitly); unsigned NumInlineCandidates; bool notDuplicatable; unsigned LoopSize = - ApproximateLoopSize(L, NumInlineCandidates, notDuplicatable, TTI); + ApproximateLoopSize(L, NumInlineCandidates, notDuplicatable, TTI, &AC); DEBUG(dbgs() << " Loop Size = " << LoopSize << "\n"); - uint64_t UnrolledSize = (uint64_t)LoopSize * Count; + + // When computing the unrolled size, note that the conditional branch on the + // backedge and the comparison feeding it are not replicated like the rest of + // the loop body (which is why 2 is subtracted). + uint64_t UnrolledSize = (uint64_t)(LoopSize-2) * Count + 2; if (notDuplicatable) { DEBUG(dbgs() << " Not unrolling loop which contains non-duplicatable" << " instructions.\n"); @@ -443,7 +459,7 @@ bool LoopUnroll::runOnLoop(Loop *L, LPPassManager &LPM) { } if (PartialThreshold != NoThreshold && UnrolledSize > PartialThreshold) { // Reduce unroll count to be modulo of TripCount for partial unrolling. - Count = PartialThreshold / LoopSize; + Count = (std::max(PartialThreshold, 3u)-2) / (LoopSize-2); while (Count != 0 && TripCount % Count != 0) Count--; } @@ -457,7 +473,7 @@ bool LoopUnroll::runOnLoop(Loop *L, LPPassManager &LPM) { // the original count which satisfies the threshold limit. while (Count != 0 && UnrolledSize > PartialThreshold) { Count >>= 1; - UnrolledSize = LoopSize * Count; + UnrolledSize = (LoopSize-2) * Count + 2; } if (Count > UP.MaxCount) Count = UP.MaxCount; @@ -465,25 +481,26 @@ bool LoopUnroll::runOnLoop(Loop *L, LPPassManager &LPM) { } if (HasPragma) { - // Mark loop as unrolled to prevent unrolling beyond that - // requested by the pragma. - SetLoopAlreadyUnrolled(L); + if (PragmaCount != 0) + // If loop has an unroll count pragma mark loop as unrolled to prevent + // unrolling beyond that requested by the pragma. + SetLoopAlreadyUnrolled(L); // Emit optimization remarks if we are unable to unroll the loop // as directed by a pragma. DebugLoc LoopLoc = L->getStartLoc(); Function *F = Header->getParent(); LLVMContext &Ctx = F->getContext(); - if (HasEnablePragma && PragmaCount == 0) { + if (PragmaFullUnroll && PragmaCount == 0) { if (TripCount && Count != TripCount) { emitOptimizationRemarkMissed( Ctx, DEBUG_TYPE, *F, LoopLoc, - "Unable to fully unroll loop as directed by unroll(enable) pragma " + "Unable to fully unroll loop as directed by unroll(full) pragma " "because unrolled size is too large."); } else if (!TripCount) { emitOptimizationRemarkMissed( Ctx, DEBUG_TYPE, *F, LoopLoc, - "Unable to fully unroll loop as directed by unroll(enable) pragma " + "Unable to fully unroll loop as directed by unroll(full) pragma " "because loop has a runtime trip count."); } } else if (PragmaCount > 0 && Count != OriginalCount) { @@ -501,7 +518,8 @@ bool LoopUnroll::runOnLoop(Loop *L, LPPassManager &LPM) { } // Unroll the loop. - if (!UnrollLoop(L, Count, TripCount, AllowRuntime, TripMultiple, LI, this, &LPM)) + if (!UnrollLoop(L, Count, TripCount, AllowRuntime, TripMultiple, LI, this, + &LPM, &AC)) return false; return true; diff --git a/lib/Transforms/Scalar/LoopUnswitch.cpp b/lib/Transforms/Scalar/LoopUnswitch.cpp index 977c53a3bc63..9f4c12270d76 100644 --- a/lib/Transforms/Scalar/LoopUnswitch.cpp +++ b/lib/Transforms/Scalar/LoopUnswitch.cpp @@ -30,6 +30,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/CodeMetrics.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopInfo.h" @@ -103,7 +104,8 @@ namespace { // Analyze loop. Check its size, calculate is it possible to unswitch // it. Returns true if we can unswitch this loop. - bool countLoop(const Loop *L, const TargetTransformInfo &TTI); + bool countLoop(const Loop *L, const TargetTransformInfo &TTI, + AssumptionCache *AC); // Clean all data related to given loop. void forgetLoop(const Loop *L); @@ -126,6 +128,7 @@ namespace { class LoopUnswitch : public LoopPass { LoopInfo *LI; // Loop information LPPassManager *LPM; + AssumptionCache *AC; // LoopProcessWorklist - Used to check if second loop needs processing // after RewriteLoopBodyWithConditionConstant rewrites first loop. @@ -164,6 +167,7 @@ namespace { /// loop preheaders be inserted into the CFG. /// void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<AssumptionCacheTracker>(); AU.addRequiredID(LoopSimplifyID); AU.addPreservedID(LoopSimplifyID); AU.addRequired<LoopInfo>(); @@ -212,7 +216,8 @@ namespace { // Analyze loop. Check its size, calculate is it possible to unswitch // it. Returns true if we can unswitch this loop. -bool LUAnalysisCache::countLoop(const Loop *L, const TargetTransformInfo &TTI) { +bool LUAnalysisCache::countLoop(const Loop *L, const TargetTransformInfo &TTI, + AssumptionCache *AC) { LoopPropsMapIt PropsIt; bool Inserted; @@ -229,13 +234,16 @@ bool LUAnalysisCache::countLoop(const Loop *L, const TargetTransformInfo &TTI) { // large numbers of branches which cause loop unswitching to go crazy. // This is a very ad-hoc heuristic. + SmallPtrSet<const Value *, 32> EphValues; + CodeMetrics::collectEphemeralValues(L, AC, EphValues); + // FIXME: This is overly conservative because it does not take into // consideration code simplification opportunities and code that can // be shared by the resultant unswitched loops. CodeMetrics Metrics; for (Loop::block_iterator I = L->block_begin(), E = L->block_end(); I != E; ++I) - Metrics.analyzeBasicBlock(*I, TTI); + Metrics.analyzeBasicBlock(*I, TTI, EphValues); Props.SizeEstimation = std::min(Metrics.NumInsts, Metrics.NumBlocks * 5); Props.CanBeUnswitchedCount = MaxSize / (Props.SizeEstimation); @@ -326,6 +334,7 @@ char LoopUnswitch::ID = 0; INITIALIZE_PASS_BEGIN(LoopUnswitch, "loop-unswitch", "Unswitch loops", false, false) INITIALIZE_AG_DEPENDENCY(TargetTransformInfo) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(LoopSimplify) INITIALIZE_PASS_DEPENDENCY(LoopInfo) INITIALIZE_PASS_DEPENDENCY(LCSSA) @@ -376,6 +385,8 @@ bool LoopUnswitch::runOnLoop(Loop *L, LPPassManager &LPM_Ref) { if (skipOptnoneFunction(L)) return false; + AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache( + *L->getHeader()->getParent()); LI = &getAnalysis<LoopInfo>(); LPM = &LPM_Ref; DominatorTreeWrapperPass *DTWP = @@ -421,7 +432,8 @@ bool LoopUnswitch::processCurrentLoop() { // Probably we reach the quota of branches for this loop. If so // stop unswitching. - if (!BranchesInfo.countLoop(currentLoop, getAnalysis<TargetTransformInfo>())) + if (!BranchesInfo.countLoop(currentLoop, getAnalysis<TargetTransformInfo>(), + AC)) return false; // Loop over all of the basic blocks in the loop. If we find an interior @@ -823,6 +835,10 @@ void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val, F->getBasicBlockList().splice(NewPreheader, F->getBasicBlockList(), NewBlocks[0], F->end()); + // FIXME: We could register any cloned assumptions instead of clearing the + // whole function's cache. + AC->clear(); + // Now we create the new Loop object for the versioned loop. Loop *NewLoop = CloneLoop(L, L->getParentLoop(), VMap, LI, LPM); diff --git a/lib/Transforms/Scalar/MemCpyOptimizer.cpp b/lib/Transforms/Scalar/MemCpyOptimizer.cpp index 7c184a4ad2c3..33b5f9df5a27 100644 --- a/lib/Transforms/Scalar/MemCpyOptimizer.cpp +++ b/lib/Transforms/Scalar/MemCpyOptimizer.cpp @@ -16,6 +16,7 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/MemoryDependenceAnalysis.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/DataLayout.h" @@ -329,6 +330,7 @@ namespace { // This transformation requires dominator postdominator info void getAnalysisUsage(AnalysisUsage &AU) const override { AU.setPreservesCFG(); + AU.addRequired<AssumptionCacheTracker>(); AU.addRequired<DominatorTreeWrapperPass>(); AU.addRequired<MemoryDependenceAnalysis>(); AU.addRequired<AliasAnalysis>(); @@ -361,6 +363,7 @@ FunctionPass *llvm::createMemCpyOptPass() { return new MemCpyOpt(); } INITIALIZE_PASS_BEGIN(MemCpyOpt, "memcpyopt", "MemCpy Optimization", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(MemoryDependenceAnalysis) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfo) @@ -631,22 +634,24 @@ bool MemCpyOpt::performCallSlotOptzn(Instruction *cpy, if (destSize < srcSize) return false; } else if (Argument *A = dyn_cast<Argument>(cpyDest)) { - // If the destination is an sret parameter then only accesses that are - // outside of the returned struct type can trap. - if (!A->hasStructRetAttr()) - return false; + if (A->getDereferenceableBytes() < srcSize) { + // If the destination is an sret parameter then only accesses that are + // outside of the returned struct type can trap. + if (!A->hasStructRetAttr()) + return false; - Type *StructTy = cast<PointerType>(A->getType())->getElementType(); - if (!StructTy->isSized()) { - // The call may never return and hence the copy-instruction may never - // be executed, and therefore it's not safe to say "the destination - // has at least <cpyLen> bytes, as implied by the copy-instruction", - return false; - } + Type *StructTy = cast<PointerType>(A->getType())->getElementType(); + if (!StructTy->isSized()) { + // The call may never return and hence the copy-instruction may never + // be executed, and therefore it's not safe to say "the destination + // has at least <cpyLen> bytes, as implied by the copy-instruction", + return false; + } - uint64_t destSize = DL->getTypeAllocSize(StructTy); - if (destSize < srcSize) - return false; + uint64_t destSize = DL->getTypeAllocSize(StructTy); + if (destSize < srcSize) + return false; + } } else { return false; } @@ -673,15 +678,23 @@ bool MemCpyOpt::performCallSlotOptzn(Instruction *cpy, if (isa<BitCastInst>(U) || isa<AddrSpaceCastInst>(U)) { for (User *UU : U->users()) srcUseList.push_back(UU); - } else if (GetElementPtrInst *G = dyn_cast<GetElementPtrInst>(U)) { - if (G->hasAllZeroIndices()) - for (User *UU : U->users()) - srcUseList.push_back(UU); - else + continue; + } + if (GetElementPtrInst *G = dyn_cast<GetElementPtrInst>(U)) { + if (!G->hasAllZeroIndices()) return false; - } else if (U != C && U != cpy) { - return false; + + for (User *UU : U->users()) + srcUseList.push_back(UU); + continue; } + if (const IntrinsicInst *IT = dyn_cast<IntrinsicInst>(U)) + if (IT->getIntrinsicID() == Intrinsic::lifetime_start || + IT->getIntrinsicID() == Intrinsic::lifetime_end) + continue; + + if (U != C && U != cpy) + return false; } // Check that src isn't captured by the called function since the @@ -969,8 +982,13 @@ bool MemCpyOpt::processByValArgument(CallSite CS, unsigned ArgNo) { // If it is greater than the memcpy, then we check to see if we can force the // source of the memcpy to the alignment we need. If we fail, we bail out. + AssumptionCache &AC = + getAnalysis<AssumptionCacheTracker>().getAssumptionCache( + *CS->getParent()->getParent()); + DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); if (MDep->getAlignment() < ByValAlign && - getOrEnforceKnownAlignment(MDep->getSource(),ByValAlign, DL) < ByValAlign) + getOrEnforceKnownAlignment(MDep->getSource(), ByValAlign, DL, &AC, + CS.getInstruction(), &DT) < ByValAlign) return false; // Verify that the copied-from memory doesn't change in between the memcpy and diff --git a/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp b/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp index c2467fecb5eb..8509713b3367 100644 --- a/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp +++ b/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp @@ -97,9 +97,6 @@ using namespace llvm; //===----------------------------------------------------------------------===// // MergedLoadStoreMotion Pass //===----------------------------------------------------------------------===// -static cl::opt<bool> -EnableMLSM("mlsm", cl::desc("Enable motion of merged load and store"), - cl::init(true)); namespace { class MergedLoadStoreMotion : public FunctionPass { @@ -134,7 +131,9 @@ private: BasicBlock *getDiamondTail(BasicBlock *BB); bool isDiamondHead(BasicBlock *BB); // Routines for hoisting loads - bool isLoadHoistBarrier(Instruction *Inst); + bool isLoadHoistBarrierInRange(const Instruction& Start, + const Instruction& End, + LoadInst* LI); LoadInst *canHoistFromBlock(BasicBlock *BB, LoadInst *LI); void hoistInstruction(BasicBlock *BB, Instruction *HoistCand, Instruction *ElseInst); @@ -144,7 +143,9 @@ private: // Routines for sinking stores StoreInst *canSinkFromBlock(BasicBlock *BB, StoreInst *SI); PHINode *getPHIOperand(BasicBlock *BB, StoreInst *S0, StoreInst *S1); - bool isStoreSinkBarrier(Instruction *Inst); + bool isStoreSinkBarrierInRange(const Instruction& Start, + const Instruction& End, + AliasAnalysis::Location Loc); bool sinkStore(BasicBlock *BB, StoreInst *SinkCand, StoreInst *ElseInst); bool mergeStores(BasicBlock *BB); // The mergeLoad/Store algorithms could have Size0 * Size1 complexity, @@ -235,27 +236,12 @@ bool MergedLoadStoreMotion::isDiamondHead(BasicBlock *BB) { /// being loaded or protect against the load from happening /// it is considered a hoist barrier. /// -bool MergedLoadStoreMotion::isLoadHoistBarrier(Instruction *Inst) { - // FIXME: A call with no side effects should not be a barrier. - // Aren't all such calls covered by mayHaveSideEffects() below? - // Then this check can be removed. - if (isa<CallInst>(Inst)) - return true; - if (isa<TerminatorInst>(Inst)) - return true; - // FIXME: Conservatively let a store instruction block the load. - // Use alias analysis instead. - if (isa<StoreInst>(Inst)) - return true; - // Note: mayHaveSideEffects covers all instructions that could - // trigger a change to state. Eg. in-flight stores have to be executed - // before ordered loads or fences, calls could invoke functions that store - // data to memory etc. - if (Inst->mayHaveSideEffects()) { - return true; - } - DEBUG(dbgs() << "No Hoist Barrier\n"); - return false; + +bool MergedLoadStoreMotion::isLoadHoistBarrierInRange(const Instruction& Start, + const Instruction& End, + LoadInst* LI) { + AliasAnalysis::Location Loc = AA->getLocation(LI); + return AA->canInstructionRangeModRef(Start, End, Loc, AliasAnalysis::Mod); } /// @@ -265,33 +251,29 @@ bool MergedLoadStoreMotion::isLoadHoistBarrier(Instruction *Inst) { /// and it can be hoisted from \p BB, return that load. /// Otherwise return Null. /// -LoadInst *MergedLoadStoreMotion::canHoistFromBlock(BasicBlock *BB, - LoadInst *LI) { - LoadInst *I = nullptr; - assert(isa<LoadInst>(LI)); - if (LI->isUsedOutsideOfBlock(LI->getParent())) - return nullptr; +LoadInst *MergedLoadStoreMotion::canHoistFromBlock(BasicBlock *BB1, + LoadInst *Load0) { - for (BasicBlock::iterator BBI = BB->begin(), BBE = BB->end(); BBI != BBE; + for (BasicBlock::iterator BBI = BB1->begin(), BBE = BB1->end(); BBI != BBE; ++BBI) { Instruction *Inst = BBI; // Only merge and hoist loads when their result in used only in BB - if (isLoadHoistBarrier(Inst)) - break; - if (!isa<LoadInst>(Inst)) - continue; - if (Inst->isUsedOutsideOfBlock(Inst->getParent())) + if (!isa<LoadInst>(Inst) || Inst->isUsedOutsideOfBlock(BB1)) continue; - AliasAnalysis::Location LocLI = AA->getLocation(LI); - AliasAnalysis::Location LocInst = AA->getLocation((LoadInst *)Inst); - if (AA->isMustAlias(LocLI, LocInst) && LI->getType() == Inst->getType()) { - I = (LoadInst *)Inst; - break; + LoadInst *Load1 = dyn_cast<LoadInst>(Inst); + BasicBlock *BB0 = Load0->getParent(); + + AliasAnalysis::Location Loc0 = AA->getLocation(Load0); + AliasAnalysis::Location Loc1 = AA->getLocation(Load1); + if (AA->isMustAlias(Loc0, Loc1) && Load0->isSameOperationAs(Load1) && + !isLoadHoistBarrierInRange(BB1->front(), *Load1, Load1) && + !isLoadHoistBarrierInRange(BB0->front(), *Load0, Load0)) { + return Load1; } } - return I; + return nullptr; } /// @@ -388,15 +370,10 @@ bool MergedLoadStoreMotion::mergeLoads(BasicBlock *BB) { Instruction *I = BBI; ++BBI; - if (isLoadHoistBarrier(I)) - break; // Only move non-simple (atomic, volatile) loads. - if (!isa<LoadInst>(I)) - continue; - - LoadInst *L0 = (LoadInst *)I; - if (!L0->isSimple()) + LoadInst *L0 = dyn_cast<LoadInst>(I); + if (!L0 || !L0->isSimple() || L0->isUsedOutsideOfBlock(Succ0)) continue; ++NLoads; @@ -414,26 +391,19 @@ bool MergedLoadStoreMotion::mergeLoads(BasicBlock *BB) { } /// -/// \brief True when instruction is sink barrier for a store -/// -bool MergedLoadStoreMotion::isStoreSinkBarrier(Instruction *Inst) { - // FIXME: Conservatively let a load instruction block the store. - // Use alias analysis instead. - if (isa<LoadInst>(Inst)) - return true; - if (isa<CallInst>(Inst)) - return true; - if (isa<TerminatorInst>(Inst) && !isa<BranchInst>(Inst)) - return true; - // Note: mayHaveSideEffects covers all instructions that could - // trigger a change to state. Eg. in-flight stores have to be executed - // before ordered loads or fences, calls could invoke functions that store - // data to memory etc. - if (!isa<StoreInst>(Inst) && Inst->mayHaveSideEffects()) { - return true; - } - DEBUG(dbgs() << "No Sink Barrier\n"); - return false; +/// \brief True when instruction is a sink barrier for a store +/// located in Loc +/// +/// Whenever an instruction could possibly read or modify the +/// value being stored or protect against the store from +/// happening it is considered a sink barrier. +/// + +bool MergedLoadStoreMotion::isStoreSinkBarrierInRange(const Instruction& Start, + const Instruction& End, + AliasAnalysis::Location + Loc) { + return AA->canInstructionRangeModRef(Start, End, Loc, AliasAnalysis::Ref); } /// @@ -441,27 +411,28 @@ bool MergedLoadStoreMotion::isStoreSinkBarrier(Instruction *Inst) { /// /// \return The store in \p when it is safe to sink. Otherwise return Null. /// -StoreInst *MergedLoadStoreMotion::canSinkFromBlock(BasicBlock *BB, - StoreInst *SI) { - StoreInst *I = 0; - DEBUG(dbgs() << "can Sink? : "; SI->dump(); dbgs() << "\n"); - for (BasicBlock::reverse_iterator RBI = BB->rbegin(), RBE = BB->rend(); +StoreInst *MergedLoadStoreMotion::canSinkFromBlock(BasicBlock *BB1, + StoreInst *Store0) { + DEBUG(dbgs() << "can Sink? : "; Store0->dump(); dbgs() << "\n"); + for (BasicBlock::reverse_iterator RBI = BB1->rbegin(), RBE = BB1->rend(); RBI != RBE; ++RBI) { Instruction *Inst = &*RBI; - // Only move loads if they are used in the block. - if (isStoreSinkBarrier(Inst)) - break; - if (isa<StoreInst>(Inst)) { - AliasAnalysis::Location LocSI = AA->getLocation(SI); - AliasAnalysis::Location LocInst = AA->getLocation((StoreInst *)Inst); - if (AA->isMustAlias(LocSI, LocInst)) { - I = (StoreInst *)Inst; - break; - } + if (!isa<StoreInst>(Inst)) + continue; + + StoreInst *Store1 = cast<StoreInst>(Inst); + BasicBlock *BB0 = Store0->getParent(); + + AliasAnalysis::Location Loc0 = AA->getLocation(Store0); + AliasAnalysis::Location Loc1 = AA->getLocation(Store1); + if (AA->isMustAlias(Loc0, Loc1) && Store0->isSameOperationAs(Store1) && + !isStoreSinkBarrierInRange(*Store1, BB1->back(), Loc1) && + !isStoreSinkBarrierInRange(*Store0, BB0->back(), Loc0)) { + return Store1; } } - return I; + return nullptr; } /// @@ -573,8 +544,7 @@ bool MergedLoadStoreMotion::mergeStores(BasicBlock *T) { Instruction *I = &*RBI; ++RBI; - if (isStoreSinkBarrier(I)) - break; + // Sink move non-simple (atomic, volatile) stores if (!isa<StoreInst>(I)) continue; @@ -611,8 +581,6 @@ bool MergedLoadStoreMotion::runOnFunction(Function &F) { AA = &getAnalysis<AliasAnalysis>(); bool Changed = false; - if (!EnableMLSM) - return false; DEBUG(dbgs() << "Instruction Merger\n"); // Merge unconditional branches, allowing PRE to catch more @@ -622,7 +590,6 @@ bool MergedLoadStoreMotion::runOnFunction(Function &F) { // Hoist equivalent loads and sink stores // outside diamonds when possible - // Run outside core GVN if (isDiamondHead(BB)) { Changed |= mergeLoads(BB); Changed |= mergeStores(getDiamondTail(BB)); diff --git a/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp b/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp index 7cce89e0627e..5c8bed585b64 100644 --- a/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp +++ b/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp @@ -108,6 +108,10 @@ bool PartiallyInlineLibCalls::optimizeSQRT(CallInst *Call, if (Call->onlyReadsMemory()) return false; + // The call must have the expected result type. + if (!Call->getType()->isFloatingPointTy()) + return false; + // Do the following transformation: // // (before) diff --git a/lib/Transforms/Scalar/Reassociate.cpp b/lib/Transforms/Scalar/Reassociate.cpp index ea2cf7cf9b5f..4e022556f9cc 100644 --- a/lib/Transforms/Scalar/Reassociate.cpp +++ b/lib/Transforms/Scalar/Reassociate.cpp @@ -176,6 +176,7 @@ namespace { private: void BuildRankMap(Function &F); unsigned getRank(Value *V); + void canonicalizeOperands(Instruction *I); void ReassociateExpression(BinaryOperator *I); void RewriteExprTree(BinaryOperator *I, SmallVectorImpl<ValueEntry> &Ops); Value *OptimizeExpression(BinaryOperator *I, @@ -194,6 +195,7 @@ namespace { Value *RemoveFactorFromExpression(Value *V, Value *Factor); void EraseInst(Instruction *I); void OptimizeInst(Instruction *I); + Instruction *canonicalizeNegConstExpr(Instruction *I); }; } @@ -235,7 +237,20 @@ FunctionPass *llvm::createReassociatePass() { return new Reassociate(); } /// opcode and if it only has one use. static BinaryOperator *isReassociableOp(Value *V, unsigned Opcode) { if (V->hasOneUse() && isa<Instruction>(V) && - cast<Instruction>(V)->getOpcode() == Opcode) + cast<Instruction>(V)->getOpcode() == Opcode && + (!isa<FPMathOperator>(V) || + cast<Instruction>(V)->hasUnsafeAlgebra())) + return cast<BinaryOperator>(V); + return nullptr; +} + +static BinaryOperator *isReassociableOp(Value *V, unsigned Opcode1, + unsigned Opcode2) { + if (V->hasOneUse() && isa<Instruction>(V) && + (cast<Instruction>(V)->getOpcode() == Opcode1 || + cast<Instruction>(V)->getOpcode() == Opcode2) && + (!isa<FPMathOperator>(V) || + cast<Instruction>(V)->hasUnsafeAlgebra())) return cast<BinaryOperator>(V); return nullptr; } @@ -264,9 +279,11 @@ static bool isUnmovableInstruction(Instruction *I) { void Reassociate::BuildRankMap(Function &F) { unsigned i = 2; - // Assign distinct ranks to function arguments - for (Function::arg_iterator I = F.arg_begin(), E = F.arg_end(); I != E; ++I) + // Assign distinct ranks to function arguments. + for (Function::arg_iterator I = F.arg_begin(), E = F.arg_end(); I != E; ++I) { ValueRankMap[&*I] = ++i; + DEBUG(dbgs() << "Calculated Rank[" << I->getName() << "] = " << i << "\n"); + } ReversePostOrderTraversal<Function*> RPOT(&F); for (ReversePostOrderTraversal<Function*>::rpo_iterator I = RPOT.begin(), @@ -304,24 +321,78 @@ unsigned Reassociate::getRank(Value *V) { // If this is a not or neg instruction, do not count it for rank. This // assures us that X and ~X will have the same rank. - if (!I->getType()->isIntegerTy() || - (!BinaryOperator::isNot(I) && !BinaryOperator::isNeg(I))) + Type *Ty = V->getType(); + if ((!Ty->isIntegerTy() && !Ty->isFloatingPointTy()) || + (!BinaryOperator::isNot(I) && !BinaryOperator::isNeg(I) && + !BinaryOperator::isFNeg(I))) ++Rank; - //DEBUG(dbgs() << "Calculated Rank[" << V->getName() << "] = " - // << Rank << "\n"); + DEBUG(dbgs() << "Calculated Rank[" << V->getName() << "] = " << Rank << "\n"); return ValueRankMap[I] = Rank; } +// Canonicalize constants to RHS. Otherwise, sort the operands by rank. +void Reassociate::canonicalizeOperands(Instruction *I) { + assert(isa<BinaryOperator>(I) && "Expected binary operator."); + assert(I->isCommutative() && "Expected commutative operator."); + + Value *LHS = I->getOperand(0); + Value *RHS = I->getOperand(1); + unsigned LHSRank = getRank(LHS); + unsigned RHSRank = getRank(RHS); + + if (isa<Constant>(RHS)) + return; + + if (isa<Constant>(LHS) || RHSRank < LHSRank) + cast<BinaryOperator>(I)->swapOperands(); +} + +static BinaryOperator *CreateAdd(Value *S1, Value *S2, const Twine &Name, + Instruction *InsertBefore, Value *FlagsOp) { + if (S1->getType()->isIntegerTy()) + return BinaryOperator::CreateAdd(S1, S2, Name, InsertBefore); + else { + BinaryOperator *Res = + BinaryOperator::CreateFAdd(S1, S2, Name, InsertBefore); + Res->setFastMathFlags(cast<FPMathOperator>(FlagsOp)->getFastMathFlags()); + return Res; + } +} + +static BinaryOperator *CreateMul(Value *S1, Value *S2, const Twine &Name, + Instruction *InsertBefore, Value *FlagsOp) { + if (S1->getType()->isIntegerTy()) + return BinaryOperator::CreateMul(S1, S2, Name, InsertBefore); + else { + BinaryOperator *Res = + BinaryOperator::CreateFMul(S1, S2, Name, InsertBefore); + Res->setFastMathFlags(cast<FPMathOperator>(FlagsOp)->getFastMathFlags()); + return Res; + } +} + +static BinaryOperator *CreateNeg(Value *S1, const Twine &Name, + Instruction *InsertBefore, Value *FlagsOp) { + if (S1->getType()->isIntegerTy()) + return BinaryOperator::CreateNeg(S1, Name, InsertBefore); + else { + BinaryOperator *Res = BinaryOperator::CreateFNeg(S1, Name, InsertBefore); + Res->setFastMathFlags(cast<FPMathOperator>(FlagsOp)->getFastMathFlags()); + return Res; + } +} + /// LowerNegateToMultiply - Replace 0-X with X*-1. /// static BinaryOperator *LowerNegateToMultiply(Instruction *Neg) { - Constant *Cst = Constant::getAllOnesValue(Neg->getType()); + Type *Ty = Neg->getType(); + Constant *NegOne = Ty->isIntegerTy() ? ConstantInt::getAllOnesValue(Ty) + : ConstantFP::get(Ty, -1.0); - BinaryOperator *Res = - BinaryOperator::CreateMul(Neg->getOperand(1), Cst, "",Neg); - Neg->setOperand(1, Constant::getNullValue(Neg->getType())); // Drop use of op. + BinaryOperator *Res = CreateMul(Neg->getOperand(1), NegOne, "", Neg, Neg); + Neg->setOperand(1, Constant::getNullValue(Ty)); // Drop use of op. Res->takeName(Neg); Neg->replaceAllUsesWith(Res); Res->setDebugLoc(Neg->getDebugLoc()); @@ -377,13 +448,14 @@ static void IncorporateWeight(APInt &LHS, const APInt &RHS, unsigned Opcode) { LHS = 0; // 1 + 1 === 0 modulo 2. return; } - if (Opcode == Instruction::Add) { + if (Opcode == Instruction::Add || Opcode == Instruction::FAdd) { // TODO: Reduce the weight by exploiting nsw/nuw? LHS += RHS; return; } - assert(Opcode == Instruction::Mul && "Unknown associative operation!"); + assert((Opcode == Instruction::Mul || Opcode == Instruction::FMul) && + "Unknown associative operation!"); unsigned Bitwidth = LHS.getBitWidth(); // If CM is the Carmichael number then a weight W satisfying W >= CM+Bitwidth // can be replaced with W-CM. That's because x^W=x^(W-CM) for every Bitwidth @@ -499,8 +571,7 @@ static bool LinearizeExprTree(BinaryOperator *I, DEBUG(dbgs() << "LINEARIZE: " << *I << '\n'); unsigned Bitwidth = I->getType()->getScalarType()->getPrimitiveSizeInBits(); unsigned Opcode = I->getOpcode(); - assert(Instruction::isAssociative(Opcode) && - Instruction::isCommutative(Opcode) && + assert(I->isAssociative() && I->isCommutative() && "Expected an associative and commutative operation!"); // Visit all operands of the expression, keeping track of their weight (the @@ -515,7 +586,7 @@ static bool LinearizeExprTree(BinaryOperator *I, // ways to get to it. SmallVector<std::pair<BinaryOperator*, APInt>, 8> Worklist; // (Op, Weight) Worklist.push_back(std::make_pair(I, APInt(Bitwidth, 1))); - bool MadeChange = false; + bool Changed = false; // Leaves of the expression are values that either aren't the right kind of // operation (eg: a constant, or a multiply in an add tree), or are, but have @@ -552,7 +623,7 @@ static bool LinearizeExprTree(BinaryOperator *I, // If this is a binary operation of the right kind with only one use then // add its operands to the expression. if (BinaryOperator *BO = isReassociableOp(Op, Opcode)) { - assert(Visited.insert(Op) && "Not first visit!"); + assert(Visited.insert(Op).second && "Not first visit!"); DEBUG(dbgs() << "DIRECT ADD: " << *Op << " (" << Weight << ")\n"); Worklist.push_back(std::make_pair(BO, Weight)); continue; @@ -562,7 +633,7 @@ static bool LinearizeExprTree(BinaryOperator *I, LeafMap::iterator It = Leaves.find(Op); if (It == Leaves.end()) { // Not in the leaf map. Must be the first time we saw this operand. - assert(Visited.insert(Op) && "Not first visit!"); + assert(Visited.insert(Op).second && "Not first visit!"); if (!Op->hasOneUse()) { // This value has uses not accounted for by the expression, so it is // not safe to modify. Mark it as being a leaf. @@ -584,7 +655,7 @@ static bool LinearizeExprTree(BinaryOperator *I, // exactly one such use, drop this new use of the leaf. assert(!Op->hasOneUse() && "Only one use, but we got here twice!"); I->setOperand(OpIdx, UndefValue::get(I->getType())); - MadeChange = true; + Changed = true; // If the leaf is a binary operation of the right kind and we now see // that its multiple original uses were in fact all by nodes belonging @@ -613,21 +684,24 @@ static bool LinearizeExprTree(BinaryOperator *I, // expression. This means that it can safely be modified. See if we // can usefully morph it into an expression of the right kind. assert((!isa<Instruction>(Op) || - cast<Instruction>(Op)->getOpcode() != Opcode) && + cast<Instruction>(Op)->getOpcode() != Opcode + || (isa<FPMathOperator>(Op) && + !cast<Instruction>(Op)->hasUnsafeAlgebra())) && "Should have been handled above!"); assert(Op->hasOneUse() && "Has uses outside the expression tree!"); // If this is a multiply expression, turn any internal negations into // multiplies by -1 so they can be reassociated. - BinaryOperator *BO = dyn_cast<BinaryOperator>(Op); - if (Opcode == Instruction::Mul && BO && BinaryOperator::isNeg(BO)) { - DEBUG(dbgs() << "MORPH LEAF: " << *Op << " (" << Weight << ") TO "); - BO = LowerNegateToMultiply(BO); - DEBUG(dbgs() << *BO << 'n'); - Worklist.push_back(std::make_pair(BO, Weight)); - MadeChange = true; - continue; - } + if (BinaryOperator *BO = dyn_cast<BinaryOperator>(Op)) + if ((Opcode == Instruction::Mul && BinaryOperator::isNeg(BO)) || + (Opcode == Instruction::FMul && BinaryOperator::isFNeg(BO))) { + DEBUG(dbgs() << "MORPH LEAF: " << *Op << " (" << Weight << ") TO "); + BO = LowerNegateToMultiply(BO); + DEBUG(dbgs() << *BO << '\n'); + Worklist.push_back(std::make_pair(BO, Weight)); + Changed = true; + continue; + } // Failed to morph into an expression of the right type. This really is // a leaf. @@ -665,7 +739,7 @@ static bool LinearizeExprTree(BinaryOperator *I, Ops.push_back(std::make_pair(Identity, APInt(Bitwidth, 1))); } - return MadeChange; + return Changed; } // RewriteExprTree - Now that the operands for this expression tree are @@ -798,6 +872,8 @@ void Reassociate::RewriteExprTree(BinaryOperator *I, Constant *Undef = UndefValue::get(I->getType()); NewOp = BinaryOperator::Create(Instruction::BinaryOps(Opcode), Undef, Undef, "", I); + if (NewOp->getType()->isFloatingPointTy()) + NewOp->setFastMathFlags(I->getFastMathFlags()); } else { NewOp = NodesToRewrite.pop_back_val(); } @@ -817,7 +893,14 @@ void Reassociate::RewriteExprTree(BinaryOperator *I, // expression tree is dominated by all of Ops. if (ExpressionChanged) do { - ExpressionChanged->clearSubclassOptionalData(); + // Preserve FastMathFlags. + if (isa<FPMathOperator>(I)) { + FastMathFlags Flags = I->getFastMathFlags(); + ExpressionChanged->clearSubclassOptionalData(); + ExpressionChanged->setFastMathFlags(Flags); + } else + ExpressionChanged->clearSubclassOptionalData(); + if (ExpressionChanged == I) break; ExpressionChanged->moveBefore(I); @@ -834,6 +917,8 @@ void Reassociate::RewriteExprTree(BinaryOperator *I, /// version of the value is returned, and BI is left pointing at the instruction /// that should be processed next by the reassociation pass. static Value *NegateValue(Value *V, Instruction *BI) { + if (ConstantFP *C = dyn_cast<ConstantFP>(V)) + return ConstantExpr::getFNeg(C); if (Constant *C = dyn_cast<Constant>(V)) return ConstantExpr::getNeg(C); @@ -846,7 +931,8 @@ static Value *NegateValue(Value *V, Instruction *BI) { // the constants. We assume that instcombine will clean up the mess later if // we introduce tons of unnecessary negation instructions. // - if (BinaryOperator *I = isReassociableOp(V, Instruction::Add)) { + if (BinaryOperator *I = + isReassociableOp(V, Instruction::Add, Instruction::FAdd)) { // Push the negates through the add. I->setOperand(0, NegateValue(I->getOperand(0), BI)); I->setOperand(1, NegateValue(I->getOperand(1), BI)); @@ -864,7 +950,8 @@ static Value *NegateValue(Value *V, Instruction *BI) { // Okay, we need to materialize a negated version of V with an instruction. // Scan the use lists of V to see if we have one already. for (User *U : V->users()) { - if (!BinaryOperator::isNeg(U)) continue; + if (!BinaryOperator::isNeg(U) && !BinaryOperator::isFNeg(U)) + continue; // We found one! Now we have to make sure that the definition dominates // this use. We do this by moving it to the entry block (if it is a @@ -894,27 +981,34 @@ static Value *NegateValue(Value *V, Instruction *BI) { // Insert a 'neg' instruction that subtracts the value from zero to get the // negation. - return BinaryOperator::CreateNeg(V, V->getName() + ".neg", BI); + return CreateNeg(V, V->getName() + ".neg", BI, BI); } /// ShouldBreakUpSubtract - Return true if we should break up this subtract of /// X-Y into (X + -Y). static bool ShouldBreakUpSubtract(Instruction *Sub) { // If this is a negation, we can't split it up! - if (BinaryOperator::isNeg(Sub)) + if (BinaryOperator::isNeg(Sub) || BinaryOperator::isFNeg(Sub)) + return false; + + // Don't breakup X - undef. + if (isa<UndefValue>(Sub->getOperand(1))) return false; // Don't bother to break this up unless either the LHS is an associable add or // subtract or if this is only used by one. - if (isReassociableOp(Sub->getOperand(0), Instruction::Add) || - isReassociableOp(Sub->getOperand(0), Instruction::Sub)) + Value *V0 = Sub->getOperand(0); + if (isReassociableOp(V0, Instruction::Add, Instruction::FAdd) || + isReassociableOp(V0, Instruction::Sub, Instruction::FSub)) return true; - if (isReassociableOp(Sub->getOperand(1), Instruction::Add) || - isReassociableOp(Sub->getOperand(1), Instruction::Sub)) + Value *V1 = Sub->getOperand(1); + if (isReassociableOp(V1, Instruction::Add, Instruction::FAdd) || + isReassociableOp(V1, Instruction::Sub, Instruction::FSub)) return true; + Value *VB = Sub->user_back(); if (Sub->hasOneUse() && - (isReassociableOp(Sub->user_back(), Instruction::Add) || - isReassociableOp(Sub->user_back(), Instruction::Sub))) + (isReassociableOp(VB, Instruction::Add, Instruction::FAdd) || + isReassociableOp(VB, Instruction::Sub, Instruction::FSub))) return true; return false; @@ -931,8 +1025,7 @@ static BinaryOperator *BreakUpSubtract(Instruction *Sub) { // and set it as the RHS of the add instruction we just made. // Value *NegVal = NegateValue(Sub->getOperand(1), Sub); - BinaryOperator *New = - BinaryOperator::CreateAdd(Sub->getOperand(0), NegVal, "", Sub); + BinaryOperator *New = CreateAdd(Sub->getOperand(0), NegVal, "", Sub, Sub); Sub->setOperand(0, Constant::getNullValue(Sub->getType())); // Drop use of op. Sub->setOperand(1, Constant::getNullValue(Sub->getType())); // Drop use of op. New->takeName(Sub); @@ -956,8 +1049,19 @@ static BinaryOperator *ConvertShiftToMul(Instruction *Shl) { BinaryOperator::CreateMul(Shl->getOperand(0), MulCst, "", Shl); Shl->setOperand(0, UndefValue::get(Shl->getType())); // Drop use of op. Mul->takeName(Shl); + + // Everyone now refers to the mul instruction. Shl->replaceAllUsesWith(Mul); Mul->setDebugLoc(Shl->getDebugLoc()); + + // We can safely preserve the nuw flag in all cases. It's also safe to turn a + // nuw nsw shl into a nuw nsw mul. However, nsw in isolation requires special + // handling. + bool NSW = cast<BinaryOperator>(Shl)->hasNoSignedWrap(); + bool NUW = cast<BinaryOperator>(Shl)->hasNoUnsignedWrap(); + if (NSW && NUW) + Mul->setHasNoSignedWrap(true); + Mul->setHasNoUnsignedWrap(NUW); return Mul; } @@ -969,13 +1073,23 @@ static unsigned FindInOperandList(SmallVectorImpl<ValueEntry> &Ops, unsigned i, Value *X) { unsigned XRank = Ops[i].Rank; unsigned e = Ops.size(); - for (unsigned j = i+1; j != e && Ops[j].Rank == XRank; ++j) + for (unsigned j = i+1; j != e && Ops[j].Rank == XRank; ++j) { if (Ops[j].Op == X) return j; + if (Instruction *I1 = dyn_cast<Instruction>(Ops[j].Op)) + if (Instruction *I2 = dyn_cast<Instruction>(X)) + if (I1->isIdenticalTo(I2)) + return j; + } // Scan backwards. - for (unsigned j = i-1; j != ~0U && Ops[j].Rank == XRank; --j) + for (unsigned j = i-1; j != ~0U && Ops[j].Rank == XRank; --j) { if (Ops[j].Op == X) return j; + if (Instruction *I1 = dyn_cast<Instruction>(Ops[j].Op)) + if (Instruction *I2 = dyn_cast<Instruction>(X)) + if (I1->isIdenticalTo(I2)) + return j; + } return i; } @@ -988,15 +1102,16 @@ static Value *EmitAddTreeOfValues(Instruction *I, Value *V1 = Ops.back(); Ops.pop_back(); Value *V2 = EmitAddTreeOfValues(I, Ops); - return BinaryOperator::CreateAdd(V2, V1, "tmp", I); + return CreateAdd(V2, V1, "tmp", I, I); } /// RemoveFactorFromExpression - If V is an expression tree that is a /// multiplication sequence, and if this sequence contains a multiply by Factor, /// remove Factor from the tree and return the new tree. Value *Reassociate::RemoveFactorFromExpression(Value *V, Value *Factor) { - BinaryOperator *BO = isReassociableOp(V, Instruction::Mul); - if (!BO) return nullptr; + BinaryOperator *BO = isReassociableOp(V, Instruction::Mul, Instruction::FMul); + if (!BO) + return nullptr; SmallVector<RepeatedValue, 8> Tree; MadeChange |= LinearizeExprTree(BO, Tree); @@ -1018,13 +1133,25 @@ Value *Reassociate::RemoveFactorFromExpression(Value *V, Value *Factor) { } // If this is a negative version of this factor, remove it. - if (ConstantInt *FC1 = dyn_cast<ConstantInt>(Factor)) + if (ConstantInt *FC1 = dyn_cast<ConstantInt>(Factor)) { if (ConstantInt *FC2 = dyn_cast<ConstantInt>(Factors[i].Op)) if (FC1->getValue() == -FC2->getValue()) { FoundFactor = NeedsNegate = true; Factors.erase(Factors.begin()+i); break; } + } else if (ConstantFP *FC1 = dyn_cast<ConstantFP>(Factor)) { + if (ConstantFP *FC2 = dyn_cast<ConstantFP>(Factors[i].Op)) { + APFloat F1(FC1->getValueAPF()); + APFloat F2(FC2->getValueAPF()); + F2.changeSign(); + if (F1.compare(F2) == APFloat::cmpEqual) { + FoundFactor = NeedsNegate = true; + Factors.erase(Factors.begin() + i); + break; + } + } + } } if (!FoundFactor) { @@ -1046,7 +1173,7 @@ Value *Reassociate::RemoveFactorFromExpression(Value *V, Value *Factor) { } if (NeedsNegate) - V = BinaryOperator::CreateNeg(V, "neg", InsertPt); + V = CreateNeg(V, "neg", InsertPt, BO); return V; } @@ -1058,7 +1185,7 @@ Value *Reassociate::RemoveFactorFromExpression(Value *V, Value *Factor) { static void FindSingleUseMultiplyFactors(Value *V, SmallVectorImpl<Value*> &Factors, const SmallVectorImpl<ValueEntry> &Ops) { - BinaryOperator *BO = isReassociableOp(V, Instruction::Mul); + BinaryOperator *BO = isReassociableOp(V, Instruction::Mul, Instruction::FMul); if (!BO) { Factors.push_back(V); return; @@ -1385,17 +1512,19 @@ Value *Reassociate::OptimizeAdd(Instruction *I, ++NumFound; } while (i != Ops.size() && Ops[i].Op == TheOp); - DEBUG(errs() << "\nFACTORING [" << NumFound << "]: " << *TheOp << '\n'); + DEBUG(dbgs() << "\nFACTORING [" << NumFound << "]: " << *TheOp << '\n'); ++NumFactor; // Insert a new multiply. - Value *Mul = ConstantInt::get(cast<IntegerType>(I->getType()), NumFound); - Mul = BinaryOperator::CreateMul(TheOp, Mul, "factor", I); + Type *Ty = TheOp->getType(); + Constant *C = Ty->isIntegerTy() ? ConstantInt::get(Ty, NumFound) + : ConstantFP::get(Ty, NumFound); + Instruction *Mul = CreateMul(TheOp, C, "factor", I, I); // Now that we have inserted a multiply, optimize it. This allows us to // handle cases that require multiple factoring steps, such as this: // (X*2) + (X*2) + (X*2) -> (X*2)*3 -> X*6 - RedoInsts.insert(cast<Instruction>(Mul)); + RedoInsts.insert(Mul); // If every add operand was a duplicate, return the multiply. if (Ops.empty()) @@ -1412,11 +1541,12 @@ Value *Reassociate::OptimizeAdd(Instruction *I, } // Check for X and -X or X and ~X in the operand list. - if (!BinaryOperator::isNeg(TheOp) && !BinaryOperator::isNot(TheOp)) + if (!BinaryOperator::isNeg(TheOp) && !BinaryOperator::isFNeg(TheOp) && + !BinaryOperator::isNot(TheOp)) continue; Value *X = nullptr; - if (BinaryOperator::isNeg(TheOp)) + if (BinaryOperator::isNeg(TheOp) || BinaryOperator::isFNeg(TheOp)) X = BinaryOperator::getNegArgument(TheOp); else if (BinaryOperator::isNot(TheOp)) X = BinaryOperator::getNotArgument(TheOp); @@ -1426,7 +1556,8 @@ Value *Reassociate::OptimizeAdd(Instruction *I, continue; // Remove X and -X from the operand list. - if (Ops.size() == 2 && BinaryOperator::isNeg(TheOp)) + if (Ops.size() == 2 && + (BinaryOperator::isNeg(TheOp) || BinaryOperator::isFNeg(TheOp))) return Constant::getNullValue(X->getType()); // Remove X and ~X from the operand list. @@ -1463,7 +1594,8 @@ Value *Reassociate::OptimizeAdd(Instruction *I, unsigned MaxOcc = 0; Value *MaxOccVal = nullptr; for (unsigned i = 0, e = Ops.size(); i != e; ++i) { - BinaryOperator *BOp = isReassociableOp(Ops[i].Op, Instruction::Mul); + BinaryOperator *BOp = + isReassociableOp(Ops[i].Op, Instruction::Mul, Instruction::FMul); if (!BOp) continue; @@ -1476,40 +1608,65 @@ Value *Reassociate::OptimizeAdd(Instruction *I, SmallPtrSet<Value*, 8> Duplicates; for (unsigned i = 0, e = Factors.size(); i != e; ++i) { Value *Factor = Factors[i]; - if (!Duplicates.insert(Factor)) continue; + if (!Duplicates.insert(Factor).second) + continue; unsigned Occ = ++FactorOccurrences[Factor]; - if (Occ > MaxOcc) { MaxOcc = Occ; MaxOccVal = Factor; } + if (Occ > MaxOcc) { + MaxOcc = Occ; + MaxOccVal = Factor; + } // If Factor is a negative constant, add the negated value as a factor // because we can percolate the negate out. Watch for minint, which // cannot be positivified. - if (ConstantInt *CI = dyn_cast<ConstantInt>(Factor)) + if (ConstantInt *CI = dyn_cast<ConstantInt>(Factor)) { if (CI->isNegative() && !CI->isMinValue(true)) { Factor = ConstantInt::get(CI->getContext(), -CI->getValue()); assert(!Duplicates.count(Factor) && "Shouldn't have two constant factors, missed a canonicalize"); - unsigned Occ = ++FactorOccurrences[Factor]; - if (Occ > MaxOcc) { MaxOcc = Occ; MaxOccVal = Factor; } + if (Occ > MaxOcc) { + MaxOcc = Occ; + MaxOccVal = Factor; + } + } + } else if (ConstantFP *CF = dyn_cast<ConstantFP>(Factor)) { + if (CF->isNegative()) { + APFloat F(CF->getValueAPF()); + F.changeSign(); + Factor = ConstantFP::get(CF->getContext(), F); + assert(!Duplicates.count(Factor) && + "Shouldn't have two constant factors, missed a canonicalize"); + unsigned Occ = ++FactorOccurrences[Factor]; + if (Occ > MaxOcc) { + MaxOcc = Occ; + MaxOccVal = Factor; + } } + } } } // If any factor occurred more than one time, we can pull it out. if (MaxOcc > 1) { - DEBUG(errs() << "\nFACTORING [" << MaxOcc << "]: " << *MaxOccVal << '\n'); + DEBUG(dbgs() << "\nFACTORING [" << MaxOcc << "]: " << *MaxOccVal << '\n'); ++NumFactor; // Create a new instruction that uses the MaxOccVal twice. If we don't do // this, we could otherwise run into situations where removing a factor // from an expression will drop a use of maxocc, and this can cause // RemoveFactorFromExpression on successive values to behave differently. - Instruction *DummyInst = BinaryOperator::CreateAdd(MaxOccVal, MaxOccVal); + Instruction *DummyInst = + I->getType()->isIntegerTy() + ? BinaryOperator::CreateAdd(MaxOccVal, MaxOccVal) + : BinaryOperator::CreateFAdd(MaxOccVal, MaxOccVal); + SmallVector<WeakVH, 4> NewMulOps; for (unsigned i = 0; i != Ops.size(); ++i) { // Only try to remove factors from expressions we're allowed to. - BinaryOperator *BOp = isReassociableOp(Ops[i].Op, Instruction::Mul); + BinaryOperator *BOp = + isReassociableOp(Ops[i].Op, Instruction::Mul, Instruction::FMul); if (!BOp) continue; @@ -1542,7 +1699,7 @@ Value *Reassociate::OptimizeAdd(Instruction *I, RedoInsts.insert(VI); // Create the multiply. - Instruction *V2 = BinaryOperator::CreateMul(V, MaxOccVal, "tmp", I); + Instruction *V2 = CreateMul(V, MaxOccVal, "tmp", I, I); // Rerun associate on the multiply in case the inner expression turned into // a multiply. We want to make sure that we keep things in canonical form. @@ -1632,7 +1789,10 @@ static Value *buildMultiplyTree(IRBuilder<> &Builder, Value *LHS = Ops.pop_back_val(); do { - LHS = Builder.CreateMul(LHS, Ops.pop_back_val()); + if (LHS->getType()->isIntegerTy()) + LHS = Builder.CreateMul(LHS, Ops.pop_back_val()); + else + LHS = Builder.CreateFMul(LHS, Ops.pop_back_val()); } while (!Ops.empty()); return LHS; @@ -1765,11 +1925,13 @@ Value *Reassociate::OptimizeExpression(BinaryOperator *I, break; case Instruction::Add: + case Instruction::FAdd: if (Value *Result = OptimizeAdd(I, Ops)) return Result; break; case Instruction::Mul: + case Instruction::FMul: if (Value *Result = OptimizeMul(I, Ops)) return Result; break; @@ -1797,12 +1959,104 @@ void Reassociate::EraseInst(Instruction *I) { // and add that since that's where optimization actually happens. unsigned Opcode = Op->getOpcode(); while (Op->hasOneUse() && Op->user_back()->getOpcode() == Opcode && - Visited.insert(Op)) + Visited.insert(Op).second) Op = Op->user_back(); RedoInsts.insert(Op); } } +// Canonicalize expressions of the following form: +// x + (-Constant * y) -> x - (Constant * y) +// x - (-Constant * y) -> x + (Constant * y) +Instruction *Reassociate::canonicalizeNegConstExpr(Instruction *I) { + if (!I->hasOneUse() || I->getType()->isVectorTy()) + return nullptr; + + // Must be a mul, fmul, or fdiv instruction. + unsigned Opcode = I->getOpcode(); + if (Opcode != Instruction::Mul && Opcode != Instruction::FMul && + Opcode != Instruction::FDiv) + return nullptr; + + // Must have at least one constant operand. + Constant *C0 = dyn_cast<Constant>(I->getOperand(0)); + Constant *C1 = dyn_cast<Constant>(I->getOperand(1)); + if (!C0 && !C1) + return nullptr; + + // Must be a negative ConstantInt or ConstantFP. + Constant *C = C0 ? C0 : C1; + unsigned ConstIdx = C0 ? 0 : 1; + if (auto *CI = dyn_cast<ConstantInt>(C)) { + if (!CI->isNegative()) + return nullptr; + } else if (auto *CF = dyn_cast<ConstantFP>(C)) { + if (!CF->isNegative()) + return nullptr; + } else + return nullptr; + + // User must be a binary operator with one or more uses. + Instruction *User = I->user_back(); + if (!isa<BinaryOperator>(User) || !User->getNumUses()) + return nullptr; + + unsigned UserOpcode = User->getOpcode(); + if (UserOpcode != Instruction::Add && UserOpcode != Instruction::FAdd && + UserOpcode != Instruction::Sub && UserOpcode != Instruction::FSub) + return nullptr; + + // Subtraction is not commutative. Explicitly, the following transform is + // not valid: (-Constant * y) - x -> x + (Constant * y) + if (!User->isCommutative() && User->getOperand(1) != I) + return nullptr; + + // Change the sign of the constant. + if (ConstantInt *CI = dyn_cast<ConstantInt>(C)) + I->setOperand(ConstIdx, ConstantInt::get(CI->getContext(), -CI->getValue())); + else { + ConstantFP *CF = cast<ConstantFP>(C); + APFloat Val = CF->getValueAPF(); + Val.changeSign(); + I->setOperand(ConstIdx, ConstantFP::get(CF->getContext(), Val)); + } + + // Canonicalize I to RHS to simplify the next bit of logic. E.g., + // ((-Const*y) + x) -> (x + (-Const*y)). + if (User->getOperand(0) == I && User->isCommutative()) + cast<BinaryOperator>(User)->swapOperands(); + + Value *Op0 = User->getOperand(0); + Value *Op1 = User->getOperand(1); + BinaryOperator *NI; + switch(UserOpcode) { + default: + llvm_unreachable("Unexpected Opcode!"); + case Instruction::Add: + NI = BinaryOperator::CreateSub(Op0, Op1); + break; + case Instruction::Sub: + NI = BinaryOperator::CreateAdd(Op0, Op1); + break; + case Instruction::FAdd: + NI = BinaryOperator::CreateFSub(Op0, Op1); + NI->setFastMathFlags(cast<FPMathOperator>(User)->getFastMathFlags()); + break; + case Instruction::FSub: + NI = BinaryOperator::CreateFAdd(Op0, Op1); + NI->setFastMathFlags(cast<FPMathOperator>(User)->getFastMathFlags()); + break; + } + + NI->insertBefore(User); + NI->setName(User->getName()); + User->replaceAllUsesWith(NI); + NI->setDebugLoc(I->getDebugLoc()); + RedoInsts.insert(I); + MadeChange = true; + return NI; +} + /// OptimizeInst - Inspect and optimize the given instruction. Note that erasing /// instructions is not allowed. void Reassociate::OptimizeInst(Instruction *I) { @@ -1810,8 +2064,7 @@ void Reassociate::OptimizeInst(Instruction *I) { if (!isa<BinaryOperator>(I)) return; - if (I->getOpcode() == Instruction::Shl && - isa<ConstantInt>(I->getOperand(1))) + if (I->getOpcode() == Instruction::Shl && isa<ConstantInt>(I->getOperand(1))) // If an operand of this shift is a reassociable multiply, or if the shift // is used by a reassociable multiply or add, turn into a multiply. if (isReassociableOp(I->getOperand(0), Instruction::Mul) || @@ -1824,29 +2077,23 @@ void Reassociate::OptimizeInst(Instruction *I) { I = NI; } - // Floating point binary operators are not associative, but we can still - // commute (some) of them, to canonicalize the order of their operands. - // This can potentially expose more CSE opportunities, and makes writing - // other transformations simpler. - if ((I->getType()->isFloatingPointTy() || I->getType()->isVectorTy())) { - // FAdd and FMul can be commuted. - if (I->getOpcode() != Instruction::FMul && - I->getOpcode() != Instruction::FAdd) - return; + // Canonicalize negative constants out of expressions. + if (Instruction *Res = canonicalizeNegConstExpr(I)) + I = Res; - Value *LHS = I->getOperand(0); - Value *RHS = I->getOperand(1); - unsigned LHSRank = getRank(LHS); - unsigned RHSRank = getRank(RHS); + // Commute binary operators, to canonicalize the order of their operands. + // This can potentially expose more CSE opportunities, and makes writing other + // transformations simpler. + if (I->isCommutative()) + canonicalizeOperands(I); - // Sort the operands by rank. - if (RHSRank < LHSRank) { - I->setOperand(0, RHS); - I->setOperand(1, LHS); - } + // Don't optimize vector instructions. + if (I->getType()->isVectorTy()) + return; + // Don't optimize floating point instructions that don't have unsafe algebra. + if (I->getType()->isFloatingPointTy() && !I->hasUnsafeAlgebra()) return; - } // Do not reassociate boolean (i1) expressions. We want to preserve the // original order of evaluation for short-circuited comparisons that @@ -1877,6 +2124,24 @@ void Reassociate::OptimizeInst(Instruction *I) { I = NI; } } + } else if (I->getOpcode() == Instruction::FSub) { + if (ShouldBreakUpSubtract(I)) { + Instruction *NI = BreakUpSubtract(I); + RedoInsts.insert(I); + MadeChange = true; + I = NI; + } else if (BinaryOperator::isFNeg(I)) { + // Otherwise, this is a negation. See if the operand is a multiply tree + // and if this is not an inner node of a multiply tree. + if (isReassociableOp(I->getOperand(1), Instruction::FMul) && + (!I->hasOneUse() || + !isReassociableOp(I->user_back(), Instruction::FMul))) { + Instruction *NI = LowerNegateToMultiply(I); + RedoInsts.insert(I); + MadeChange = true; + I = NI; + } + } } // If this instruction is an associative binary operator, process it. @@ -1894,11 +2159,16 @@ void Reassociate::OptimizeInst(Instruction *I) { if (BO->hasOneUse() && BO->getOpcode() == Instruction::Add && cast<Instruction>(BO->user_back())->getOpcode() == Instruction::Sub) return; + if (BO->hasOneUse() && BO->getOpcode() == Instruction::FAdd && + cast<Instruction>(BO->user_back())->getOpcode() == Instruction::FSub) + return; ReassociateExpression(BO); } void Reassociate::ReassociateExpression(BinaryOperator *I) { + assert(!I->getType()->isVectorTy() && + "Reassociation of vector instructions is not supported."); // First, walk the expression tree, linearizing the tree, collecting the // operand information. @@ -1943,12 +2213,21 @@ void Reassociate::ReassociateExpression(BinaryOperator *I) { // this is a multiply tree used only by an add, and the immediate is a -1. // In this case we reassociate to put the negation on the outside so that we // can fold the negation into the add: (-X)*Y + Z -> Z-X*Y - if (I->getOpcode() == Instruction::Mul && I->hasOneUse() && - cast<Instruction>(I->user_back())->getOpcode() == Instruction::Add && - isa<ConstantInt>(Ops.back().Op) && - cast<ConstantInt>(Ops.back().Op)->isAllOnesValue()) { - ValueEntry Tmp = Ops.pop_back_val(); - Ops.insert(Ops.begin(), Tmp); + if (I->hasOneUse()) { + if (I->getOpcode() == Instruction::Mul && + cast<Instruction>(I->user_back())->getOpcode() == Instruction::Add && + isa<ConstantInt>(Ops.back().Op) && + cast<ConstantInt>(Ops.back().Op)->isAllOnesValue()) { + ValueEntry Tmp = Ops.pop_back_val(); + Ops.insert(Ops.begin(), Tmp); + } else if (I->getOpcode() == Instruction::FMul && + cast<Instruction>(I->user_back())->getOpcode() == + Instruction::FAdd && + isa<ConstantFP>(Ops.back().Op) && + cast<ConstantFP>(Ops.back().Op)->isExactlyValue(-1.0)) { + ValueEntry Tmp = Ops.pop_back_val(); + Ops.insert(Ops.begin(), Tmp); + } } DEBUG(dbgs() << "RAOut:\t"; PrintOps(I, Ops); dbgs() << '\n'); diff --git a/lib/Transforms/Scalar/Reg2Mem.cpp b/lib/Transforms/Scalar/Reg2Mem.cpp index b6023e2ce789..1b46727c17bb 100644 --- a/lib/Transforms/Scalar/Reg2Mem.cpp +++ b/lib/Transforms/Scalar/Reg2Mem.cpp @@ -73,7 +73,7 @@ bool RegToMem::runOnFunction(Function &F) { // Insert all new allocas into entry block. BasicBlock *BBEntry = &F.getEntryBlock(); - assert(pred_begin(BBEntry) == pred_end(BBEntry) && + assert(pred_empty(BBEntry) && "Entry block to function must not have predecessors!"); // Find first non-alloca instruction and create insertion point. This is diff --git a/lib/Transforms/Scalar/SCCP.cpp b/lib/Transforms/Scalar/SCCP.cpp index 90c3520c8323..cfc9a8e89fa0 100644 --- a/lib/Transforms/Scalar/SCCP.cpp +++ b/lib/Transforms/Scalar/SCCP.cpp @@ -214,7 +214,8 @@ public: /// /// This returns true if the block was not considered live before. bool MarkBlockExecutable(BasicBlock *BB) { - if (!BBExecutable.insert(BB)) return false; + if (!BBExecutable.insert(BB).second) + return false; DEBUG(dbgs() << "Marking Block Executable: " << BB->getName() << '\n'); BBWorkList.push_back(BB); // Add the block to the work list! return true; @@ -1010,7 +1011,7 @@ void SCCPSolver::visitGetElementPtrInst(GetElementPtrInst &I) { } Constant *Ptr = Operands[0]; - ArrayRef<Constant *> Indices(Operands.begin() + 1, Operands.end()); + auto Indices = makeArrayRef(Operands.begin() + 1, Operands.end()); markConstant(&I, ConstantExpr::getGetElementPtr(Ptr, Indices)); } @@ -1107,6 +1108,9 @@ CallOverdefined: Operands.push_back(State.getConstant()); } + if (getValueState(I).isOverdefined()) + return; + // If we can constant fold this, mark the result of the call as a // constant. if (Constant *C = ConstantFoldCall(F, Operands, TLI)) diff --git a/lib/Transforms/Scalar/SROA.cpp b/lib/Transforms/Scalar/SROA.cpp index f902eb23cbcf..ed161fd4af3e 100644 --- a/lib/Transforms/Scalar/SROA.cpp +++ b/lib/Transforms/Scalar/SROA.cpp @@ -28,6 +28,7 @@ #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/Loads.h" #include "llvm/Analysis/PtrUseVisitor.h" #include "llvm/Analysis/ValueTracking.h" @@ -78,8 +79,8 @@ STATISTIC(NumVectorized, "Number of vectorized aggregates"); /// Hidden option to force the pass to not use DomTree and mem2reg, instead /// forming SSA values through the SSAUpdater infrastructure. -static cl::opt<bool> -ForceSSAUpdater("force-ssa-updater", cl::init(false), cl::Hidden); +static cl::opt<bool> ForceSSAUpdater("force-ssa-updater", cl::init(false), + cl::Hidden); /// Hidden option to enable randomly shuffling the slices to help uncover /// instability in their order. @@ -88,15 +89,15 @@ static cl::opt<bool> SROARandomShuffleSlices("sroa-random-shuffle-slices", /// Hidden option to experiment with completely strict handling of inbounds /// GEPs. -static cl::opt<bool> SROAStrictInbounds("sroa-strict-inbounds", - cl::init(false), cl::Hidden); +static cl::opt<bool> SROAStrictInbounds("sroa-strict-inbounds", cl::init(false), + cl::Hidden); namespace { /// \brief A custom IRBuilder inserter which prefixes all names if they are /// preserved. template <bool preserveNames = true> -class IRBuilderPrefixedInserter : - public IRBuilderDefaultInserter<preserveNames> { +class IRBuilderPrefixedInserter + : public IRBuilderDefaultInserter<preserveNames> { std::string Prefix; public: @@ -112,19 +113,19 @@ protected: // Specialization for not preserving the name is trivial. template <> -class IRBuilderPrefixedInserter<false> : - public IRBuilderDefaultInserter<false> { +class IRBuilderPrefixedInserter<false> + : public IRBuilderDefaultInserter<false> { public: void SetNamePrefix(const Twine &P) {} }; /// \brief Provide a typedef for IRBuilder that drops names in release builds. #ifndef NDEBUG -typedef llvm::IRBuilder<true, ConstantFolder, - IRBuilderPrefixedInserter<true> > IRBuilderTy; +typedef llvm::IRBuilder<true, ConstantFolder, IRBuilderPrefixedInserter<true>> + IRBuilderTy; #else -typedef llvm::IRBuilder<false, ConstantFolder, - IRBuilderPrefixedInserter<false> > IRBuilderTy; +typedef llvm::IRBuilder<false, ConstantFolder, IRBuilderPrefixedInserter<false>> + IRBuilderTy; #endif } @@ -170,10 +171,14 @@ public: /// decreasing. Thus the spanning range comes first in a cluster with the /// same start position. bool operator<(const Slice &RHS) const { - if (beginOffset() < RHS.beginOffset()) return true; - if (beginOffset() > RHS.beginOffset()) return false; - if (isSplittable() != RHS.isSplittable()) return !isSplittable(); - if (endOffset() > RHS.endOffset()) return true; + if (beginOffset() < RHS.beginOffset()) + return true; + if (beginOffset() > RHS.beginOffset()) + return false; + if (isSplittable() != RHS.isSplittable()) + return !isSplittable(); + if (endOffset() > RHS.endOffset()) + return true; return false; } @@ -197,9 +202,7 @@ public: namespace llvm { template <typename T> struct isPodLike; -template <> struct isPodLike<Slice> { - static const bool value = true; -}; +template <> struct isPodLike<Slice> { static const bool value = true; }; } namespace { @@ -224,36 +227,318 @@ public: /// \brief Support for iterating over the slices. /// @{ typedef SmallVectorImpl<Slice>::iterator iterator; + typedef iterator_range<iterator> range; iterator begin() { return Slices.begin(); } iterator end() { return Slices.end(); } typedef SmallVectorImpl<Slice>::const_iterator const_iterator; + typedef iterator_range<const_iterator> const_range; const_iterator begin() const { return Slices.begin(); } const_iterator end() const { return Slices.end(); } /// @} - /// \brief Allow iterating the dead users for this alloca. + /// \brief Erase a range of slices. + void erase(iterator Start, iterator Stop) { Slices.erase(Start, Stop); } + + /// \brief Insert new slices for this alloca. /// - /// These are instructions which will never actually use the alloca as they - /// are outside the allocated range. They are safe to replace with undef and - /// delete. - /// @{ - typedef SmallVectorImpl<Instruction *>::const_iterator dead_user_iterator; - dead_user_iterator dead_user_begin() const { return DeadUsers.begin(); } - dead_user_iterator dead_user_end() const { return DeadUsers.end(); } - /// @} + /// This moves the slices into the alloca's slices collection, and re-sorts + /// everything so that the usual ordering properties of the alloca's slices + /// hold. + void insert(ArrayRef<Slice> NewSlices) { + int OldSize = Slices.size(); + std::move(NewSlices.begin(), NewSlices.end(), std::back_inserter(Slices)); + auto SliceI = Slices.begin() + OldSize; + std::sort(SliceI, Slices.end()); + std::inplace_merge(Slices.begin(), SliceI, Slices.end()); + } + + // Forward declare an iterator to befriend it. + class partition_iterator; + + /// \brief A partition of the slices. + /// + /// An ephemeral representation for a range of slices which can be viewed as + /// a partition of the alloca. This range represents a span of the alloca's + /// memory which cannot be split, and provides access to all of the slices + /// overlapping some part of the partition. + /// + /// Objects of this type are produced by traversing the alloca's slices, but + /// are only ephemeral and not persistent. + class Partition { + private: + friend class AllocaSlices; + friend class AllocaSlices::partition_iterator; + + /// \brief The begining and ending offsets of the alloca for this partition. + uint64_t BeginOffset, EndOffset; + + /// \brief The start end end iterators of this partition. + iterator SI, SJ; + + /// \brief A collection of split slice tails overlapping the partition. + SmallVector<Slice *, 4> SplitTails; + + /// \brief Raw constructor builds an empty partition starting and ending at + /// the given iterator. + Partition(iterator SI) : SI(SI), SJ(SI) {} + + public: + /// \brief The start offset of this partition. + /// + /// All of the contained slices start at or after this offset. + uint64_t beginOffset() const { return BeginOffset; } - /// \brief Allow iterating the dead expressions referring to this alloca. + /// \brief The end offset of this partition. + /// + /// All of the contained slices end at or before this offset. + uint64_t endOffset() const { return EndOffset; } + + /// \brief The size of the partition. + /// + /// Note that this can never be zero. + uint64_t size() const { + assert(BeginOffset < EndOffset && "Partitions must span some bytes!"); + return EndOffset - BeginOffset; + } + + /// \brief Test whether this partition contains no slices, and merely spans + /// a region occupied by split slices. + bool empty() const { return SI == SJ; } + + /// \name Iterate slices that start within the partition. + /// These may be splittable or unsplittable. They have a begin offset >= the + /// partition begin offset. + /// @{ + // FIXME: We should probably define a "concat_iterator" helper and use that + // to stitch together pointee_iterators over the split tails and the + // contiguous iterators of the partition. That would give a much nicer + // interface here. We could then additionally expose filtered iterators for + // split, unsplit, and unsplittable splices based on the usage patterns. + iterator begin() const { return SI; } + iterator end() const { return SJ; } + /// @} + + /// \brief Get the sequence of split slice tails. + /// + /// These tails are of slices which start before this partition but are + /// split and overlap into the partition. We accumulate these while forming + /// partitions. + ArrayRef<Slice *> splitSliceTails() const { return SplitTails; } + }; + + /// \brief An iterator over partitions of the alloca's slices. + /// + /// This iterator implements the core algorithm for partitioning the alloca's + /// slices. It is a forward iterator as we don't support backtracking for + /// efficiency reasons, and re-use a single storage area to maintain the + /// current set of split slices. + /// + /// It is templated on the slice iterator type to use so that it can operate + /// with either const or non-const slice iterators. + class partition_iterator + : public iterator_facade_base<partition_iterator, + std::forward_iterator_tag, Partition> { + friend class AllocaSlices; + + /// \brief Most of the state for walking the partitions is held in a class + /// with a nice interface for examining them. + Partition P; + + /// \brief We need to keep the end of the slices to know when to stop. + AllocaSlices::iterator SE; + + /// \brief We also need to keep track of the maximum split end offset seen. + /// FIXME: Do we really? + uint64_t MaxSplitSliceEndOffset; + + /// \brief Sets the partition to be empty at given iterator, and sets the + /// end iterator. + partition_iterator(AllocaSlices::iterator SI, AllocaSlices::iterator SE) + : P(SI), SE(SE), MaxSplitSliceEndOffset(0) { + // If not already at the end, advance our state to form the initial + // partition. + if (SI != SE) + advance(); + } + + /// \brief Advance the iterator to the next partition. + /// + /// Requires that the iterator not be at the end of the slices. + void advance() { + assert((P.SI != SE || !P.SplitTails.empty()) && + "Cannot advance past the end of the slices!"); + + // Clear out any split uses which have ended. + if (!P.SplitTails.empty()) { + if (P.EndOffset >= MaxSplitSliceEndOffset) { + // If we've finished all splits, this is easy. + P.SplitTails.clear(); + MaxSplitSliceEndOffset = 0; + } else { + // Remove the uses which have ended in the prior partition. This + // cannot change the max split slice end because we just checked that + // the prior partition ended prior to that max. + P.SplitTails.erase( + std::remove_if( + P.SplitTails.begin(), P.SplitTails.end(), + [&](Slice *S) { return S->endOffset() <= P.EndOffset; }), + P.SplitTails.end()); + assert(std::any_of(P.SplitTails.begin(), P.SplitTails.end(), + [&](Slice *S) { + return S->endOffset() == MaxSplitSliceEndOffset; + }) && + "Could not find the current max split slice offset!"); + assert(std::all_of(P.SplitTails.begin(), P.SplitTails.end(), + [&](Slice *S) { + return S->endOffset() <= MaxSplitSliceEndOffset; + }) && + "Max split slice end offset is not actually the max!"); + } + } + + // If P.SI is already at the end, then we've cleared the split tail and + // now have an end iterator. + if (P.SI == SE) { + assert(P.SplitTails.empty() && "Failed to clear the split slices!"); + return; + } + + // If we had a non-empty partition previously, set up the state for + // subsequent partitions. + if (P.SI != P.SJ) { + // Accumulate all the splittable slices which started in the old + // partition into the split list. + for (Slice &S : P) + if (S.isSplittable() && S.endOffset() > P.EndOffset) { + P.SplitTails.push_back(&S); + MaxSplitSliceEndOffset = + std::max(S.endOffset(), MaxSplitSliceEndOffset); + } + + // Start from the end of the previous partition. + P.SI = P.SJ; + + // If P.SI is now at the end, we at most have a tail of split slices. + if (P.SI == SE) { + P.BeginOffset = P.EndOffset; + P.EndOffset = MaxSplitSliceEndOffset; + return; + } + + // If the we have split slices and the next slice is after a gap and is + // not splittable immediately form an empty partition for the split + // slices up until the next slice begins. + if (!P.SplitTails.empty() && P.SI->beginOffset() != P.EndOffset && + !P.SI->isSplittable()) { + P.BeginOffset = P.EndOffset; + P.EndOffset = P.SI->beginOffset(); + return; + } + } + + // OK, we need to consume new slices. Set the end offset based on the + // current slice, and step SJ past it. The beginning offset of the + // parttion is the beginning offset of the next slice unless we have + // pre-existing split slices that are continuing, in which case we begin + // at the prior end offset. + P.BeginOffset = P.SplitTails.empty() ? P.SI->beginOffset() : P.EndOffset; + P.EndOffset = P.SI->endOffset(); + ++P.SJ; + + // There are two strategies to form a partition based on whether the + // partition starts with an unsplittable slice or a splittable slice. + if (!P.SI->isSplittable()) { + // When we're forming an unsplittable region, it must always start at + // the first slice and will extend through its end. + assert(P.BeginOffset == P.SI->beginOffset()); + + // Form a partition including all of the overlapping slices with this + // unsplittable slice. + while (P.SJ != SE && P.SJ->beginOffset() < P.EndOffset) { + if (!P.SJ->isSplittable()) + P.EndOffset = std::max(P.EndOffset, P.SJ->endOffset()); + ++P.SJ; + } + + // We have a partition across a set of overlapping unsplittable + // partitions. + return; + } + + // If we're starting with a splittable slice, then we need to form + // a synthetic partition spanning it and any other overlapping splittable + // splices. + assert(P.SI->isSplittable() && "Forming a splittable partition!"); + + // Collect all of the overlapping splittable slices. + while (P.SJ != SE && P.SJ->beginOffset() < P.EndOffset && + P.SJ->isSplittable()) { + P.EndOffset = std::max(P.EndOffset, P.SJ->endOffset()); + ++P.SJ; + } + + // Back upiP.EndOffset if we ended the span early when encountering an + // unsplittable slice. This synthesizes the early end offset of + // a partition spanning only splittable slices. + if (P.SJ != SE && P.SJ->beginOffset() < P.EndOffset) { + assert(!P.SJ->isSplittable()); + P.EndOffset = P.SJ->beginOffset(); + } + } + + public: + bool operator==(const partition_iterator &RHS) const { + assert(SE == RHS.SE && + "End iterators don't match between compared partition iterators!"); + + // The observed positions of partitions is marked by the P.SI iterator and + // the emptyness of the split slices. The latter is only relevant when + // P.SI == SE, as the end iterator will additionally have an empty split + // slices list, but the prior may have the same P.SI and a tail of split + // slices. + if (P.SI == RHS.P.SI && + P.SplitTails.empty() == RHS.P.SplitTails.empty()) { + assert(P.SJ == RHS.P.SJ && + "Same set of slices formed two different sized partitions!"); + assert(P.SplitTails.size() == RHS.P.SplitTails.size() && + "Same slice position with differently sized non-empty split " + "slice tails!"); + return true; + } + return false; + } + + partition_iterator &operator++() { + advance(); + return *this; + } + + Partition &operator*() { return P; } + }; + + /// \brief A forward range over the partitions of the alloca's slices. + /// + /// This accesses an iterator range over the partitions of the alloca's + /// slices. It computes these partitions on the fly based on the overlapping + /// offsets of the slices and the ability to split them. It will visit "empty" + /// partitions to cover regions of the alloca only accessed via split + /// slices. + iterator_range<partition_iterator> partitions() { + return make_range(partition_iterator(begin(), end()), + partition_iterator(end(), end())); + } + + /// \brief Access the dead users for this alloca. + ArrayRef<Instruction *> getDeadUsers() const { return DeadUsers; } + + /// \brief Access the dead operands referring to this alloca. /// /// These are operands which have cannot actually be used to refer to the /// alloca as they are outside its range and the user doesn't correct for /// that. These mostly consist of PHI node inputs and the like which we just /// need to replace with undef. - /// @{ - typedef SmallVectorImpl<Use *>::const_iterator dead_op_iterator; - dead_op_iterator dead_op_begin() const { return DeadOperands.begin(); } - dead_op_iterator dead_op_end() const { return DeadOperands.end(); } - /// @} + ArrayRef<Use *> getDeadOperands() const { return DeadOperands; } #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) void print(raw_ostream &OS, const_iterator I, StringRef Indent = " ") const; @@ -317,13 +602,22 @@ static Value *foldSelectInst(SelectInst &SI) { // being selected between, fold the select. Yes this does (rarely) happen // early on. if (ConstantInt *CI = dyn_cast<ConstantInt>(SI.getCondition())) - return SI.getOperand(1+CI->isZero()); + return SI.getOperand(1 + CI->isZero()); if (SI.getOperand(1) == SI.getOperand(2)) return SI.getOperand(1); return nullptr; } +/// \brief A helper that folds a PHI node or a select. +static Value *foldPHINodeOrSelectInst(Instruction &I) { + if (PHINode *PN = dyn_cast<PHINode>(&I)) { + // If PN merges together the same value, return that value. + return PN->hasConstantValue(); + } + return foldSelectInst(cast<SelectInst>(I)); +} + /// \brief Builder for the alloca slices. /// /// This class builds a set of alloca slices by recursively visiting the uses @@ -334,7 +628,7 @@ class AllocaSlices::SliceBuilder : public PtrUseVisitor<SliceBuilder> { typedef PtrUseVisitor<SliceBuilder> Base; const uint64_t AllocSize; - AllocaSlices &S; + AllocaSlices &AS; SmallDenseMap<Instruction *, unsigned> MemTransferSliceMap; SmallDenseMap<Instruction *, uint64_t> PHIOrSelectSizes; @@ -343,14 +637,14 @@ class AllocaSlices::SliceBuilder : public PtrUseVisitor<SliceBuilder> { SmallPtrSet<Instruction *, 4> VisitedDeadInsts; public: - SliceBuilder(const DataLayout &DL, AllocaInst &AI, AllocaSlices &S) + SliceBuilder(const DataLayout &DL, AllocaInst &AI, AllocaSlices &AS) : PtrUseVisitor<SliceBuilder>(DL), - AllocSize(DL.getTypeAllocSize(AI.getAllocatedType())), S(S) {} + AllocSize(DL.getTypeAllocSize(AI.getAllocatedType())), AS(AS) {} private: void markAsDead(Instruction &I) { - if (VisitedDeadInsts.insert(&I)) - S.DeadUsers.push_back(&I); + if (VisitedDeadInsts.insert(&I).second) + AS.DeadUsers.push_back(&I); } void insertUse(Instruction &I, const APInt &Offset, uint64_t Size, @@ -361,7 +655,7 @@ private: DEBUG(dbgs() << "WARNING: Ignoring " << Size << " byte use @" << Offset << " which has zero size or starts outside of the " << AllocSize << " byte alloca:\n" - << " alloca: " << S.AI << "\n" + << " alloca: " << AS.AI << "\n" << " use: " << I << "\n"); return markAsDead(I); } @@ -379,12 +673,12 @@ private: if (Size > AllocSize - BeginOffset) { DEBUG(dbgs() << "WARNING: Clamping a " << Size << " byte use @" << Offset << " to remain within the " << AllocSize << " byte alloca:\n" - << " alloca: " << S.AI << "\n" + << " alloca: " << AS.AI << "\n" << " use: " << I << "\n"); EndOffset = AllocSize; } - S.Slices.push_back(Slice(BeginOffset, EndOffset, U, IsSplittable)); + AS.Slices.push_back(Slice(BeginOffset, EndOffset, U, IsSplittable)); } void visitBitCastInst(BitCastInst &BC) { @@ -421,7 +715,8 @@ private: GEPOffset += APInt(Offset.getBitWidth(), SL->getElementOffset(ElementIdx)); } else { - // For array or vector indices, scale the index by the size of the type. + // For array or vector indices, scale the index by the size of the + // type. APInt Index = OpC->getValue().sextOrTrunc(Offset.getBitWidth()); GEPOffset += Index * APInt(Offset.getBitWidth(), DL.getTypeAllocSize(GTI.getIndexedType())); @@ -440,16 +735,10 @@ private: void handleLoadOrStore(Type *Ty, Instruction &I, const APInt &Offset, uint64_t Size, bool IsVolatile) { - // We allow splitting of loads and stores where the type is an integer type - // and cover the entire alloca. This prevents us from splitting over - // eagerly. - // FIXME: In the great blue eventually, we should eagerly split all integer - // loads and stores, and then have a separate step that merges adjacent - // alloca partitions into a single partition suitable for integer widening. - // Or we should skip the merge step and rely on GVN and other passes to - // merge adjacent loads and stores that survive mem2reg. - bool IsSplittable = - Ty->isIntegerTy() && !IsVolatile && Offset == 0 && Size >= AllocSize; + // We allow splitting of non-volatile loads and stores where the type is an + // integer type. These may be used to implement 'memcpy' or other "transfer + // of bits" patterns. + bool IsSplittable = Ty->isIntegerTy() && !IsVolatile; insertUse(I, Offset, Size, IsSplittable); } @@ -485,7 +774,7 @@ private: DEBUG(dbgs() << "WARNING: Ignoring " << Size << " byte store @" << Offset << " which extends past the end of the " << AllocSize << " byte alloca:\n" - << " alloca: " << S.AI << "\n" + << " alloca: " << AS.AI << "\n" << " use: " << SI << "\n"); return markAsDead(SI); } @@ -495,7 +784,6 @@ private: handleLoadOrStore(ValOp->getType(), SI, Offset, Size, SI.isVolatile()); } - void visitMemSetInst(MemSetInst &II) { assert(II.getRawDest() == *U && "Pointer use is not the destination?"); ConstantInt *Length = dyn_cast<ConstantInt>(II.getLength()); @@ -507,9 +795,8 @@ private: if (!IsOffsetKnown) return PI.setAborted(&II); - insertUse(II, Offset, - Length ? Length->getLimitedValue() - : AllocSize - Offset.getLimitedValue(), + insertUse(II, Offset, Length ? Length->getLimitedValue() + : AllocSize - Offset.getLimitedValue(), (bool)Length); } @@ -533,15 +820,15 @@ private: // FIXME: Yet another place we really should bypass this when // instrumenting for ASan. if (Offset.uge(AllocSize)) { - SmallDenseMap<Instruction *, unsigned>::iterator MTPI = MemTransferSliceMap.find(&II); + SmallDenseMap<Instruction *, unsigned>::iterator MTPI = + MemTransferSliceMap.find(&II); if (MTPI != MemTransferSliceMap.end()) - S.Slices[MTPI->second].kill(); + AS.Slices[MTPI->second].kill(); return markAsDead(II); } uint64_t RawOffset = Offset.getLimitedValue(); - uint64_t Size = Length ? Length->getLimitedValue() - : AllocSize - RawOffset; + uint64_t Size = Length ? Length->getLimitedValue() : AllocSize - RawOffset; // Check for the special case where the same exact value is used for both // source and dest. @@ -558,10 +845,10 @@ private: bool Inserted; SmallDenseMap<Instruction *, unsigned>::iterator MTPI; std::tie(MTPI, Inserted) = - MemTransferSliceMap.insert(std::make_pair(&II, S.Slices.size())); + MemTransferSliceMap.insert(std::make_pair(&II, AS.Slices.size())); unsigned PrevIdx = MTPI->second; if (!Inserted) { - Slice &PrevP = S.Slices[PrevIdx]; + Slice &PrevP = AS.Slices[PrevIdx]; // Check if the begin offsets match and this is a non-volatile transfer. // In that case, we can completely elide the transfer. @@ -579,7 +866,7 @@ private: insertUse(II, Offset, Size, /*IsSplittable=*/Inserted && Length); // Check that we ended up with a valid index in the map. - assert(S.Slices[PrevIdx].getUse()->getUser() == &II && + assert(AS.Slices[PrevIdx].getUse()->getUser() == &II && "Map index doesn't point back to a slice with this user."); } @@ -639,64 +926,47 @@ private: } for (User *U : I->users()) - if (Visited.insert(cast<Instruction>(U))) + if (Visited.insert(cast<Instruction>(U)).second) Uses.push_back(std::make_pair(I, cast<Instruction>(U))); } while (!Uses.empty()); return nullptr; } - void visitPHINode(PHINode &PN) { - if (PN.use_empty()) - return markAsDead(PN); - if (!IsOffsetKnown) - return PI.setAborted(&PN); - - // See if we already have computed info on this node. - uint64_t &PHISize = PHIOrSelectSizes[&PN]; - if (!PHISize) { - // This is a new PHI node, check for an unsafe use of the PHI node. - if (Instruction *UnsafeI = hasUnsafePHIOrSelectUse(&PN, PHISize)) - return PI.setAborted(UnsafeI); - } - - // For PHI and select operands outside the alloca, we can't nuke the entire - // phi or select -- the other side might still be relevant, so we special - // case them here and use a separate structure to track the operands - // themselves which should be replaced with undef. - // FIXME: This should instead be escaped in the event we're instrumenting - // for address sanitization. - if (Offset.uge(AllocSize)) { - S.DeadOperands.push_back(U); - return; - } - - insertUse(PN, Offset, PHISize); - } + void visitPHINodeOrSelectInst(Instruction &I) { + assert(isa<PHINode>(I) || isa<SelectInst>(I)); + if (I.use_empty()) + return markAsDead(I); - void visitSelectInst(SelectInst &SI) { - if (SI.use_empty()) - return markAsDead(SI); - if (Value *Result = foldSelectInst(SI)) { + // TODO: We could use SimplifyInstruction here to fold PHINodes and + // SelectInsts. However, doing so requires to change the current + // dead-operand-tracking mechanism. For instance, suppose neither loading + // from %U nor %other traps. Then "load (select undef, %U, %other)" does not + // trap either. However, if we simply replace %U with undef using the + // current dead-operand-tracking mechanism, "load (select undef, undef, + // %other)" may trap because the select may return the first operand + // "undef". + if (Value *Result = foldPHINodeOrSelectInst(I)) { if (Result == *U) // If the result of the constant fold will be the pointer, recurse - // through the select as if we had RAUW'ed it. - enqueueUsers(SI); + // through the PHI/select as if we had RAUW'ed it. + enqueueUsers(I); else - // Otherwise the operand to the select is dead, and we can replace it - // with undef. - S.DeadOperands.push_back(U); + // Otherwise the operand to the PHI/select is dead, and we can replace + // it with undef. + AS.DeadOperands.push_back(U); return; } + if (!IsOffsetKnown) - return PI.setAborted(&SI); + return PI.setAborted(&I); // See if we already have computed info on this node. - uint64_t &SelectSize = PHIOrSelectSizes[&SI]; - if (!SelectSize) { - // This is a new Select, check for an unsafe use of it. - if (Instruction *UnsafeI = hasUnsafePHIOrSelectUse(&SI, SelectSize)) + uint64_t &Size = PHIOrSelectSizes[&I]; + if (!Size) { + // This is a new PHI/Select, check for an unsafe use of it. + if (Instruction *UnsafeI = hasUnsafePHIOrSelectUse(&I, Size)) return PI.setAborted(UnsafeI); } @@ -707,17 +977,19 @@ private: // FIXME: This should instead be escaped in the event we're instrumenting // for address sanitization. if (Offset.uge(AllocSize)) { - S.DeadOperands.push_back(U); + AS.DeadOperands.push_back(U); return; } - insertUse(SI, Offset, SelectSize); + insertUse(I, Offset, Size); } + void visitPHINode(PHINode &PN) { visitPHINodeOrSelectInst(PN); } + + void visitSelectInst(SelectInst &SI) { visitPHINodeOrSelectInst(SI); } + /// \brief Disable SROA entirely if there are unhandled users of the alloca. - void visitInstruction(Instruction &I) { - PI.setAborted(&I); - } + void visitInstruction(Instruction &I) { PI.setAborted(&I); } }; AllocaSlices::AllocaSlices(const DataLayout &DL, AllocaInst &AI) @@ -738,7 +1010,9 @@ AllocaSlices::AllocaSlices(const DataLayout &DL, AllocaInst &AI) } Slices.erase(std::remove_if(Slices.begin(), Slices.end(), - std::mem_fun_ref(&Slice::isDead)), + [](const Slice &S) { + return S.isDead(); + }), Slices.end()); #if __cplusplus >= 201103L && !defined(NDEBUG) @@ -758,6 +1032,7 @@ AllocaSlices::AllocaSlices(const DataLayout &DL, AllocaInst &AI) void AllocaSlices::print(raw_ostream &OS, const_iterator I, StringRef Indent) const { printSlice(OS, I, Indent); + OS << "\n"; printUse(OS, I, Indent); } @@ -765,7 +1040,7 @@ void AllocaSlices::printSlice(raw_ostream &OS, const_iterator I, StringRef Indent) const { OS << Indent << "[" << I->beginOffset() << "," << I->endOffset() << ")" << " slice #" << (I - begin()) - << (I->isSplittable() ? " (splittable)" : "") << "\n"; + << (I->isSplittable() ? " (splittable)" : ""); } void AllocaSlices::printUse(raw_ostream &OS, const_iterator I, @@ -813,15 +1088,17 @@ public: AllocaInst &AI, DIBuilder &DIB) : LoadAndStorePromoter(Insts, S), AI(AI), DIB(DIB) {} - void run(const SmallVectorImpl<Instruction*> &Insts) { + void run(const SmallVectorImpl<Instruction *> &Insts) { // Retain the debug information attached to the alloca for use when // rewriting loads and stores. - if (MDNode *DebugNode = MDNode::getIfExists(AI.getContext(), &AI)) { - for (User *U : DebugNode->users()) - if (DbgDeclareInst *DDI = dyn_cast<DbgDeclareInst>(U)) - DDIs.push_back(DDI); - else if (DbgValueInst *DVI = dyn_cast<DbgValueInst>(U)) - DVIs.push_back(DVI); + if (auto *L = LocalAsMetadata::getIfExists(&AI)) { + if (auto *DebugNode = MetadataAsValue::getIfExists(AI.getContext(), L)) { + for (User *U : DebugNode->users()) + if (DbgDeclareInst *DDI = dyn_cast<DbgDeclareInst>(U)) + DDIs.push_back(DDI); + else if (DbgValueInst *DVI = dyn_cast<DbgValueInst>(U)) + DVIs.push_back(DVI); + } } LoadAndStorePromoter::run(Insts); @@ -834,8 +1111,9 @@ public: DVIs.pop_back_val()->eraseFromParent(); } - bool isInstInList(Instruction *I, - const SmallVectorImpl<Instruction*> &Insts) const override { + bool + isInstInList(Instruction *I, + const SmallVectorImpl<Instruction *> &Insts) const override { Value *Ptr; if (LoadInst *LI = dyn_cast<LoadInst>(I)) Ptr = LI->getOperand(0); @@ -857,23 +1135,18 @@ public: else return false; - } while (Visited.insert(Ptr)); + } while (Visited.insert(Ptr).second); return false; } void updateDebugInfo(Instruction *Inst) const override { - for (SmallVectorImpl<DbgDeclareInst *>::const_iterator I = DDIs.begin(), - E = DDIs.end(); I != E; ++I) { - DbgDeclareInst *DDI = *I; + for (DbgDeclareInst *DDI : DDIs) if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) ConvertDebugDeclareToDebugValue(DDI, SI, DIB); else if (LoadInst *LI = dyn_cast<LoadInst>(Inst)) ConvertDebugDeclareToDebugValue(DDI, LI, DIB); - } - for (SmallVectorImpl<DbgValueInst *>::const_iterator I = DVIs.begin(), - E = DVIs.end(); I != E; ++I) { - DbgValueInst *DVI = *I; + for (DbgValueInst *DVI : DVIs) { Value *Arg = nullptr; if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) { // If an argument is zero extended then use argument directly. The ZExt @@ -890,15 +1163,14 @@ public: continue; } Instruction *DbgVal = - DIB.insertDbgValueIntrinsic(Arg, 0, DIVariable(DVI->getVariable()), - Inst); + DIB.insertDbgValueIntrinsic(Arg, 0, DIVariable(DVI->getVariable()), + DIExpression(DVI->getExpression()), Inst); DbgVal->setDebugLoc(DVI->getDebugLoc()); } } }; } // end anon namespace - namespace { /// \brief An optimization pass providing Scalar Replacement of Aggregates. /// @@ -924,6 +1196,7 @@ class SROA : public FunctionPass { LLVMContext *C; const DataLayout *DL; DominatorTree *DT; + AssumptionCache *AC; /// \brief Worklist of alloca instructions to simplify. /// @@ -932,12 +1205,12 @@ class SROA : public FunctionPass { /// directly promoted. Finally, each time we rewrite a use of an alloca other /// the one being actively rewritten, we add it back onto the list if not /// already present to ensure it is re-visited. - SetVector<AllocaInst *, SmallVector<AllocaInst *, 16> > Worklist; + SetVector<AllocaInst *, SmallVector<AllocaInst *, 16>> Worklist; /// \brief A collection of instructions to delete. /// We try to batch deletions to simplify code and make things a bit more /// efficient. - SetVector<Instruction *, SmallVector<Instruction *, 8> > DeadInsts; + SetVector<Instruction *, SmallVector<Instruction *, 8>> DeadInsts; /// \brief Post-promotion worklist. /// @@ -947,7 +1220,7 @@ class SROA : public FunctionPass { /// /// Note that we have to be very careful to clear allocas out of this list in /// the event they are deleted. - SetVector<AllocaInst *, SmallVector<AllocaInst *, 16> > PostPromotionWorklist; + SetVector<AllocaInst *, SmallVector<AllocaInst *, 16>> PostPromotionWorklist; /// \brief A collection of alloca instructions we can directly promote. std::vector<AllocaInst *> PromotableAllocas; @@ -957,7 +1230,7 @@ class SROA : public FunctionPass { /// All of these PHIs have been checked for the safety of speculation and by /// being speculated will allow promoting allocas currently in the promotable /// queue. - SetVector<PHINode *, SmallVector<PHINode *, 2> > SpeculatablePHIs; + SetVector<PHINode *, SmallVector<PHINode *, 2>> SpeculatablePHIs; /// \brief A worklist of select instructions to speculate prior to promoting /// allocas. @@ -965,12 +1238,12 @@ class SROA : public FunctionPass { /// All of these select instructions have been checked for the safety of /// speculation and by being speculated will allow promoting allocas /// currently in the promotable queue. - SetVector<SelectInst *, SmallVector<SelectInst *, 2> > SpeculatableSelects; + SetVector<SelectInst *, SmallVector<SelectInst *, 2>> SpeculatableSelects; public: SROA(bool RequiresDomTree = true) - : FunctionPass(ID), RequiresDomTree(RequiresDomTree), - C(nullptr), DL(nullptr), DT(nullptr) { + : FunctionPass(ID), RequiresDomTree(RequiresDomTree), C(nullptr), + DL(nullptr), DT(nullptr) { initializeSROAPass(*PassRegistry::getPassRegistry()); } bool runOnFunction(Function &F) override; @@ -983,14 +1256,13 @@ private: friend class PHIOrSelectSpeculator; friend class AllocaSliceRewriter; - bool rewritePartition(AllocaInst &AI, AllocaSlices &S, - AllocaSlices::iterator B, AllocaSlices::iterator E, - int64_t BeginOffset, int64_t EndOffset, - ArrayRef<AllocaSlices::iterator> SplitUses); - bool splitAlloca(AllocaInst &AI, AllocaSlices &S); + bool presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS); + bool rewritePartition(AllocaInst &AI, AllocaSlices &AS, + AllocaSlices::Partition &P); + bool splitAlloca(AllocaInst &AI, AllocaSlices &AS); bool runOnAlloca(AllocaInst &AI); void clobberUse(Use &U); - void deleteDeadInstructions(SmallPtrSet<AllocaInst *, 4> &DeletedAllocas); + void deleteDeadInstructions(SmallPtrSetImpl<AllocaInst *> &DeletedAllocas); bool promoteAllocas(Function &F); }; } @@ -1001,11 +1273,12 @@ FunctionPass *llvm::createSROAPass(bool RequiresDomTree) { return new SROA(RequiresDomTree); } -INITIALIZE_PASS_BEGIN(SROA, "sroa", "Scalar Replacement Of Aggregates", - false, false) +INITIALIZE_PASS_BEGIN(SROA, "sroa", "Scalar Replacement Of Aggregates", false, + false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_END(SROA, "sroa", "Scalar Replacement Of Aggregates", - false, false) +INITIALIZE_PASS_END(SROA, "sroa", "Scalar Replacement Of Aggregates", false, + false) /// Walk the range of a partitioning looking for a common type to cover this /// sequence of slices. @@ -1076,8 +1349,7 @@ static Type *findCommonType(AllocaSlices::const_iterator B, /// /// FIXME: This should be hoisted into a generic utility, likely in /// Transforms/Util/Local.h -static bool isSafePHIToSpeculate(PHINode &PN, - const DataLayout *DL = nullptr) { +static bool isSafePHIToSpeculate(PHINode &PN, const DataLayout *DL = nullptr) { // For now, we can only do this promotion if the load is in the same block // as the PHI, and if there are no stores between the phi and load. // TODO: Allow recursive phi users. @@ -1148,10 +1420,12 @@ static void speculatePHINodeLoads(PHINode &PN) { PHINode *NewPN = PHIBuilder.CreatePHI(LoadTy, PN.getNumIncomingValues(), PN.getName() + ".sroa.speculated"); - // Get the TBAA tag and alignment to use from one of the loads. It doesn't + // Get the AA tags and alignment to use from one of the loads. It doesn't // matter which one we get and if any differ. LoadInst *SomeLoad = cast<LoadInst>(PN.user_back()); - MDNode *TBAATag = SomeLoad->getMetadata(LLVMContext::MD_tbaa); + + AAMDNodes AATags; + SomeLoad->getAAMetadata(AATags); unsigned Align = SomeLoad->getAlignment(); // Rewrite all loads of the PN to use the new PHI. @@ -1172,8 +1446,8 @@ static void speculatePHINodeLoads(PHINode &PN) { InVal, (PN.getName() + ".sroa.speculate.load." + Pred->getName())); ++NumLoadsSpeculated; Load->setAlignment(Align); - if (TBAATag) - Load->setMetadata(LLVMContext::MD_tbaa, TBAATag); + if (AATags) + Load->setAAMetadata(AATags); NewPN->addIncoming(Load, Pred); } @@ -1238,12 +1512,15 @@ static void speculateSelectInstLoads(SelectInst &SI) { IRB.CreateLoad(FV, LI->getName() + ".sroa.speculate.load.false"); NumLoadsSpeculated += 2; - // Transfer alignment and TBAA info if present. + // Transfer alignment and AA info if present. TL->setAlignment(LI->getAlignment()); FL->setAlignment(LI->getAlignment()); - if (MDNode *Tag = LI->getMetadata(LLVMContext::MD_tbaa)) { - TL->setMetadata(LLVMContext::MD_tbaa, Tag); - FL->setMetadata(LLVMContext::MD_tbaa, Tag); + + AAMDNodes Tags; + LI->getAAMetadata(Tags); + if (Tags) { + TL->setAAMetadata(Tags); + FL->setAAMetadata(Tags); } Value *V = IRB.CreateSelect(SI.getCondition(), TL, FL, @@ -1332,7 +1609,8 @@ static Value *getNaturalGEPRecursively(IRBuilderTy &IRB, const DataLayout &DL, SmallVectorImpl<Value *> &Indices, Twine NamePrefix) { if (Offset == 0) - return getNaturalGEPWithType(IRB, DL, Ptr, Ty, TargetTy, Indices, NamePrefix); + return getNaturalGEPWithType(IRB, DL, Ptr, Ty, TargetTy, Indices, + NamePrefix); // We can't recurse through pointer types. if (Ty->isPointerTy()) @@ -1440,8 +1718,7 @@ static Value *getNaturalGEPWithOffset(IRBuilderTy &IRB, const DataLayout &DL, /// a single GEP as possible, thus making each GEP more independent of the /// surrounding code. static Value *getAdjustedPtr(IRBuilderTy &IRB, const DataLayout &DL, Value *Ptr, - APInt Offset, Type *PointerTy, - Twine NamePrefix) { + APInt Offset, Type *PointerTy, Twine NamePrefix) { // Even though we don't look through PHI nodes, we could be called on an // instruction in an unreachable block, which may be on a cycle. SmallPtrSet<Value *, 4> Visited; @@ -1450,8 +1727,9 @@ static Value *getAdjustedPtr(IRBuilderTy &IRB, const DataLayout &DL, Value *Ptr, // We may end up computing an offset pointer that has the wrong type. If we // never are able to compute one directly that has the correct type, we'll - // fall back to it, so keep it around here. + // fall back to it, so keep it and the base it was computed from around here. Value *OffsetPtr = nullptr; + Value *OffsetBasePtr; // Remember any i8 pointer we come across to re-use if we need to do a raw // byte offset. @@ -1468,7 +1746,7 @@ static Value *getAdjustedPtr(IRBuilderTy &IRB, const DataLayout &DL, Value *Ptr, break; Offset += GEPOffset; Ptr = GEP->getPointerOperand(); - if (!Visited.insert(Ptr)) + if (!Visited.insert(Ptr).second) break; } @@ -1476,16 +1754,19 @@ static Value *getAdjustedPtr(IRBuilderTy &IRB, const DataLayout &DL, Value *Ptr, Indices.clear(); if (Value *P = getNaturalGEPWithOffset(IRB, DL, Ptr, Offset, TargetTy, Indices, NamePrefix)) { - if (P->getType() == PointerTy) { - // Zap any offset pointer that we ended up computing in previous rounds. - if (OffsetPtr && OffsetPtr->use_empty()) - if (Instruction *I = dyn_cast<Instruction>(OffsetPtr)) - I->eraseFromParent(); + // If we have a new natural pointer at the offset, clear out any old + // offset pointer we computed. Unless it is the base pointer or + // a non-instruction, we built a GEP we don't need. Zap it. + if (OffsetPtr && OffsetPtr != OffsetBasePtr) + if (Instruction *I = dyn_cast<Instruction>(OffsetPtr)) { + assert(I->use_empty() && "Built a GEP with uses some how!"); + I->eraseFromParent(); + } + OffsetPtr = P; + OffsetBasePtr = Ptr; + // If we also found a pointer of the right type, we're done. + if (P->getType() == PointerTy) return P; - } - if (!OffsetPtr) { - OffsetPtr = P; - } } // Stash this pointer if we've found an i8*. @@ -1505,7 +1786,7 @@ static Value *getAdjustedPtr(IRBuilderTy &IRB, const DataLayout &DL, Value *Ptr, break; } assert(Ptr->getType()->isPointerTy() && "Unexpected operand type!"); - } while (Visited.insert(Ptr)); + } while (Visited.insert(Ptr).second); if (!OffsetPtr) { if (!Int8Ptr) { @@ -1515,9 +1796,10 @@ static Value *getAdjustedPtr(IRBuilderTy &IRB, const DataLayout &DL, Value *Ptr, Int8PtrOffset = Offset; } - OffsetPtr = Int8PtrOffset == 0 ? Int8Ptr : - IRB.CreateInBoundsGEP(Int8Ptr, IRB.getInt(Int8PtrOffset), - NamePrefix + "sroa_raw_idx"); + OffsetPtr = Int8PtrOffset == 0 + ? Int8Ptr + : IRB.CreateInBoundsGEP(Int8Ptr, IRB.getInt(Int8PtrOffset), + NamePrefix + "sroa_raw_idx"); } Ptr = OffsetPtr; @@ -1528,6 +1810,27 @@ static Value *getAdjustedPtr(IRBuilderTy &IRB, const DataLayout &DL, Value *Ptr, return Ptr; } +/// \brief Compute the adjusted alignment for a load or store from an offset. +static unsigned getAdjustedAlignment(Instruction *I, uint64_t Offset, + const DataLayout &DL) { + unsigned Alignment; + Type *Ty; + if (auto *LI = dyn_cast<LoadInst>(I)) { + Alignment = LI->getAlignment(); + Ty = LI->getType(); + } else if (auto *SI = dyn_cast<StoreInst>(I)) { + Alignment = SI->getAlignment(); + Ty = SI->getValueOperand()->getType(); + } else { + llvm_unreachable("Only loads and stores are allowed!"); + } + + if (!Alignment) + Alignment = DL.getABITypeAlignment(Ty); + + return MinAlign(Alignment, Offset); +} + /// \brief Test whether we can convert a value from the old to the new type. /// /// This predicate should be used to guard calls to convertValue in order to @@ -1621,39 +1924,43 @@ static Value *convertValue(const DataLayout &DL, IRBuilderTy &IRB, Value *V, /// /// This function is called to test each entry in a partioning which is slated /// for a single slice. -static bool isVectorPromotionViableForSlice( - const DataLayout &DL, AllocaSlices &S, uint64_t SliceBeginOffset, - uint64_t SliceEndOffset, VectorType *Ty, uint64_t ElementSize, - AllocaSlices::const_iterator I) { +static bool isVectorPromotionViableForSlice(AllocaSlices::Partition &P, + const Slice &S, VectorType *Ty, + uint64_t ElementSize, + const DataLayout &DL) { // First validate the slice offsets. uint64_t BeginOffset = - std::max(I->beginOffset(), SliceBeginOffset) - SliceBeginOffset; + std::max(S.beginOffset(), P.beginOffset()) - P.beginOffset(); uint64_t BeginIndex = BeginOffset / ElementSize; if (BeginIndex * ElementSize != BeginOffset || BeginIndex >= Ty->getNumElements()) return false; uint64_t EndOffset = - std::min(I->endOffset(), SliceEndOffset) - SliceBeginOffset; + std::min(S.endOffset(), P.endOffset()) - P.beginOffset(); uint64_t EndIndex = EndOffset / ElementSize; if (EndIndex * ElementSize != EndOffset || EndIndex > Ty->getNumElements()) return false; assert(EndIndex > BeginIndex && "Empty vector!"); uint64_t NumElements = EndIndex - BeginIndex; - Type *SliceTy = - (NumElements == 1) ? Ty->getElementType() - : VectorType::get(Ty->getElementType(), NumElements); + Type *SliceTy = (NumElements == 1) + ? Ty->getElementType() + : VectorType::get(Ty->getElementType(), NumElements); Type *SplitIntTy = Type::getIntNTy(Ty->getContext(), NumElements * ElementSize * 8); - Use *U = I->getUse(); + Use *U = S.getUse(); if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(U->getUser())) { if (MI->isVolatile()) return false; - if (!I->isSplittable()) + if (!S.isSplittable()) return false; // Skip any unsplittable intrinsics. + } else if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(U->getUser())) { + if (II->getIntrinsicID() != Intrinsic::lifetime_start && + II->getIntrinsicID() != Intrinsic::lifetime_end) + return false; } else if (U->get()->getType()->getPointerElementType()->isStructTy()) { // Disable vector promotion when there are loads or stores of an FCA. return false; @@ -1661,8 +1968,7 @@ static bool isVectorPromotionViableForSlice( if (LI->isVolatile()) return false; Type *LTy = LI->getType(); - if (SliceBeginOffset > I->beginOffset() || - SliceEndOffset < I->endOffset()) { + if (P.beginOffset() > S.beginOffset() || P.endOffset() < S.endOffset()) { assert(LTy->isIntegerTy()); LTy = SplitIntTy; } @@ -1672,8 +1978,7 @@ static bool isVectorPromotionViableForSlice( if (SI->isVolatile()) return false; Type *STy = SI->getValueOperand()->getType(); - if (SliceBeginOffset > I->beginOffset() || - SliceEndOffset < I->endOffset()) { + if (P.beginOffset() > S.beginOffset() || P.endOffset() < S.endOffset()) { assert(STy->isIntegerTy()); STy = SplitIntTy; } @@ -1695,65 +2000,137 @@ static bool isVectorPromotionViableForSlice( /// SSA value. We only can ensure this for a limited set of operations, and we /// don't want to do the rewrites unless we are confident that the result will /// be promotable, so we have an early test here. -static bool -isVectorPromotionViable(const DataLayout &DL, Type *AllocaTy, AllocaSlices &S, - uint64_t SliceBeginOffset, uint64_t SliceEndOffset, - AllocaSlices::const_iterator I, - AllocaSlices::const_iterator E, - ArrayRef<AllocaSlices::iterator> SplitUses) { - VectorType *Ty = dyn_cast<VectorType>(AllocaTy); - if (!Ty) - return false; +static VectorType *isVectorPromotionViable(AllocaSlices::Partition &P, + const DataLayout &DL) { + // Collect the candidate types for vector-based promotion. Also track whether + // we have different element types. + SmallVector<VectorType *, 4> CandidateTys; + Type *CommonEltTy = nullptr; + bool HaveCommonEltTy = true; + auto CheckCandidateType = [&](Type *Ty) { + if (auto *VTy = dyn_cast<VectorType>(Ty)) { + CandidateTys.push_back(VTy); + if (!CommonEltTy) + CommonEltTy = VTy->getElementType(); + else if (CommonEltTy != VTy->getElementType()) + HaveCommonEltTy = false; + } + }; + // Consider any loads or stores that are the exact size of the slice. + for (const Slice &S : P) + if (S.beginOffset() == P.beginOffset() && + S.endOffset() == P.endOffset()) { + if (auto *LI = dyn_cast<LoadInst>(S.getUse()->getUser())) + CheckCandidateType(LI->getType()); + else if (auto *SI = dyn_cast<StoreInst>(S.getUse()->getUser())) + CheckCandidateType(SI->getValueOperand()->getType()); + } - uint64_t ElementSize = DL.getTypeSizeInBits(Ty->getScalarType()); + // If we didn't find a vector type, nothing to do here. + if (CandidateTys.empty()) + return nullptr; - // While the definition of LLVM vectors is bitpacked, we don't support sizes - // that aren't byte sized. - if (ElementSize % 8) - return false; - assert((DL.getTypeSizeInBits(Ty) % 8) == 0 && - "vector size not a multiple of element size?"); - ElementSize /= 8; + // Remove non-integer vector types if we had multiple common element types. + // FIXME: It'd be nice to replace them with integer vector types, but we can't + // do that until all the backends are known to produce good code for all + // integer vector types. + if (!HaveCommonEltTy) { + CandidateTys.erase(std::remove_if(CandidateTys.begin(), CandidateTys.end(), + [](VectorType *VTy) { + return !VTy->getElementType()->isIntegerTy(); + }), + CandidateTys.end()); + + // If there were no integer vector types, give up. + if (CandidateTys.empty()) + return nullptr; - for (; I != E; ++I) - if (!isVectorPromotionViableForSlice(DL, S, SliceBeginOffset, - SliceEndOffset, Ty, ElementSize, I)) - return false; + // Rank the remaining candidate vector types. This is easy because we know + // they're all integer vectors. We sort by ascending number of elements. + auto RankVectorTypes = [&DL](VectorType *RHSTy, VectorType *LHSTy) { + assert(DL.getTypeSizeInBits(RHSTy) == DL.getTypeSizeInBits(LHSTy) && + "Cannot have vector types of different sizes!"); + assert(RHSTy->getElementType()->isIntegerTy() && + "All non-integer types eliminated!"); + assert(LHSTy->getElementType()->isIntegerTy() && + "All non-integer types eliminated!"); + return RHSTy->getNumElements() < LHSTy->getNumElements(); + }; + std::sort(CandidateTys.begin(), CandidateTys.end(), RankVectorTypes); + CandidateTys.erase( + std::unique(CandidateTys.begin(), CandidateTys.end(), RankVectorTypes), + CandidateTys.end()); + } else { +// The only way to have the same element type in every vector type is to +// have the same vector type. Check that and remove all but one. +#ifndef NDEBUG + for (VectorType *VTy : CandidateTys) { + assert(VTy->getElementType() == CommonEltTy && + "Unaccounted for element type!"); + assert(VTy == CandidateTys[0] && + "Different vector types with the same element type!"); + } +#endif + CandidateTys.resize(1); + } - for (ArrayRef<AllocaSlices::iterator>::const_iterator SUI = SplitUses.begin(), - SUE = SplitUses.end(); - SUI != SUE; ++SUI) - if (!isVectorPromotionViableForSlice(DL, S, SliceBeginOffset, - SliceEndOffset, Ty, ElementSize, *SUI)) + // Try each vector type, and return the one which works. + auto CheckVectorTypeForPromotion = [&](VectorType *VTy) { + uint64_t ElementSize = DL.getTypeSizeInBits(VTy->getElementType()); + + // While the definition of LLVM vectors is bitpacked, we don't support sizes + // that aren't byte sized. + if (ElementSize % 8) return false; + assert((DL.getTypeSizeInBits(VTy) % 8) == 0 && + "vector size not a multiple of element size?"); + ElementSize /= 8; - return true; + for (const Slice &S : P) + if (!isVectorPromotionViableForSlice(P, S, VTy, ElementSize, DL)) + return false; + + for (const Slice *S : P.splitSliceTails()) + if (!isVectorPromotionViableForSlice(P, *S, VTy, ElementSize, DL)) + return false; + + return true; + }; + for (VectorType *VTy : CandidateTys) + if (CheckVectorTypeForPromotion(VTy)) + return VTy; + + return nullptr; } /// \brief Test whether a slice of an alloca is valid for integer widening. /// /// This implements the necessary checking for the \c isIntegerWideningViable /// test below on a single slice of the alloca. -static bool isIntegerWideningViableForSlice(const DataLayout &DL, - Type *AllocaTy, +static bool isIntegerWideningViableForSlice(const Slice &S, uint64_t AllocBeginOffset, - uint64_t Size, AllocaSlices &S, - AllocaSlices::const_iterator I, + Type *AllocaTy, + const DataLayout &DL, bool &WholeAllocaOp) { - uint64_t RelBegin = I->beginOffset() - AllocBeginOffset; - uint64_t RelEnd = I->endOffset() - AllocBeginOffset; + uint64_t Size = DL.getTypeStoreSize(AllocaTy); + + uint64_t RelBegin = S.beginOffset() - AllocBeginOffset; + uint64_t RelEnd = S.endOffset() - AllocBeginOffset; // We can't reasonably handle cases where the load or store extends past // the end of the aloca's type and into its padding. if (RelEnd > Size) return false; - Use *U = I->getUse(); + Use *U = S.getUse(); if (LoadInst *LI = dyn_cast<LoadInst>(U->getUser())) { if (LI->isVolatile()) return false; - if (RelBegin == 0 && RelEnd == Size) + // Note that we don't count vector loads or stores as whole-alloca + // operations which enable integer widening because we would prefer to use + // vector widening instead. + if (!isa<VectorType>(LI->getType()) && RelBegin == 0 && RelEnd == Size) WholeAllocaOp = true; if (IntegerType *ITy = dyn_cast<IntegerType>(LI->getType())) { if (ITy->getBitWidth() < DL.getTypeStoreSizeInBits(ITy)) @@ -1768,7 +2145,10 @@ static bool isIntegerWideningViableForSlice(const DataLayout &DL, Type *ValueTy = SI->getValueOperand()->getType(); if (SI->isVolatile()) return false; - if (RelBegin == 0 && RelEnd == Size) + // Note that we don't count vector loads or stores as whole-alloca + // operations which enable integer widening because we would prefer to use + // vector widening instead. + if (!isa<VectorType>(ValueTy) && RelBegin == 0 && RelEnd == Size) WholeAllocaOp = true; if (IntegerType *ITy = dyn_cast<IntegerType>(ValueTy)) { if (ITy->getBitWidth() < DL.getTypeStoreSizeInBits(ITy)) @@ -1782,7 +2162,7 @@ static bool isIntegerWideningViableForSlice(const DataLayout &DL, } else if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(U->getUser())) { if (MI->isVolatile() || !isa<Constant>(MI->getLength())) return false; - if (!I->isSplittable()) + if (!S.isSplittable()) return false; // Skip any unsplittable intrinsics. } else if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(U->getUser())) { if (II->getIntrinsicID() != Intrinsic::lifetime_start && @@ -1801,12 +2181,8 @@ static bool isIntegerWideningViableForSlice(const DataLayout &DL, /// This is a quick test to check whether we can rewrite the integer loads and /// stores to a particular alloca into wider loads and stores and be able to /// promote the resulting alloca. -static bool -isIntegerWideningViable(const DataLayout &DL, Type *AllocaTy, - uint64_t AllocBeginOffset, AllocaSlices &S, - AllocaSlices::const_iterator I, - AllocaSlices::const_iterator E, - ArrayRef<AllocaSlices::iterator> SplitUses) { +static bool isIntegerWideningViable(AllocaSlices::Partition &P, Type *AllocaTy, + const DataLayout &DL) { uint64_t SizeInBits = DL.getTypeSizeInBits(AllocaTy); // Don't create integer types larger than the maximum bitwidth. if (SizeInBits > IntegerType::MAX_INT_BITS) @@ -1824,25 +2200,24 @@ isIntegerWideningViable(const DataLayout &DL, Type *AllocaTy, !canConvertValue(DL, IntTy, AllocaTy)) return false; - uint64_t Size = DL.getTypeStoreSize(AllocaTy); - // While examining uses, we ensure that the alloca has a covering load or // store. We don't want to widen the integer operations only to fail to // promote due to some other unsplittable entry (which we may make splittable // later). However, if there are only splittable uses, go ahead and assume // that we cover the alloca. - bool WholeAllocaOp = (I != E) ? false : DL.isLegalInteger(SizeInBits); - - for (; I != E; ++I) - if (!isIntegerWideningViableForSlice(DL, AllocaTy, AllocBeginOffset, Size, - S, I, WholeAllocaOp)) + // FIXME: We shouldn't consider split slices that happen to start in the + // partition here... + bool WholeAllocaOp = + P.begin() != P.end() ? false : DL.isLegalInteger(SizeInBits); + + for (const Slice &S : P) + if (!isIntegerWideningViableForSlice(S, P.beginOffset(), AllocaTy, DL, + WholeAllocaOp)) return false; - for (ArrayRef<AllocaSlices::iterator>::const_iterator SUI = SplitUses.begin(), - SUE = SplitUses.end(); - SUI != SUE; ++SUI) - if (!isIntegerWideningViableForSlice(DL, AllocaTy, AllocBeginOffset, Size, - S, *SUI, WholeAllocaOp)) + for (const Slice *S : P.splitSliceTails()) + if (!isIntegerWideningViableForSlice(*S, P.beginOffset(), AllocaTy, DL, + WholeAllocaOp)) return false; return WholeAllocaOp; @@ -1855,9 +2230,9 @@ static Value *extractInteger(const DataLayout &DL, IRBuilderTy &IRB, Value *V, IntegerType *IntTy = cast<IntegerType>(V->getType()); assert(DL.getTypeStoreSize(Ty) + Offset <= DL.getTypeStoreSize(IntTy) && "Element extends past full value"); - uint64_t ShAmt = 8*Offset; + uint64_t ShAmt = 8 * Offset; if (DL.isBigEndian()) - ShAmt = 8*(DL.getTypeStoreSize(IntTy) - DL.getTypeStoreSize(Ty) - Offset); + ShAmt = 8 * (DL.getTypeStoreSize(IntTy) - DL.getTypeStoreSize(Ty) - Offset); if (ShAmt) { V = IRB.CreateLShr(V, ShAmt, Name + ".shift"); DEBUG(dbgs() << " shifted: " << *V << "\n"); @@ -1884,9 +2259,9 @@ static Value *insertInteger(const DataLayout &DL, IRBuilderTy &IRB, Value *Old, } assert(DL.getTypeStoreSize(Ty) + Offset <= DL.getTypeStoreSize(IntTy) && "Element store outside of alloca store"); - uint64_t ShAmt = 8*Offset; + uint64_t ShAmt = 8 * Offset; if (DL.isBigEndian()) - ShAmt = 8*(DL.getTypeStoreSize(IntTy) - DL.getTypeStoreSize(Ty) - Offset); + ShAmt = 8 * (DL.getTypeStoreSize(IntTy) - DL.getTypeStoreSize(Ty) - Offset); if (ShAmt) { V = IRB.CreateShl(V, ShAmt, Name + ".shift"); DEBUG(dbgs() << " shifted: " << *V << "\n"); @@ -1902,9 +2277,8 @@ static Value *insertInteger(const DataLayout &DL, IRBuilderTy &IRB, Value *Old, return V; } -static Value *extractVector(IRBuilderTy &IRB, Value *V, - unsigned BeginIndex, unsigned EndIndex, - const Twine &Name) { +static Value *extractVector(IRBuilderTy &IRB, Value *V, unsigned BeginIndex, + unsigned EndIndex, const Twine &Name) { VectorType *VecTy = cast<VectorType>(V->getType()); unsigned NumElements = EndIndex - BeginIndex; assert(NumElements <= VecTy->getNumElements() && "Too many elements!"); @@ -1919,13 +2293,12 @@ static Value *extractVector(IRBuilderTy &IRB, Value *V, return V; } - SmallVector<Constant*, 8> Mask; + SmallVector<Constant *, 8> Mask; Mask.reserve(NumElements); for (unsigned i = BeginIndex; i != EndIndex; ++i) Mask.push_back(IRB.getInt32(i)); V = IRB.CreateShuffleVector(V, UndefValue::get(V->getType()), - ConstantVector::get(Mask), - Name + ".extract"); + ConstantVector::get(Mask), Name + ".extract"); DEBUG(dbgs() << " shuffle: " << *V << "\n"); return V; } @@ -1940,7 +2313,7 @@ static Value *insertVector(IRBuilderTy &IRB, Value *Old, Value *V, // Single element to insert. V = IRB.CreateInsertElement(Old, V, IRB.getInt32(BeginIndex), Name + ".insert"); - DEBUG(dbgs() << " insert: " << *V << "\n"); + DEBUG(dbgs() << " insert: " << *V << "\n"); return V; } @@ -1956,7 +2329,7 @@ static Value *insertVector(IRBuilderTy &IRB, Value *Old, Value *V, // use a shuffle vector to widen it with undef elements, and then // a second shuffle vector to select between the loaded vector and the // incoming vector. - SmallVector<Constant*, 8> Mask; + SmallVector<Constant *, 8> Mask; Mask.reserve(VecTy->getNumElements()); for (unsigned i = 0; i != VecTy->getNumElements(); ++i) if (i >= BeginIndex && i < EndIndex) @@ -1964,8 +2337,7 @@ static Value *insertVector(IRBuilderTy &IRB, Value *Old, Value *V, else Mask.push_back(UndefValue::get(IRB.getInt32Ty())); V = IRB.CreateShuffleVector(V, UndefValue::get(V->getType()), - ConstantVector::get(Mask), - Name + ".expand"); + ConstantVector::get(Mask), Name + ".expand"); DEBUG(dbgs() << " shuffle: " << *V << "\n"); Mask.clear(); @@ -1991,12 +2363,18 @@ class AllocaSliceRewriter : public InstVisitor<AllocaSliceRewriter, bool> { typedef llvm::InstVisitor<AllocaSliceRewriter, bool> Base; const DataLayout &DL; - AllocaSlices &S; + AllocaSlices &AS; SROA &Pass; AllocaInst &OldAI, &NewAI; const uint64_t NewAllocaBeginOffset, NewAllocaEndOffset; Type *NewAllocaTy; + // This is a convenience and flag variable that will be null unless the new + // alloca's integer operations should be widened to this integer type due to + // passing isIntegerWideningViable above. If it is non-null, the desired + // integer type will be stored here for easy access during rewriting. + IntegerType *IntTy; + // If we are rewriting an alloca partition which can be written as pure // vector operations, we stash extra information here. When VecTy is // non-null, we have some strict guarantees about the rewritten alloca: @@ -2010,12 +2388,6 @@ class AllocaSliceRewriter : public InstVisitor<AllocaSliceRewriter, bool> { Type *ElementTy; uint64_t ElementSize; - // This is a convenience and flag variable that will be null unless the new - // alloca's integer operations should be widened to this integer type due to - // passing isIntegerWideningViable above. If it is non-null, the desired - // integer type will be stored here for easy access during rewriting. - IntegerType *IntTy; - // The original offset of the slice currently being rewritten relative to // the original alloca. uint64_t BeginOffset, EndOffset; @@ -2038,25 +2410,25 @@ class AllocaSliceRewriter : public InstVisitor<AllocaSliceRewriter, bool> { IRBuilderTy IRB; public: - AllocaSliceRewriter(const DataLayout &DL, AllocaSlices &S, SROA &Pass, + AllocaSliceRewriter(const DataLayout &DL, AllocaSlices &AS, SROA &Pass, AllocaInst &OldAI, AllocaInst &NewAI, uint64_t NewAllocaBeginOffset, - uint64_t NewAllocaEndOffset, bool IsVectorPromotable, - bool IsIntegerPromotable, + uint64_t NewAllocaEndOffset, bool IsIntegerPromotable, + VectorType *PromotableVecTy, SmallPtrSetImpl<PHINode *> &PHIUsers, SmallPtrSetImpl<SelectInst *> &SelectUsers) - : DL(DL), S(S), Pass(Pass), OldAI(OldAI), NewAI(NewAI), + : DL(DL), AS(AS), Pass(Pass), OldAI(OldAI), NewAI(NewAI), NewAllocaBeginOffset(NewAllocaBeginOffset), NewAllocaEndOffset(NewAllocaEndOffset), NewAllocaTy(NewAI.getAllocatedType()), - VecTy(IsVectorPromotable ? cast<VectorType>(NewAllocaTy) : nullptr), - ElementTy(VecTy ? VecTy->getElementType() : nullptr), - ElementSize(VecTy ? DL.getTypeSizeInBits(ElementTy) / 8 : 0), IntTy(IsIntegerPromotable ? Type::getIntNTy( NewAI.getContext(), DL.getTypeSizeInBits(NewAI.getAllocatedType())) : nullptr), + VecTy(PromotableVecTy), + ElementTy(VecTy ? VecTy->getElementType() : nullptr), + ElementSize(VecTy ? DL.getTypeSizeInBits(ElementTy) / 8 : 0), BeginOffset(), EndOffset(), IsSplittable(), IsSplit(), OldUse(), OldPtr(), PHIUsers(PHIUsers), SelectUsers(SelectUsers), IRB(NewAI.getContext(), ConstantFolder()) { @@ -2065,8 +2437,7 @@ public: "Only multiple-of-8 sized vector elements are viable"); ++NumVectorized; } - assert((!IsVectorPromotable && !IsIntegerPromotable) || - IsVectorPromotable != IsIntegerPromotable); + assert((!IntTy && !VecTy) || (IntTy && !VecTy) || (!IntTy && VecTy)); } bool visit(AllocaSlices::const_iterator I) { @@ -2076,6 +2447,9 @@ public: IsSplittable = I->isSplittable(); IsSplit = BeginOffset < NewAllocaBeginOffset || EndOffset > NewAllocaEndOffset; + DEBUG(dbgs() << " rewriting " << (IsSplit ? "split " : "")); + DEBUG(AS.printSlice(dbgs(), I, "")); + DEBUG(dbgs() << "\n"); // Compute the intersecting offset range. assert(BeginOffset < NewAllocaEndOffset); @@ -2146,7 +2520,8 @@ private: ); } - /// \brief Compute suitable alignment to access this slice of the *new* alloca. + /// \brief Compute suitable alignment to access this slice of the *new* + /// alloca. /// /// You can optionally pass a type to this routine and if that type's ABI /// alignment is itself suitable, this will return zero. @@ -2154,7 +2529,8 @@ private: unsigned NewAIAlign = NewAI.getAlignment(); if (!NewAIAlign) NewAIAlign = DL.getABITypeAlignment(NewAI.getAllocatedType()); - unsigned Align = MinAlign(NewAIAlign, NewBeginOffset - NewAllocaBeginOffset); + unsigned Align = + MinAlign(NewAIAlign, NewBeginOffset - NewAllocaBeginOffset); return (Ty && Align == DL.getABITypeAlignment(Ty)) ? 0 : Align; } @@ -2178,16 +2554,14 @@ private: unsigned EndIndex = getIndex(NewEndOffset); assert(EndIndex > BeginIndex && "Empty vector!"); - Value *V = IRB.CreateAlignedLoad(&NewAI, NewAI.getAlignment(), - "load"); + Value *V = IRB.CreateAlignedLoad(&NewAI, NewAI.getAlignment(), "load"); return extractVector(IRB, V, BeginIndex, EndIndex, "vec"); } Value *rewriteIntegerLoad(LoadInst &LI) { assert(IntTy && "We cannot insert an integer to the alloca"); assert(!LI.isVolatile()); - Value *V = IRB.CreateAlignedLoad(&NewAI, NewAI.getAlignment(), - "load"); + Value *V = IRB.CreateAlignedLoad(&NewAI, NewAI.getAlignment(), "load"); V = convertValue(DL, IRB, V, IntTy); assert(NewBeginOffset >= NewAllocaBeginOffset && "Out of bounds offset"); uint64_t Offset = NewBeginOffset - NewAllocaBeginOffset; @@ -2212,8 +2586,8 @@ private: V = rewriteIntegerLoad(LI); } else if (NewBeginOffset == NewAllocaBeginOffset && canConvertValue(DL, NewAllocaTy, LI.getType())) { - V = IRB.CreateAlignedLoad(&NewAI, NewAI.getAlignment(), - LI.isVolatile(), LI.getName()); + V = IRB.CreateAlignedLoad(&NewAI, NewAI.getAlignment(), LI.isVolatile(), + LI.getName()); } else { Type *LTy = TargetTy->getPointerTo(); V = IRB.CreateAlignedLoad(getNewAllocaSlicePtr(IRB, LTy), @@ -2230,7 +2604,7 @@ private: assert(SliceSize < DL.getTypeStoreSize(LI.getType()) && "Split load isn't smaller than original load"); assert(LI.getType()->getIntegerBitWidth() == - DL.getTypeStoreSizeInBits(LI.getType()) && + DL.getTypeStoreSizeInBits(LI.getType()) && "Non-byte-multiple bit width"); // Move the insertion point just past the load so that we can refer to it. IRB.SetInsertPoint(std::next(BasicBlock::iterator(&LI))); @@ -2238,9 +2612,9 @@ private: // basis for the new value. This allows us to replace the uses of LI with // the computed value, and then replace the placeholder with LI, leaving // LI only used for this computation. - Value *Placeholder - = new LoadInst(UndefValue::get(LI.getType()->getPointerTo())); - V = insertInteger(DL, IRB, Placeholder, V, NewBeginOffset, + Value *Placeholder = + new LoadInst(UndefValue::get(LI.getType()->getPointerTo())); + V = insertInteger(DL, IRB, Placeholder, V, NewBeginOffset - BeginOffset, "insert"); LI.replaceAllUsesWith(V); Placeholder->replaceAllUsesWith(&LI); @@ -2262,15 +2636,14 @@ private: assert(EndIndex > BeginIndex && "Empty vector!"); unsigned NumElements = EndIndex - BeginIndex; assert(NumElements <= VecTy->getNumElements() && "Too many elements!"); - Type *SliceTy = - (NumElements == 1) ? ElementTy - : VectorType::get(ElementTy, NumElements); + Type *SliceTy = (NumElements == 1) + ? ElementTy + : VectorType::get(ElementTy, NumElements); if (V->getType() != SliceTy) V = convertValue(DL, IRB, V, SliceTy); // Mix in the existing elements. - Value *Old = IRB.CreateAlignedLoad(&NewAI, NewAI.getAlignment(), - "load"); + Value *Old = IRB.CreateAlignedLoad(&NewAI, NewAI.getAlignment(), "load"); V = insertVector(IRB, Old, V, BeginIndex, "vec"); } StoreInst *Store = IRB.CreateAlignedStore(V, &NewAI, NewAI.getAlignment()); @@ -2285,13 +2658,12 @@ private: assert(IntTy && "We cannot extract an integer from the alloca"); assert(!SI.isVolatile()); if (DL.getTypeSizeInBits(V->getType()) != IntTy->getBitWidth()) { - Value *Old = IRB.CreateAlignedLoad(&NewAI, NewAI.getAlignment(), - "oldload"); + Value *Old = + IRB.CreateAlignedLoad(&NewAI, NewAI.getAlignment(), "oldload"); Old = convertValue(DL, IRB, Old, IntTy); assert(BeginOffset >= NewAllocaBeginOffset && "Out of bounds offset"); uint64_t Offset = BeginOffset - NewAllocaBeginOffset; - V = insertInteger(DL, IRB, Old, SI.getValueOperand(), Offset, - "insert"); + V = insertInteger(DL, IRB, Old, SI.getValueOperand(), Offset, "insert"); } V = convertValue(DL, IRB, V, NewAllocaTy); StoreInst *Store = IRB.CreateAlignedStore(V, &NewAI, NewAI.getAlignment()); @@ -2319,10 +2691,10 @@ private: assert(V->getType()->isIntegerTy() && "Only integer type loads and stores are split"); assert(V->getType()->getIntegerBitWidth() == - DL.getTypeStoreSizeInBits(V->getType()) && + DL.getTypeStoreSizeInBits(V->getType()) && "Non-byte-multiple bit width"); IntegerType *NarrowTy = Type::getIntNTy(SI.getContext(), SliceSize * 8); - V = extractInteger(DL, IRB, V, NarrowTy, NewBeginOffset, + V = extractInteger(DL, IRB, V, NarrowTy, NewBeginOffset - BeginOffset, "extract"); } @@ -2367,14 +2739,14 @@ private: if (Size == 1) return V; - Type *SplatIntTy = Type::getIntNTy(VTy->getContext(), Size*8); - V = IRB.CreateMul(IRB.CreateZExt(V, SplatIntTy, "zext"), - ConstantExpr::getUDiv( - Constant::getAllOnesValue(SplatIntTy), - ConstantExpr::getZExt( - Constant::getAllOnesValue(V->getType()), - SplatIntTy)), - "isplat"); + Type *SplatIntTy = Type::getIntNTy(VTy->getContext(), Size * 8); + V = IRB.CreateMul( + IRB.CreateZExt(V, SplatIntTy, "zext"), + ConstantExpr::getUDiv( + Constant::getAllOnesValue(SplatIntTy), + ConstantExpr::getZExt(Constant::getAllOnesValue(V->getType()), + SplatIntTy)), + "isplat"); return V; } @@ -2411,11 +2783,11 @@ private: // If this doesn't map cleanly onto the alloca type, and that type isn't // a single value type, just emit a memset. if (!VecTy && !IntTy && - (BeginOffset > NewAllocaBeginOffset || - EndOffset < NewAllocaEndOffset || + (BeginOffset > NewAllocaBeginOffset || EndOffset < NewAllocaEndOffset || + SliceSize != DL.getTypeStoreSize(AllocaTy) || !AllocaTy->isSingleValueType() || !DL.isLegalInteger(DL.getTypeSizeInBits(ScalarTy)) || - DL.getTypeSizeInBits(ScalarTy)%8 != 0)) { + DL.getTypeSizeInBits(ScalarTy) % 8 != 0)) { Type *SizeTy = II.getLength()->getType(); Constant *Size = ConstantInt::get(SizeTy, NewEndOffset - NewBeginOffset); CallInst *New = IRB.CreateMemSet( @@ -2449,8 +2821,8 @@ private: if (NumElements > 1) Splat = getVectorSplat(Splat, NumElements); - Value *Old = IRB.CreateAlignedLoad(&NewAI, NewAI.getAlignment(), - "oldload"); + Value *Old = + IRB.CreateAlignedLoad(&NewAI, NewAI.getAlignment(), "oldload"); V = insertVector(IRB, Old, Splat, BeginIndex, "vec"); } else if (IntTy) { // If this is a memset on an alloca where we can widen stores, insert the @@ -2462,8 +2834,8 @@ private: if (IntTy && (BeginOffset != NewAllocaBeginOffset || EndOffset != NewAllocaBeginOffset)) { - Value *Old = IRB.CreateAlignedLoad(&NewAI, NewAI.getAlignment(), - "oldload"); + Value *Old = + IRB.CreateAlignedLoad(&NewAI, NewAI.getAlignment(), "oldload"); Old = convertValue(DL, IRB, Old, IntTy); uint64_t Offset = NewBeginOffset - NewAllocaBeginOffset; V = insertInteger(DL, IRB, Old, V, Offset, "insert"); @@ -2535,10 +2907,11 @@ private: // If this doesn't map cleanly onto the alloca type, and that type isn't // a single value type, just emit a memcpy. - bool EmitMemCpy - = !VecTy && !IntTy && (BeginOffset > NewAllocaBeginOffset || - EndOffset < NewAllocaEndOffset || - !NewAI.getAllocatedType()->isSingleValueType()); + bool EmitMemCpy = + !VecTy && !IntTy && + (BeginOffset > NewAllocaBeginOffset || EndOffset < NewAllocaEndOffset || + SliceSize != DL.getTypeStoreSize(NewAI.getAllocatedType()) || + !NewAI.getAllocatedType()->isSingleValueType()); // If we're just going to emit a memcpy, the alloca hasn't changed, and the // size hasn't been shrunk based on analysis of the viable range, this is @@ -2559,8 +2932,8 @@ private: // Strip all inbounds GEPs and pointer casts to try to dig out any root // alloca that should be re-examined after rewriting this instruction. Value *OtherPtr = IsDest ? II.getRawSource() : II.getRawDest(); - if (AllocaInst *AI - = dyn_cast<AllocaInst>(OtherPtr->stripInBoundsOffsets())) { + if (AllocaInst *AI = + dyn_cast<AllocaInst>(OtherPtr->stripInBoundsOffsets())) { assert(AI != &OldAI && AI != &NewAI && "Splittable transfers cannot reach the same alloca on both ends."); Pass.Worklist.insert(AI); @@ -2599,8 +2972,8 @@ private: unsigned BeginIndex = VecTy ? getIndex(NewBeginOffset) : 0; unsigned EndIndex = VecTy ? getIndex(NewEndOffset) : 0; unsigned NumElements = EndIndex - BeginIndex; - IntegerType *SubIntTy - = IntTy ? Type::getIntNTy(IntTy->getContext(), Size*8) : nullptr; + IntegerType *SubIntTy = + IntTy ? Type::getIntNTy(IntTy->getContext(), Size * 8) : nullptr; // Reset the other pointer type to match the register type we're going to // use, but using the address space of the original other pointer. @@ -2629,27 +3002,25 @@ private: Value *Src; if (VecTy && !IsWholeAlloca && !IsDest) { - Src = IRB.CreateAlignedLoad(&NewAI, NewAI.getAlignment(), - "load"); + Src = IRB.CreateAlignedLoad(&NewAI, NewAI.getAlignment(), "load"); Src = extractVector(IRB, Src, BeginIndex, EndIndex, "vec"); } else if (IntTy && !IsWholeAlloca && !IsDest) { - Src = IRB.CreateAlignedLoad(&NewAI, NewAI.getAlignment(), - "load"); + Src = IRB.CreateAlignedLoad(&NewAI, NewAI.getAlignment(), "load"); Src = convertValue(DL, IRB, Src, IntTy); uint64_t Offset = NewBeginOffset - NewAllocaBeginOffset; Src = extractInteger(DL, IRB, Src, SubIntTy, Offset, "extract"); } else { - Src = IRB.CreateAlignedLoad(SrcPtr, SrcAlign, II.isVolatile(), - "copyload"); + Src = + IRB.CreateAlignedLoad(SrcPtr, SrcAlign, II.isVolatile(), "copyload"); } if (VecTy && !IsWholeAlloca && IsDest) { - Value *Old = IRB.CreateAlignedLoad(&NewAI, NewAI.getAlignment(), - "oldload"); + Value *Old = + IRB.CreateAlignedLoad(&NewAI, NewAI.getAlignment(), "oldload"); Src = insertVector(IRB, Old, Src, BeginIndex, "vec"); } else if (IntTy && !IsWholeAlloca && IsDest) { - Value *Old = IRB.CreateAlignedLoad(&NewAI, NewAI.getAlignment(), - "oldload"); + Value *Old = + IRB.CreateAlignedLoad(&NewAI, NewAI.getAlignment(), "oldload"); Old = convertValue(DL, IRB, Old, IntTy); uint64_t Offset = NewBeginOffset - NewAllocaBeginOffset; Src = insertInteger(DL, IRB, Old, Src, Offset, "insert"); @@ -2672,8 +3043,8 @@ private: // Record this instruction for deletion. Pass.DeadInsts.insert(&II); - ConstantInt *Size - = ConstantInt::get(cast<IntegerType>(II.getArgOperand(0)->getType()), + ConstantInt *Size = + ConstantInt::get(cast<IntegerType>(II.getArgOperand(0)->getType()), NewEndOffset - NewBeginOffset); Value *Ptr = getNewAllocaSlicePtr(IRB, OldPtr->getType()); Value *New; @@ -2740,7 +3111,6 @@ private: SelectUsers.insert(&SI); return true; } - }; } @@ -2787,7 +3157,7 @@ private: /// This uses a set to de-duplicate users. void enqueueUsers(Instruction &I) { for (Use &U : I.uses()) - if (Visited.insert(U.getUser())) + if (Visited.insert(U.getUser()).second) Queue.push_back(&U); } @@ -2795,8 +3165,7 @@ private: bool visitInstruction(Instruction &I) { return false; } /// \brief Generic recursive split emission class. - template <typename Derived> - class OpSplitter { + template <typename Derived> class OpSplitter { protected: /// The builder used to form new instructions. IRBuilderTy IRB; @@ -2813,7 +3182,7 @@ private: /// Initialize the splitter with an insertion point, Ptr and start with a /// single zero GEP index. OpSplitter(Instruction *InsertionPoint, Value *Ptr) - : IRB(InsertionPoint), GEPIndices(1, IRB.getInt32(0)), Ptr(Ptr) {} + : IRB(InsertionPoint), GEPIndices(1, IRB.getInt32(0)), Ptr(Ptr) {} public: /// \brief Generic recursive split emission routine. @@ -2869,7 +3238,7 @@ private: struct LoadOpSplitter : public OpSplitter<LoadOpSplitter> { LoadOpSplitter(Instruction *InsertionPoint, Value *Ptr) - : OpSplitter<LoadOpSplitter>(InsertionPoint, Ptr) {} + : OpSplitter<LoadOpSplitter>(InsertionPoint, Ptr) {} /// Emit a leaf load of a single value. This is called at the leaves of the /// recursive emission to actually load values. @@ -2900,7 +3269,7 @@ private: struct StoreOpSplitter : public OpSplitter<StoreOpSplitter> { StoreOpSplitter(Instruction *InsertionPoint, Value *Ptr) - : OpSplitter<StoreOpSplitter>(InsertionPoint, Ptr) {} + : OpSplitter<StoreOpSplitter>(InsertionPoint, Ptr) {} /// Emit a leaf store of a single value. This is called at the leaves of the /// recursive emission to actually produce stores. @@ -2908,8 +3277,8 @@ private: assert(Ty->isSingleValueType()); // Extract the single value and store it using the indices. Value *Store = IRB.CreateStore( - IRB.CreateExtractValue(Agg, Indices, Name + ".extract"), - IRB.CreateInBoundsGEP(Ptr, GEPIndices, Name + ".gep")); + IRB.CreateExtractValue(Agg, Indices, Name + ".extract"), + IRB.CreateInBoundsGEP(Ptr, GEPIndices, Name + ".gep")); (void)Store; DEBUG(dbgs() << " to: " << *Store << "\n"); } @@ -2995,8 +3364,8 @@ static Type *stripAggregateTypeWrapping(const DataLayout &DL, Type *Ty) { /// when the size or offset cause either end of type-based partition to be off. /// Also, this is a best-effort routine. It is reasonable to give up and not /// return a type if necessary. -static Type *getTypePartition(const DataLayout &DL, Type *Ty, - uint64_t Offset, uint64_t Size) { +static Type *getTypePartition(const DataLayout &DL, Type *Ty, uint64_t Offset, + uint64_t Size) { if (Offset == 0 && DL.getTypeAllocSize(Ty) == Size) return stripAggregateTypeWrapping(DL, Ty); if (Offset > DL.getTypeAllocSize(Ty) || @@ -3088,8 +3457,8 @@ static Type *getTypePartition(const DataLayout &DL, Type *Ty, } // Try to build up a sub-structure. - StructType *SubTy = StructType::get(STy->getContext(), makeArrayRef(EI, EE), - STy->isPacked()); + StructType *SubTy = + StructType::get(STy->getContext(), makeArrayRef(EI, EE), STy->isPacked()); const StructLayout *SubSL = DL.getStructLayout(SubTy); if (Size != SubSL->getSizeInBytes()) return nullptr; // The sub-struct doesn't have quite the size needed. @@ -3097,6 +3466,494 @@ static Type *getTypePartition(const DataLayout &DL, Type *Ty, return SubTy; } +/// \brief Pre-split loads and stores to simplify rewriting. +/// +/// We want to break up the splittable load+store pairs as much as +/// possible. This is important to do as a preprocessing step, as once we +/// start rewriting the accesses to partitions of the alloca we lose the +/// necessary information to correctly split apart paired loads and stores +/// which both point into this alloca. The case to consider is something like +/// the following: +/// +/// %a = alloca [12 x i8] +/// %gep1 = getelementptr [12 x i8]* %a, i32 0, i32 0 +/// %gep2 = getelementptr [12 x i8]* %a, i32 0, i32 4 +/// %gep3 = getelementptr [12 x i8]* %a, i32 0, i32 8 +/// %iptr1 = bitcast i8* %gep1 to i64* +/// %iptr2 = bitcast i8* %gep2 to i64* +/// %fptr1 = bitcast i8* %gep1 to float* +/// %fptr2 = bitcast i8* %gep2 to float* +/// %fptr3 = bitcast i8* %gep3 to float* +/// store float 0.0, float* %fptr1 +/// store float 1.0, float* %fptr2 +/// %v = load i64* %iptr1 +/// store i64 %v, i64* %iptr2 +/// %f1 = load float* %fptr2 +/// %f2 = load float* %fptr3 +/// +/// Here we want to form 3 partitions of the alloca, each 4 bytes large, and +/// promote everything so we recover the 2 SSA values that should have been +/// there all along. +/// +/// \returns true if any changes are made. +bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { + DEBUG(dbgs() << "Pre-splitting loads and stores\n"); + + // Track the loads and stores which are candidates for pre-splitting here, in + // the order they first appear during the partition scan. These give stable + // iteration order and a basis for tracking which loads and stores we + // actually split. + SmallVector<LoadInst *, 4> Loads; + SmallVector<StoreInst *, 4> Stores; + + // We need to accumulate the splits required of each load or store where we + // can find them via a direct lookup. This is important to cross-check loads + // and stores against each other. We also track the slice so that we can kill + // all the slices that end up split. + struct SplitOffsets { + Slice *S; + std::vector<uint64_t> Splits; + }; + SmallDenseMap<Instruction *, SplitOffsets, 8> SplitOffsetsMap; + + // Track loads out of this alloca which cannot, for any reason, be pre-split. + // This is important as we also cannot pre-split stores of those loads! + // FIXME: This is all pretty gross. It means that we can be more aggressive + // in pre-splitting when the load feeding the store happens to come from + // a separate alloca. Put another way, the effectiveness of SROA would be + // decreased by a frontend which just concatenated all of its local allocas + // into one big flat alloca. But defeating such patterns is exactly the job + // SROA is tasked with! Sadly, to not have this discrepancy we would have + // change store pre-splitting to actually force pre-splitting of the load + // that feeds it *and all stores*. That makes pre-splitting much harder, but + // maybe it would make it more principled? + SmallPtrSet<LoadInst *, 8> UnsplittableLoads; + + DEBUG(dbgs() << " Searching for candidate loads and stores\n"); + for (auto &P : AS.partitions()) { + for (Slice &S : P) { + Instruction *I = cast<Instruction>(S.getUse()->getUser()); + if (!S.isSplittable() ||S.endOffset() <= P.endOffset()) { + // If this was a load we have to track that it can't participate in any + // pre-splitting! + if (auto *LI = dyn_cast<LoadInst>(I)) + UnsplittableLoads.insert(LI); + continue; + } + assert(P.endOffset() > S.beginOffset() && + "Empty or backwards partition!"); + + // Determine if this is a pre-splittable slice. + if (auto *LI = dyn_cast<LoadInst>(I)) { + assert(!LI->isVolatile() && "Cannot split volatile loads!"); + + // The load must be used exclusively to store into other pointers for + // us to be able to arbitrarily pre-split it. The stores must also be + // simple to avoid changing semantics. + auto IsLoadSimplyStored = [](LoadInst *LI) { + for (User *LU : LI->users()) { + auto *SI = dyn_cast<StoreInst>(LU); + if (!SI || !SI->isSimple()) + return false; + } + return true; + }; + if (!IsLoadSimplyStored(LI)) { + UnsplittableLoads.insert(LI); + continue; + } + + Loads.push_back(LI); + } else if (auto *SI = dyn_cast<StoreInst>(S.getUse()->getUser())) { + if (!SI || + S.getUse() != &SI->getOperandUse(SI->getPointerOperandIndex())) + continue; + auto *StoredLoad = dyn_cast<LoadInst>(SI->getValueOperand()); + if (!StoredLoad || !StoredLoad->isSimple()) + continue; + assert(!SI->isVolatile() && "Cannot split volatile stores!"); + + Stores.push_back(SI); + } else { + // Other uses cannot be pre-split. + continue; + } + + // Record the initial split. + DEBUG(dbgs() << " Candidate: " << *I << "\n"); + auto &Offsets = SplitOffsetsMap[I]; + assert(Offsets.Splits.empty() && + "Should not have splits the first time we see an instruction!"); + Offsets.S = &S; + Offsets.Splits.push_back(P.endOffset() - S.beginOffset()); + } + + // Now scan the already split slices, and add a split for any of them which + // we're going to pre-split. + for (Slice *S : P.splitSliceTails()) { + auto SplitOffsetsMapI = + SplitOffsetsMap.find(cast<Instruction>(S->getUse()->getUser())); + if (SplitOffsetsMapI == SplitOffsetsMap.end()) + continue; + auto &Offsets = SplitOffsetsMapI->second; + + assert(Offsets.S == S && "Found a mismatched slice!"); + assert(!Offsets.Splits.empty() && + "Cannot have an empty set of splits on the second partition!"); + assert(Offsets.Splits.back() == + P.beginOffset() - Offsets.S->beginOffset() && + "Previous split does not end where this one begins!"); + + // Record each split. The last partition's end isn't needed as the size + // of the slice dictates that. + if (S->endOffset() > P.endOffset()) + Offsets.Splits.push_back(P.endOffset() - Offsets.S->beginOffset()); + } + } + + // We may have split loads where some of their stores are split stores. For + // such loads and stores, we can only pre-split them if their splits exactly + // match relative to their starting offset. We have to verify this prior to + // any rewriting. + Stores.erase( + std::remove_if(Stores.begin(), Stores.end(), + [&UnsplittableLoads, &SplitOffsetsMap](StoreInst *SI) { + // Lookup the load we are storing in our map of split + // offsets. + auto *LI = cast<LoadInst>(SI->getValueOperand()); + // If it was completely unsplittable, then we're done, + // and this store can't be pre-split. + if (UnsplittableLoads.count(LI)) + return true; + + auto LoadOffsetsI = SplitOffsetsMap.find(LI); + if (LoadOffsetsI == SplitOffsetsMap.end()) + return false; // Unrelated loads are definitely safe. + auto &LoadOffsets = LoadOffsetsI->second; + + // Now lookup the store's offsets. + auto &StoreOffsets = SplitOffsetsMap[SI]; + + // If the relative offsets of each split in the load and + // store match exactly, then we can split them and we + // don't need to remove them here. + if (LoadOffsets.Splits == StoreOffsets.Splits) + return false; + + DEBUG(dbgs() + << " Mismatched splits for load and store:\n" + << " " << *LI << "\n" + << " " << *SI << "\n"); + + // We've found a store and load that we need to split + // with mismatched relative splits. Just give up on them + // and remove both instructions from our list of + // candidates. + UnsplittableLoads.insert(LI); + return true; + }), + Stores.end()); + // Now we have to go *back* through all te stores, because a later store may + // have caused an earlier store's load to become unsplittable and if it is + // unsplittable for the later store, then we can't rely on it being split in + // the earlier store either. + Stores.erase(std::remove_if(Stores.begin(), Stores.end(), + [&UnsplittableLoads](StoreInst *SI) { + auto *LI = + cast<LoadInst>(SI->getValueOperand()); + return UnsplittableLoads.count(LI); + }), + Stores.end()); + // Once we've established all the loads that can't be split for some reason, + // filter any that made it into our list out. + Loads.erase(std::remove_if(Loads.begin(), Loads.end(), + [&UnsplittableLoads](LoadInst *LI) { + return UnsplittableLoads.count(LI); + }), + Loads.end()); + + + // If no loads or stores are left, there is no pre-splitting to be done for + // this alloca. + if (Loads.empty() && Stores.empty()) + return false; + + // From here on, we can't fail and will be building new accesses, so rig up + // an IR builder. + IRBuilderTy IRB(&AI); + + // Collect the new slices which we will merge into the alloca slices. + SmallVector<Slice, 4> NewSlices; + + // Track any allocas we end up splitting loads and stores for so we iterate + // on them. + SmallPtrSet<AllocaInst *, 4> ResplitPromotableAllocas; + + // At this point, we have collected all of the loads and stores we can + // pre-split, and the specific splits needed for them. We actually do the + // splitting in a specific order in order to handle when one of the loads in + // the value operand to one of the stores. + // + // First, we rewrite all of the split loads, and just accumulate each split + // load in a parallel structure. We also build the slices for them and append + // them to the alloca slices. + SmallDenseMap<LoadInst *, std::vector<LoadInst *>, 1> SplitLoadsMap; + std::vector<LoadInst *> SplitLoads; + for (LoadInst *LI : Loads) { + SplitLoads.clear(); + + IntegerType *Ty = cast<IntegerType>(LI->getType()); + uint64_t LoadSize = Ty->getBitWidth() / 8; + assert(LoadSize > 0 && "Cannot have a zero-sized integer load!"); + + auto &Offsets = SplitOffsetsMap[LI]; + assert(LoadSize == Offsets.S->endOffset() - Offsets.S->beginOffset() && + "Slice size should always match load size exactly!"); + uint64_t BaseOffset = Offsets.S->beginOffset(); + assert(BaseOffset + LoadSize > BaseOffset && + "Cannot represent alloca access size using 64-bit integers!"); + + Instruction *BasePtr = cast<Instruction>(LI->getPointerOperand()); + IRB.SetInsertPoint(BasicBlock::iterator(LI)); + + DEBUG(dbgs() << " Splitting load: " << *LI << "\n"); + + uint64_t PartOffset = 0, PartSize = Offsets.Splits.front(); + int Idx = 0, Size = Offsets.Splits.size(); + for (;;) { + auto *PartTy = Type::getIntNTy(Ty->getContext(), PartSize * 8); + auto *PartPtrTy = PartTy->getPointerTo(LI->getPointerAddressSpace()); + LoadInst *PLoad = IRB.CreateAlignedLoad( + getAdjustedPtr(IRB, *DL, BasePtr, + APInt(DL->getPointerSizeInBits(), PartOffset), + PartPtrTy, BasePtr->getName() + "."), + getAdjustedAlignment(LI, PartOffset, *DL), /*IsVolatile*/ false, + LI->getName()); + + // Append this load onto the list of split loads so we can find it later + // to rewrite the stores. + SplitLoads.push_back(PLoad); + + // Now build a new slice for the alloca. + NewSlices.push_back( + Slice(BaseOffset + PartOffset, BaseOffset + PartOffset + PartSize, + &PLoad->getOperandUse(PLoad->getPointerOperandIndex()), + /*IsSplittable*/ false)); + DEBUG(dbgs() << " new slice [" << NewSlices.back().beginOffset() + << ", " << NewSlices.back().endOffset() << "): " << *PLoad + << "\n"); + + // See if we've handled all the splits. + if (Idx >= Size) + break; + + // Setup the next partition. + PartOffset = Offsets.Splits[Idx]; + ++Idx; + PartSize = (Idx < Size ? Offsets.Splits[Idx] : LoadSize) - PartOffset; + } + + // Now that we have the split loads, do the slow walk over all uses of the + // load and rewrite them as split stores, or save the split loads to use + // below if the store is going to be split there anyways. + bool DeferredStores = false; + for (User *LU : LI->users()) { + StoreInst *SI = cast<StoreInst>(LU); + if (!Stores.empty() && SplitOffsetsMap.count(SI)) { + DeferredStores = true; + DEBUG(dbgs() << " Deferred splitting of store: " << *SI << "\n"); + continue; + } + + Value *StoreBasePtr = SI->getPointerOperand(); + IRB.SetInsertPoint(BasicBlock::iterator(SI)); + + DEBUG(dbgs() << " Splitting store of load: " << *SI << "\n"); + + for (int Idx = 0, Size = SplitLoads.size(); Idx < Size; ++Idx) { + LoadInst *PLoad = SplitLoads[Idx]; + uint64_t PartOffset = Idx == 0 ? 0 : Offsets.Splits[Idx - 1]; + auto *PartPtrTy = + PLoad->getType()->getPointerTo(SI->getPointerAddressSpace()); + + StoreInst *PStore = IRB.CreateAlignedStore( + PLoad, getAdjustedPtr(IRB, *DL, StoreBasePtr, + APInt(DL->getPointerSizeInBits(), PartOffset), + PartPtrTy, StoreBasePtr->getName() + "."), + getAdjustedAlignment(SI, PartOffset, *DL), /*IsVolatile*/ false); + (void)PStore; + DEBUG(dbgs() << " +" << PartOffset << ":" << *PStore << "\n"); + } + + // We want to immediately iterate on any allocas impacted by splitting + // this store, and we have to track any promotable alloca (indicated by + // a direct store) as needing to be resplit because it is no longer + // promotable. + if (AllocaInst *OtherAI = dyn_cast<AllocaInst>(StoreBasePtr)) { + ResplitPromotableAllocas.insert(OtherAI); + Worklist.insert(OtherAI); + } else if (AllocaInst *OtherAI = dyn_cast<AllocaInst>( + StoreBasePtr->stripInBoundsOffsets())) { + Worklist.insert(OtherAI); + } + + // Mark the original store as dead. + DeadInsts.insert(SI); + } + + // Save the split loads if there are deferred stores among the users. + if (DeferredStores) + SplitLoadsMap.insert(std::make_pair(LI, std::move(SplitLoads))); + + // Mark the original load as dead and kill the original slice. + DeadInsts.insert(LI); + Offsets.S->kill(); + } + + // Second, we rewrite all of the split stores. At this point, we know that + // all loads from this alloca have been split already. For stores of such + // loads, we can simply look up the pre-existing split loads. For stores of + // other loads, we split those loads first and then write split stores of + // them. + for (StoreInst *SI : Stores) { + auto *LI = cast<LoadInst>(SI->getValueOperand()); + IntegerType *Ty = cast<IntegerType>(LI->getType()); + uint64_t StoreSize = Ty->getBitWidth() / 8; + assert(StoreSize > 0 && "Cannot have a zero-sized integer store!"); + + auto &Offsets = SplitOffsetsMap[SI]; + assert(StoreSize == Offsets.S->endOffset() - Offsets.S->beginOffset() && + "Slice size should always match load size exactly!"); + uint64_t BaseOffset = Offsets.S->beginOffset(); + assert(BaseOffset + StoreSize > BaseOffset && + "Cannot represent alloca access size using 64-bit integers!"); + + Value *LoadBasePtr = LI->getPointerOperand(); + Instruction *StoreBasePtr = cast<Instruction>(SI->getPointerOperand()); + + DEBUG(dbgs() << " Splitting store: " << *SI << "\n"); + + // Check whether we have an already split load. + auto SplitLoadsMapI = SplitLoadsMap.find(LI); + std::vector<LoadInst *> *SplitLoads = nullptr; + if (SplitLoadsMapI != SplitLoadsMap.end()) { + SplitLoads = &SplitLoadsMapI->second; + assert(SplitLoads->size() == Offsets.Splits.size() + 1 && + "Too few split loads for the number of splits in the store!"); + } else { + DEBUG(dbgs() << " of load: " << *LI << "\n"); + } + + uint64_t PartOffset = 0, PartSize = Offsets.Splits.front(); + int Idx = 0, Size = Offsets.Splits.size(); + for (;;) { + auto *PartTy = Type::getIntNTy(Ty->getContext(), PartSize * 8); + auto *PartPtrTy = PartTy->getPointerTo(SI->getPointerAddressSpace()); + + // Either lookup a split load or create one. + LoadInst *PLoad; + if (SplitLoads) { + PLoad = (*SplitLoads)[Idx]; + } else { + IRB.SetInsertPoint(BasicBlock::iterator(LI)); + PLoad = IRB.CreateAlignedLoad( + getAdjustedPtr(IRB, *DL, LoadBasePtr, + APInt(DL->getPointerSizeInBits(), PartOffset), + PartPtrTy, LoadBasePtr->getName() + "."), + getAdjustedAlignment(LI, PartOffset, *DL), /*IsVolatile*/ false, + LI->getName()); + } + + // And store this partition. + IRB.SetInsertPoint(BasicBlock::iterator(SI)); + StoreInst *PStore = IRB.CreateAlignedStore( + PLoad, getAdjustedPtr(IRB, *DL, StoreBasePtr, + APInt(DL->getPointerSizeInBits(), PartOffset), + PartPtrTy, StoreBasePtr->getName() + "."), + getAdjustedAlignment(SI, PartOffset, *DL), /*IsVolatile*/ false); + + // Now build a new slice for the alloca. + NewSlices.push_back( + Slice(BaseOffset + PartOffset, BaseOffset + PartOffset + PartSize, + &PStore->getOperandUse(PStore->getPointerOperandIndex()), + /*IsSplittable*/ false)); + DEBUG(dbgs() << " new slice [" << NewSlices.back().beginOffset() + << ", " << NewSlices.back().endOffset() << "): " << *PStore + << "\n"); + if (!SplitLoads) { + DEBUG(dbgs() << " of split load: " << *PLoad << "\n"); + } + + // See if we've finished all the splits. + if (Idx >= Size) + break; + + // Setup the next partition. + PartOffset = Offsets.Splits[Idx]; + ++Idx; + PartSize = (Idx < Size ? Offsets.Splits[Idx] : StoreSize) - PartOffset; + } + + // We want to immediately iterate on any allocas impacted by splitting + // this load, which is only relevant if it isn't a load of this alloca and + // thus we didn't already split the loads above. We also have to keep track + // of any promotable allocas we split loads on as they can no longer be + // promoted. + if (!SplitLoads) { + if (AllocaInst *OtherAI = dyn_cast<AllocaInst>(LoadBasePtr)) { + assert(OtherAI != &AI && "We can't re-split our own alloca!"); + ResplitPromotableAllocas.insert(OtherAI); + Worklist.insert(OtherAI); + } else if (AllocaInst *OtherAI = dyn_cast<AllocaInst>( + LoadBasePtr->stripInBoundsOffsets())) { + assert(OtherAI != &AI && "We can't re-split our own alloca!"); + Worklist.insert(OtherAI); + } + } + + // Mark the original store as dead now that we've split it up and kill its + // slice. Note that we leave the original load in place unless this store + // was its ownly use. It may in turn be split up if it is an alloca load + // for some other alloca, but it may be a normal load. This may introduce + // redundant loads, but where those can be merged the rest of the optimizer + // should handle the merging, and this uncovers SSA splits which is more + // important. In practice, the original loads will almost always be fully + // split and removed eventually, and the splits will be merged by any + // trivial CSE, including instcombine. + if (LI->hasOneUse()) { + assert(*LI->user_begin() == SI && "Single use isn't this store!"); + DeadInsts.insert(LI); + } + DeadInsts.insert(SI); + Offsets.S->kill(); + } + + // Remove the killed slices that have ben pre-split. + AS.erase(std::remove_if(AS.begin(), AS.end(), [](const Slice &S) { + return S.isDead(); + }), AS.end()); + + // Insert our new slices. This will sort and merge them into the sorted + // sequence. + AS.insert(NewSlices); + + DEBUG(dbgs() << " Pre-split slices:\n"); +#ifndef NDEBUG + for (auto I = AS.begin(), E = AS.end(); I != E; ++I) + DEBUG(AS.print(dbgs(), I, " ")); +#endif + + // Finally, don't try to promote any allocas that new require re-splitting. + // They have already been added to the worklist above. + PromotableAllocas.erase( + std::remove_if( + PromotableAllocas.begin(), PromotableAllocas.end(), + [&](AllocaInst *AI) { return ResplitPromotableAllocas.count(AI); }), + PromotableAllocas.end()); + + return true; +} + /// \brief Rewrite an alloca partition's users. /// /// This routine drives both of the rewriting goals of the SROA pass. It tries @@ -3107,38 +3964,33 @@ static Type *getTypePartition(const DataLayout &DL, Type *Ty, /// appropriate new offsets. It also evaluates how successful the rewrite was /// at enabling promotion and if it was successful queues the alloca to be /// promoted. -bool SROA::rewritePartition(AllocaInst &AI, AllocaSlices &S, - AllocaSlices::iterator B, AllocaSlices::iterator E, - int64_t BeginOffset, int64_t EndOffset, - ArrayRef<AllocaSlices::iterator> SplitUses) { - assert(BeginOffset < EndOffset); - uint64_t SliceSize = EndOffset - BeginOffset; - +bool SROA::rewritePartition(AllocaInst &AI, AllocaSlices &AS, + AllocaSlices::Partition &P) { // Try to compute a friendly type for this partition of the alloca. This // won't always succeed, in which case we fall back to a legal integer type // or an i8 array of an appropriate size. Type *SliceTy = nullptr; - if (Type *CommonUseTy = findCommonType(B, E, EndOffset)) - if (DL->getTypeAllocSize(CommonUseTy) >= SliceSize) + if (Type *CommonUseTy = findCommonType(P.begin(), P.end(), P.endOffset())) + if (DL->getTypeAllocSize(CommonUseTy) >= P.size()) SliceTy = CommonUseTy; if (!SliceTy) if (Type *TypePartitionTy = getTypePartition(*DL, AI.getAllocatedType(), - BeginOffset, SliceSize)) + P.beginOffset(), P.size())) SliceTy = TypePartitionTy; if ((!SliceTy || (SliceTy->isArrayTy() && SliceTy->getArrayElementType()->isIntegerTy())) && - DL->isLegalInteger(SliceSize * 8)) - SliceTy = Type::getIntNTy(*C, SliceSize * 8); + DL->isLegalInteger(P.size() * 8)) + SliceTy = Type::getIntNTy(*C, P.size() * 8); if (!SliceTy) - SliceTy = ArrayType::get(Type::getInt8Ty(*C), SliceSize); - assert(DL->getTypeAllocSize(SliceTy) >= SliceSize); + SliceTy = ArrayType::get(Type::getInt8Ty(*C), P.size()); + assert(DL->getTypeAllocSize(SliceTy) >= P.size()); - bool IsVectorPromotable = isVectorPromotionViable( - *DL, SliceTy, S, BeginOffset, EndOffset, B, E, SplitUses); + bool IsIntegerPromotable = isIntegerWideningViable(P, SliceTy, *DL); - bool IsIntegerPromotable = - !IsVectorPromotable && - isIntegerWideningViable(*DL, SliceTy, BeginOffset, S, B, E, SplitUses); + VectorType *VecTy = + IsIntegerPromotable ? nullptr : isVectorPromotionViable(P, *DL); + if (VecTy) + SliceTy = VecTy; // Check for the case where we're going to rewrite to a new alloca of the // exact same type as the original, and with the same access offsets. In that @@ -3146,7 +3998,7 @@ bool SROA::rewritePartition(AllocaInst &AI, AllocaSlices &S, // perform phi and select speculation. AllocaInst *NewAI; if (SliceTy == AI.getAllocatedType()) { - assert(BeginOffset == 0 && + assert(P.beginOffset() == 0 && "Non-zero begin offset but same alloca type"); NewAI = &AI; // FIXME: We should be able to bail at this point with "nothing changed". @@ -3159,19 +4011,20 @@ bool SROA::rewritePartition(AllocaInst &AI, AllocaSlices &S, // type. Alignment = DL->getABITypeAlignment(AI.getAllocatedType()); } - Alignment = MinAlign(Alignment, BeginOffset); + Alignment = MinAlign(Alignment, P.beginOffset()); // If we will get at least this much alignment from the type alone, leave // the alloca's alignment unconstrained. if (Alignment <= DL->getABITypeAlignment(SliceTy)) Alignment = 0; - NewAI = new AllocaInst(SliceTy, nullptr, Alignment, - AI.getName() + ".sroa." + Twine(B - S.begin()), &AI); + NewAI = new AllocaInst( + SliceTy, nullptr, Alignment, + AI.getName() + ".sroa." + Twine(P.begin() - AS.begin()), &AI); ++NumNewAllocas; } DEBUG(dbgs() << "Rewriting alloca partition " - << "[" << BeginOffset << "," << EndOffset << ") to: " << *NewAI - << "\n"); + << "[" << P.beginOffset() << "," << P.endOffset() + << ") to: " << *NewAI << "\n"); // Track the high watermark on the worklist as it is only relevant for // promoted allocas. We will reset it to this point if the alloca is not in @@ -3181,22 +4034,16 @@ bool SROA::rewritePartition(AllocaInst &AI, AllocaSlices &S, SmallPtrSet<PHINode *, 8> PHIUsers; SmallPtrSet<SelectInst *, 8> SelectUsers; - AllocaSliceRewriter Rewriter(*DL, S, *this, AI, *NewAI, BeginOffset, - EndOffset, IsVectorPromotable, - IsIntegerPromotable, PHIUsers, SelectUsers); + AllocaSliceRewriter Rewriter(*DL, AS, *this, AI, *NewAI, P.beginOffset(), + P.endOffset(), IsIntegerPromotable, VecTy, + PHIUsers, SelectUsers); bool Promotable = true; - for (ArrayRef<AllocaSlices::iterator>::const_iterator SUI = SplitUses.begin(), - SUE = SplitUses.end(); - SUI != SUE; ++SUI) { - DEBUG(dbgs() << " rewriting split "); - DEBUG(S.printSlice(dbgs(), *SUI, "")); - Promotable &= Rewriter.visit(*SUI); + for (Slice *S : P.splitSliceTails()) { + Promotable &= Rewriter.visit(S); ++NumUses; } - for (AllocaSlices::iterator I = B; I != E; ++I) { - DEBUG(dbgs() << " rewriting "); - DEBUG(S.printSlice(dbgs(), I, "")); - Promotable &= Rewriter.visit(I); + for (Slice &S : P) { + Promotable &= Rewriter.visit(&S); ++NumUses; } @@ -3233,14 +4080,10 @@ bool SROA::rewritePartition(AllocaInst &AI, AllocaSlices &S, // If we have either PHIs or Selects to speculate, add them to those // worklists and re-queue the new alloca so that we promote in on the // next iteration. - for (SmallPtrSetImpl<PHINode *>::iterator I = PHIUsers.begin(), - E = PHIUsers.end(); - I != E; ++I) - SpeculatablePHIs.insert(*I); - for (SmallPtrSetImpl<SelectInst *>::iterator I = SelectUsers.begin(), - E = SelectUsers.end(); - I != E; ++I) - SpeculatableSelects.insert(*I); + for (PHINode *PHIUser : PHIUsers) + SpeculatablePHIs.insert(PHIUser); + for (SelectInst *SelectUser : SelectUsers) + SpeculatableSelects.insert(SelectUser); Worklist.insert(NewAI); } } else { @@ -3258,136 +4101,46 @@ bool SROA::rewritePartition(AllocaInst &AI, AllocaSlices &S, return true; } -static void -removeFinishedSplitUses(SmallVectorImpl<AllocaSlices::iterator> &SplitUses, - uint64_t &MaxSplitUseEndOffset, uint64_t Offset) { - if (Offset >= MaxSplitUseEndOffset) { - SplitUses.clear(); - MaxSplitUseEndOffset = 0; - return; - } - - size_t SplitUsesOldSize = SplitUses.size(); - SplitUses.erase(std::remove_if(SplitUses.begin(), SplitUses.end(), - [Offset](const AllocaSlices::iterator &I) { - return I->endOffset() <= Offset; - }), - SplitUses.end()); - if (SplitUsesOldSize == SplitUses.size()) - return; - - // Recompute the max. While this is linear, so is remove_if. - MaxSplitUseEndOffset = 0; - for (SmallVectorImpl<AllocaSlices::iterator>::iterator - SUI = SplitUses.begin(), - SUE = SplitUses.end(); - SUI != SUE; ++SUI) - MaxSplitUseEndOffset = std::max((*SUI)->endOffset(), MaxSplitUseEndOffset); -} - /// \brief Walks the slices of an alloca and form partitions based on them, /// rewriting each of their uses. -bool SROA::splitAlloca(AllocaInst &AI, AllocaSlices &S) { - if (S.begin() == S.end()) +bool SROA::splitAlloca(AllocaInst &AI, AllocaSlices &AS) { + if (AS.begin() == AS.end()) return false; unsigned NumPartitions = 0; bool Changed = false; - SmallVector<AllocaSlices::iterator, 4> SplitUses; - uint64_t MaxSplitUseEndOffset = 0; - - uint64_t BeginOffset = S.begin()->beginOffset(); - - for (AllocaSlices::iterator SI = S.begin(), SJ = std::next(SI), SE = S.end(); - SI != SE; SI = SJ) { - uint64_t MaxEndOffset = SI->endOffset(); - - if (!SI->isSplittable()) { - // When we're forming an unsplittable region, it must always start at the - // first slice and will extend through its end. - assert(BeginOffset == SI->beginOffset()); - - // Form a partition including all of the overlapping slices with this - // unsplittable slice. - while (SJ != SE && SJ->beginOffset() < MaxEndOffset) { - if (!SJ->isSplittable()) - MaxEndOffset = std::max(MaxEndOffset, SJ->endOffset()); - ++SJ; - } - } else { - assert(SI->isSplittable()); // Established above. - // Collect all of the overlapping splittable slices. - while (SJ != SE && SJ->beginOffset() < MaxEndOffset && - SJ->isSplittable()) { - MaxEndOffset = std::max(MaxEndOffset, SJ->endOffset()); - ++SJ; - } + // First try to pre-split loads and stores. + Changed |= presplitLoadsAndStores(AI, AS); - // Back up MaxEndOffset and SJ if we ended the span early when - // encountering an unsplittable slice. - if (SJ != SE && SJ->beginOffset() < MaxEndOffset) { - assert(!SJ->isSplittable()); - MaxEndOffset = SJ->beginOffset(); - } - } - - // Check if we have managed to move the end offset forward yet. If so, - // we'll have to rewrite uses and erase old split uses. - if (BeginOffset < MaxEndOffset) { - // Rewrite a sequence of overlapping slices. - Changed |= - rewritePartition(AI, S, SI, SJ, BeginOffset, MaxEndOffset, SplitUses); - ++NumPartitions; - - removeFinishedSplitUses(SplitUses, MaxSplitUseEndOffset, MaxEndOffset); - } - - // Accumulate all the splittable slices from the [SI,SJ) region which - // overlap going forward. - for (AllocaSlices::iterator SK = SI; SK != SJ; ++SK) - if (SK->isSplittable() && SK->endOffset() > MaxEndOffset) { - SplitUses.push_back(SK); - MaxSplitUseEndOffset = std::max(SK->endOffset(), MaxSplitUseEndOffset); - } - - // If we're already at the end and we have no split uses, we're done. - if (SJ == SE && SplitUses.empty()) - break; - - // If we have no split uses or no gap in offsets, we're ready to move to - // the next slice. - if (SplitUses.empty() || (SJ != SE && MaxEndOffset == SJ->beginOffset())) { - BeginOffset = SJ->beginOffset(); + // Now that we have identified any pre-splitting opportunities, mark any + // splittable (non-whole-alloca) loads and stores as unsplittable. If we fail + // to split these during pre-splitting, we want to force them to be + // rewritten into a partition. + bool IsSorted = true; + for (Slice &S : AS) { + if (!S.isSplittable()) continue; - } - - // Even if we have split slices, if the next slice is splittable and the - // split slices reach it, we can simply set up the beginning offset of the - // next iteration to bridge between them. - if (SJ != SE && SJ->isSplittable() && - MaxSplitUseEndOffset > SJ->beginOffset()) { - BeginOffset = MaxEndOffset; + // FIXME: We currently leave whole-alloca splittable loads and stores. This + // used to be the only splittable loads and stores and we need to be + // confident that the above handling of splittable loads and stores is + // completely sufficient before we forcibly disable the remaining handling. + if (S.beginOffset() == 0 && + S.endOffset() >= DL->getTypeAllocSize(AI.getAllocatedType())) continue; + if (isa<LoadInst>(S.getUse()->getUser()) || + isa<StoreInst>(S.getUse()->getUser())) { + S.makeUnsplittable(); + IsSorted = false; } + } + if (!IsSorted) + std::sort(AS.begin(), AS.end()); - // Otherwise, we have a tail of split slices. Rewrite them with an empty - // range of slices. - uint64_t PostSplitEndOffset = - SJ == SE ? MaxSplitUseEndOffset : SJ->beginOffset(); - - Changed |= rewritePartition(AI, S, SJ, SJ, MaxEndOffset, PostSplitEndOffset, - SplitUses); + // Rewrite each partition. + for (auto &P : AS.partitions()) { + Changed |= rewritePartition(AI, AS, P); ++NumPartitions; - - if (SJ == SE) - break; // Skip the rest, we don't need to do any cleanup. - - removeFinishedSplitUses(SplitUses, MaxSplitUseEndOffset, - PostSplitEndOffset); - - // Now just reset the begin offset for the next iteration. - BeginOffset = SJ->beginOffset(); } NumAllocaPartitions += NumPartitions; @@ -3440,38 +4193,34 @@ bool SROA::runOnAlloca(AllocaInst &AI) { Changed |= AggRewriter.rewrite(AI); // Build the slices using a recursive instruction-visiting builder. - AllocaSlices S(*DL, AI); - DEBUG(S.print(dbgs())); - if (S.isEscaped()) + AllocaSlices AS(*DL, AI); + DEBUG(AS.print(dbgs())); + if (AS.isEscaped()) return Changed; // Delete all the dead users of this alloca before splitting and rewriting it. - for (AllocaSlices::dead_user_iterator DI = S.dead_user_begin(), - DE = S.dead_user_end(); - DI != DE; ++DI) { + for (Instruction *DeadUser : AS.getDeadUsers()) { // Free up everything used by this instruction. - for (Use &DeadOp : (*DI)->operands()) + for (Use &DeadOp : DeadUser->operands()) clobberUse(DeadOp); // Now replace the uses of this instruction. - (*DI)->replaceAllUsesWith(UndefValue::get((*DI)->getType())); + DeadUser->replaceAllUsesWith(UndefValue::get(DeadUser->getType())); // And mark it for deletion. - DeadInsts.insert(*DI); + DeadInsts.insert(DeadUser); Changed = true; } - for (AllocaSlices::dead_op_iterator DO = S.dead_op_begin(), - DE = S.dead_op_end(); - DO != DE; ++DO) { - clobberUse(**DO); + for (Use *DeadOp : AS.getDeadOperands()) { + clobberUse(*DeadOp); Changed = true; } // No slices to split. Leave the dead alloca for a later pass to clean up. - if (S.begin() == S.end()) + if (AS.begin() == AS.end()) return Changed; - Changed |= splitAlloca(AI, S); + Changed |= splitAlloca(AI, AS); DEBUG(dbgs() << " Speculating PHIs\n"); while (!SpeculatablePHIs.empty()) @@ -3493,7 +4242,8 @@ bool SROA::runOnAlloca(AllocaInst &AI) { /// /// We also record the alloca instructions deleted here so that they aren't /// subsequently handed to mem2reg to promote. -void SROA::deleteDeadInstructions(SmallPtrSet<AllocaInst*, 4> &DeletedAllocas) { +void SROA::deleteDeadInstructions( + SmallPtrSetImpl<AllocaInst *> &DeletedAllocas) { while (!DeadInsts.empty()) { Instruction *I = DeadInsts.pop_back_val(); DEBUG(dbgs() << "Deleting dead instruction: " << *I << "\n"); @@ -3518,9 +4268,9 @@ void SROA::deleteDeadInstructions(SmallPtrSet<AllocaInst*, 4> &DeletedAllocas) { static void enqueueUsersInWorklist(Instruction &I, SmallVectorImpl<Instruction *> &Worklist, - SmallPtrSet<Instruction *, 8> &Visited) { + SmallPtrSetImpl<Instruction *> &Visited) { for (User *U : I.users()) - if (Visited.insert(cast<Instruction>(U))) + if (Visited.insert(cast<Instruction>(U)).second) Worklist.push_back(cast<Instruction>(U)); } @@ -3540,14 +4290,14 @@ bool SROA::promoteAllocas(Function &F) { if (DT && !ForceSSAUpdater) { DEBUG(dbgs() << "Promoting allocas with mem2reg...\n"); - PromoteMemToReg(PromotableAllocas, *DT); + PromoteMemToReg(PromotableAllocas, *DT, nullptr, AC); PromotableAllocas.clear(); return true; } DEBUG(dbgs() << "Promoting allocas with SSAUpdater...\n"); SSAUpdater SSA; - DIBuilder DIB(*F.getParent()); + DIBuilder DIB(*F.getParent(), /*AllowUnresolved*/ false); SmallVector<Instruction *, 64> Insts; // We need a worklist to walk the uses of each alloca. @@ -3622,6 +4372,7 @@ bool SROA::runOnFunction(Function &F) { DominatorTreeWrapperPass *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>(); DT = DTWP ? &DTWP->getDomTree() : nullptr; + AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); BasicBlock &EntryBB = F.getEntryBlock(); for (BasicBlock::iterator I = EntryBB.begin(), E = std::prev(EntryBB.end()); @@ -3642,9 +4393,7 @@ bool SROA::runOnFunction(Function &F) { // Remove the deleted allocas from various lists so that we don't try to // continue processing them. if (!DeletedAllocas.empty()) { - auto IsInSet = [&](AllocaInst *AI) { - return DeletedAllocas.count(AI); - }; + auto IsInSet = [&](AllocaInst *AI) { return DeletedAllocas.count(AI); }; Worklist.remove_if(IsInSet); PostPromotionWorklist.remove_if(IsInSet); PromotableAllocas.erase(std::remove_if(PromotableAllocas.begin(), @@ -3665,6 +4414,7 @@ bool SROA::runOnFunction(Function &F) { } void SROA::getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequired<AssumptionCacheTracker>(); if (RequiresDomTree) AU.addRequired<DominatorTreeWrapperPass>(); AU.setPreservesCFG(); diff --git a/lib/Transforms/Scalar/SampleProfile.cpp b/lib/Transforms/Scalar/SampleProfile.cpp index 73c97ffeef4f..179bbf78366d 100644 --- a/lib/Transforms/Scalar/SampleProfile.cpp +++ b/lib/Transforms/Scalar/SampleProfile.cpp @@ -26,7 +26,6 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallSet.h" -#include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringRef.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/PostDominators.h" @@ -42,15 +41,14 @@ #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" #include "llvm/Pass.h" +#include "llvm/ProfileData/SampleProfReader.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" -#include "llvm/Support/LineIterator.h" -#include "llvm/Support/MemoryBuffer.h" -#include "llvm/Support/Regex.h" #include "llvm/Support/raw_ostream.h" #include <cctype> using namespace llvm; +using namespace sampleprof; #define DEBUG_TYPE "sample-profile" @@ -65,76 +63,48 @@ static cl::opt<unsigned> SampleProfileMaxPropagateIterations( "sample block/edge weights through the CFG.")); namespace { -/// \brief Represents the relative location of an instruction. -/// -/// Instruction locations are specified by the line offset from the -/// beginning of the function (marked by the line where the function -/// header is) and the discriminator value within that line. -/// -/// The discriminator value is useful to distinguish instructions -/// that are on the same line but belong to different basic blocks -/// (e.g., the two post-increment instructions in "if (p) x++; else y++;"). -struct InstructionLocation { - InstructionLocation(int L, unsigned D) : LineOffset(L), Discriminator(D) {} - int LineOffset; - unsigned Discriminator; -}; -} - -namespace llvm { -template <> struct DenseMapInfo<InstructionLocation> { - typedef DenseMapInfo<int> OffsetInfo; - typedef DenseMapInfo<unsigned> DiscriminatorInfo; - static inline InstructionLocation getEmptyKey() { - return InstructionLocation(OffsetInfo::getEmptyKey(), - DiscriminatorInfo::getEmptyKey()); - } - static inline InstructionLocation getTombstoneKey() { - return InstructionLocation(OffsetInfo::getTombstoneKey(), - DiscriminatorInfo::getTombstoneKey()); - } - static inline unsigned getHashValue(InstructionLocation Val) { - return DenseMapInfo<std::pair<int, unsigned>>::getHashValue( - std::pair<int, unsigned>(Val.LineOffset, Val.Discriminator)); - } - static inline bool isEqual(InstructionLocation LHS, InstructionLocation RHS) { - return LHS.LineOffset == RHS.LineOffset && - LHS.Discriminator == RHS.Discriminator; - } -}; -} - -namespace { -typedef DenseMap<InstructionLocation, unsigned> BodySampleMap; typedef DenseMap<BasicBlock *, unsigned> BlockWeightMap; typedef DenseMap<BasicBlock *, BasicBlock *> EquivalenceClassMap; typedef std::pair<BasicBlock *, BasicBlock *> Edge; typedef DenseMap<Edge, unsigned> EdgeWeightMap; typedef DenseMap<BasicBlock *, SmallVector<BasicBlock *, 8>> BlockEdgeMap; -/// \brief Representation of the runtime profile for a function. +/// \brief Sample profile pass. /// -/// This data structure contains the runtime profile for a given -/// function. It contains the total number of samples collected -/// in the function and a map of samples collected in every statement. -class SampleFunctionProfile { +/// This pass reads profile data from the file specified by +/// -sample-profile-file and annotates every affected function with the +/// profile information found in that file. +class SampleProfileLoader : public FunctionPass { public: - SampleFunctionProfile() - : TotalSamples(0), TotalHeadSamples(0), HeaderLineno(0), DT(nullptr), - PDT(nullptr), LI(nullptr), Ctx(nullptr) {} + // Class identification, replacement for typeinfo + static char ID; + + SampleProfileLoader(StringRef Name = SampleProfileFile) + : FunctionPass(ID), DT(nullptr), PDT(nullptr), LI(nullptr), Ctx(nullptr), + Reader(), Samples(nullptr), Filename(Name), ProfileIsValid(false) { + initializeSampleProfileLoaderPass(*PassRegistry::getPassRegistry()); + } + + bool doInitialization(Module &M) override; + + void dump() { Reader->dump(); } + + const char *getPassName() const override { return "Sample profile pass"; } + + bool runOnFunction(Function &F) override; + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesCFG(); + AU.addRequired<LoopInfo>(); + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<PostDominatorTree>(); + } +protected: unsigned getFunctionLoc(Function &F); - bool emitAnnotations(Function &F, DominatorTree *DomTree, - PostDominatorTree *PostDomTree, LoopInfo *Loops); + bool emitAnnotations(Function &F); unsigned getInstWeight(Instruction &I); - unsigned getBlockWeight(BasicBlock *B); - void addTotalSamples(unsigned Num) { TotalSamples += Num; } - void addHeadSamples(unsigned Num) { TotalHeadSamples += Num; } - void addBodySamples(int LineOffset, unsigned Discriminator, unsigned Num) { - assert(LineOffset >= 0); - BodySamples[InstructionLocation(LineOffset, Discriminator)] += Num; - } - void print(raw_ostream &OS); + unsigned getBlockWeight(BasicBlock *BB); void printEdgeWeight(raw_ostream &OS, Edge E); void printBlockWeight(raw_ostream &OS, BasicBlock *BB); void printBlockEquivalence(raw_ostream &OS, BasicBlock *BB); @@ -147,32 +117,11 @@ public: unsigned visitEdge(Edge E, unsigned *NumUnknownEdges, Edge *UnknownEdge); void buildEdges(Function &F); bool propagateThroughEdges(Function &F); - bool empty() { return BodySamples.empty(); } -protected: - /// \brief Total number of samples collected inside this function. - /// - /// Samples are cumulative, they include all the samples collected - /// inside this function and all its inlined callees. - unsigned TotalSamples; - - /// \brief Total number of samples collected at the head of the function. - /// FIXME: Use head samples to estimate a cold/hot attribute for the function. - unsigned TotalHeadSamples; - - /// \brief Line number for the function header. Used to compute relative - /// line numbers from the absolute line LOCs found in instruction locations. - /// The relative line numbers are needed to address the samples from the - /// profile file. + /// \brief Line number for the function header. Used to compute absolute + /// line numbers from the relative line numbers found in the profile. unsigned HeaderLineno; - /// \brief Map line offsets to collected samples. - /// - /// Each entry in this map contains the number of samples - /// collected at the corresponding line offset. All line locations - /// are an offset from the start of the function. - BodySampleMap BodySamples; - /// \brief Map basic blocks to their computed weights. /// /// The weight of a basic block is defined to be the maximum @@ -212,105 +161,12 @@ protected: /// \brief LLVM context holding the debug data we need. LLVMContext *Ctx; -}; - -/// \brief Sample-based profile reader. -/// -/// Each profile contains sample counts for all the functions -/// executed. Inside each function, statements are annotated with the -/// collected samples on all the instructions associated with that -/// statement. -/// -/// For this to produce meaningful data, the program needs to be -/// compiled with some debug information (at minimum, line numbers: -/// -gline-tables-only). Otherwise, it will be impossible to match IR -/// instructions to the line numbers collected by the profiler. -/// -/// From the profile file, we are interested in collecting the -/// following information: -/// -/// * A list of functions included in the profile (mangled names). -/// -/// * For each function F: -/// 1. The total number of samples collected in F. -/// -/// 2. The samples collected at each line in F. To provide some -/// protection against source code shuffling, line numbers should -/// be relative to the start of the function. -class SampleModuleProfile { -public: - SampleModuleProfile(const Module &M, StringRef F) - : Profiles(0), Filename(F), M(M) {} - - void dump(); - bool loadText(); - void loadNative() { llvm_unreachable("not implemented"); } - void printFunctionProfile(raw_ostream &OS, StringRef FName); - void dumpFunctionProfile(StringRef FName); - SampleFunctionProfile &getProfile(const Function &F) { - return Profiles[F.getName()]; - } - /// \brief Report a parse error message. - void reportParseError(int64_t LineNumber, Twine Msg) const { - DiagnosticInfoSampleProfile Diag(Filename.data(), LineNumber, Msg); - M.getContext().diagnose(Diag); - } - -protected: - /// \brief Map every function to its associated profile. - /// - /// The profile of every function executed at runtime is collected - /// in the structure SampleFunctionProfile. This maps function objects - /// to their corresponding profiles. - StringMap<SampleFunctionProfile> Profiles; - - /// \brief Path name to the file holding the profile data. - /// - /// The format of this file is defined by each profiler - /// independently. If possible, the profiler should have a text - /// version of the profile format to be used in constructing test - /// cases and debugging. - StringRef Filename; - - /// \brief Module being compiled. Used mainly to access the current - /// LLVM context for diagnostics. - const Module &M; -}; - -/// \brief Sample profile pass. -/// -/// This pass reads profile data from the file specified by -/// -sample-profile-file and annotates every affected function with the -/// profile information found in that file. -class SampleProfileLoader : public FunctionPass { -public: - // Class identification, replacement for typeinfo - static char ID; - - SampleProfileLoader(StringRef Name = SampleProfileFile) - : FunctionPass(ID), Profiler(), Filename(Name), ProfileIsValid(false) { - initializeSampleProfileLoaderPass(*PassRegistry::getPassRegistry()); - } - - bool doInitialization(Module &M) override; - - void dump() { Profiler->dump(); } - - const char *getPassName() const override { return "Sample profile pass"; } - - bool runOnFunction(Function &F) override; - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.setPreservesCFG(); - AU.addRequired<LoopInfo>(); - AU.addRequired<DominatorTreeWrapperPass>(); - AU.addRequired<PostDominatorTree>(); - } - -protected: /// \brief Profile reader object. - std::unique_ptr<SampleModuleProfile> Profiler; + std::unique_ptr<SampleProfileReader> Reader; + + /// \brief Samples collected for the body of this function. + FunctionSamples *Samples; /// \brief Name of the profile file to load. StringRef Filename; @@ -320,26 +176,11 @@ protected: }; } -/// \brief Print this function profile on stream \p OS. -/// -/// \param OS Stream to emit the output to. -void SampleFunctionProfile::print(raw_ostream &OS) { - OS << TotalSamples << ", " << TotalHeadSamples << ", " << BodySamples.size() - << " sampled lines\n"; - for (BodySampleMap::const_iterator SI = BodySamples.begin(), - SE = BodySamples.end(); - SI != SE; ++SI) - OS << "\tline offset: " << SI->first.LineOffset - << ", discriminator: " << SI->first.Discriminator - << ", number of samples: " << SI->second << "\n"; - OS << "\n"; -} - /// \brief Print the weight of edge \p E on stream \p OS. /// /// \param OS Stream to emit the output to. /// \param E Edge to print. -void SampleFunctionProfile::printEdgeWeight(raw_ostream &OS, Edge E) { +void SampleProfileLoader::printEdgeWeight(raw_ostream &OS, Edge E) { OS << "weight[" << E.first->getName() << "->" << E.second->getName() << "]: " << EdgeWeights[E] << "\n"; } @@ -348,8 +189,8 @@ void SampleFunctionProfile::printEdgeWeight(raw_ostream &OS, Edge E) { /// /// \param OS Stream to emit the output to. /// \param BB Block to print. -void SampleFunctionProfile::printBlockEquivalence(raw_ostream &OS, - BasicBlock *BB) { +void SampleProfileLoader::printBlockEquivalence(raw_ostream &OS, + BasicBlock *BB) { BasicBlock *Equiv = EquivalenceClass[BB]; OS << "equivalence[" << BB->getName() << "]: " << ((Equiv) ? EquivalenceClass[BB]->getName() : "NONE") << "\n"; @@ -359,174 +200,10 @@ void SampleFunctionProfile::printBlockEquivalence(raw_ostream &OS, /// /// \param OS Stream to emit the output to. /// \param BB Block to print. -void SampleFunctionProfile::printBlockWeight(raw_ostream &OS, BasicBlock *BB) { +void SampleProfileLoader::printBlockWeight(raw_ostream &OS, BasicBlock *BB) { OS << "weight[" << BB->getName() << "]: " << BlockWeights[BB] << "\n"; } -/// \brief Print the function profile for \p FName on stream \p OS. -/// -/// \param OS Stream to emit the output to. -/// \param FName Name of the function to print. -void SampleModuleProfile::printFunctionProfile(raw_ostream &OS, - StringRef FName) { - OS << "Function: " << FName << ":\n"; - Profiles[FName].print(OS); -} - -/// \brief Dump the function profile for \p FName. -/// -/// \param FName Name of the function to print. -void SampleModuleProfile::dumpFunctionProfile(StringRef FName) { - printFunctionProfile(dbgs(), FName); -} - -/// \brief Dump all the function profiles found. -void SampleModuleProfile::dump() { - for (StringMap<SampleFunctionProfile>::const_iterator I = Profiles.begin(), - E = Profiles.end(); - I != E; ++I) - dumpFunctionProfile(I->getKey()); -} - -/// \brief Load samples from a text file. -/// -/// The file contains a list of samples for every function executed at -/// runtime. Each function profile has the following format: -/// -/// function1:total_samples:total_head_samples -/// offset1[.discriminator]: number_of_samples [fn1:num fn2:num ... ] -/// offset2[.discriminator]: number_of_samples [fn3:num fn4:num ... ] -/// ... -/// offsetN[.discriminator]: number_of_samples [fn5:num fn6:num ... ] -/// -/// Function names must be mangled in order for the profile loader to -/// match them in the current translation unit. The two numbers in the -/// function header specify how many total samples were accumulated in -/// the function (first number), and the total number of samples accumulated -/// at the prologue of the function (second number). This head sample -/// count provides an indicator of how frequent is the function invoked. -/// -/// Each sampled line may contain several items. Some are optional -/// (marked below): -/// -/// a- Source line offset. This number represents the line number -/// in the function where the sample was collected. The line number -/// is always relative to the line where symbol of the function -/// is defined. So, if the function has its header at line 280, -/// the offset 13 is at line 293 in the file. -/// -/// b- [OPTIONAL] Discriminator. This is used if the sampled program -/// was compiled with DWARF discriminator support -/// (http://wiki.dwarfstd.org/index.php?title=Path_Discriminators) -/// -/// c- Number of samples. This is the number of samples collected by -/// the profiler at this source location. -/// -/// d- [OPTIONAL] Potential call targets and samples. If present, this -/// line contains a call instruction. This models both direct and -/// indirect calls. Each called target is listed together with the -/// number of samples. For example, -/// -/// 130: 7 foo:3 bar:2 baz:7 -/// -/// The above means that at relative line offset 130 there is a -/// call instruction that calls one of foo(), bar() and baz(). With -/// baz() being the relatively more frequent call target. -/// -/// FIXME: This is currently unhandled, but it has a lot of -/// potential for aiding the inliner. -/// -/// -/// Since this is a flat profile, a function that shows up more than -/// once gets all its samples aggregated across all its instances. -/// -/// FIXME: flat profiles are too imprecise to provide good optimization -/// opportunities. Convert them to context-sensitive profile. -/// -/// This textual representation is useful to generate unit tests and -/// for debugging purposes, but it should not be used to generate -/// profiles for large programs, as the representation is extremely -/// inefficient. -/// -/// \returns true if the file was loaded successfully, false otherwise. -bool SampleModuleProfile::loadText() { - ErrorOr<std::unique_ptr<MemoryBuffer>> BufferOrErr = - MemoryBuffer::getFile(Filename); - if (std::error_code EC = BufferOrErr.getError()) { - std::string Msg(EC.message()); - M.getContext().diagnose(DiagnosticInfoSampleProfile(Filename.data(), Msg)); - return false; - } - std::unique_ptr<MemoryBuffer> Buffer = std::move(BufferOrErr.get()); - line_iterator LineIt(*Buffer, '#'); - - // Read the profile of each function. Since each function may be - // mentioned more than once, and we are collecting flat profiles, - // accumulate samples as we parse them. - Regex HeadRE("^([^0-9].*):([0-9]+):([0-9]+)$"); - Regex LineSample("^([0-9]+)\\.?([0-9]+)?: ([0-9]+)(.*)$"); - while (!LineIt.is_at_eof()) { - // Read the header of each function. - // - // Note that for function identifiers we are actually expecting - // mangled names, but we may not always get them. This happens when - // the compiler decides not to emit the function (e.g., it was inlined - // and removed). In this case, the binary will not have the linkage - // name for the function, so the profiler will emit the function's - // unmangled name, which may contain characters like ':' and '>' in its - // name (member functions, templates, etc). - // - // The only requirement we place on the identifier, then, is that it - // should not begin with a number. - SmallVector<StringRef, 3> Matches; - if (!HeadRE.match(*LineIt, &Matches)) { - reportParseError(LineIt.line_number(), - "Expected 'mangled_name:NUM:NUM', found " + *LineIt); - return false; - } - assert(Matches.size() == 4); - StringRef FName = Matches[1]; - unsigned NumSamples, NumHeadSamples; - Matches[2].getAsInteger(10, NumSamples); - Matches[3].getAsInteger(10, NumHeadSamples); - Profiles[FName] = SampleFunctionProfile(); - SampleFunctionProfile &FProfile = Profiles[FName]; - FProfile.addTotalSamples(NumSamples); - FProfile.addHeadSamples(NumHeadSamples); - ++LineIt; - - // Now read the body. The body of the function ends when we reach - // EOF or when we see the start of the next function. - while (!LineIt.is_at_eof() && isdigit((*LineIt)[0])) { - if (!LineSample.match(*LineIt, &Matches)) { - reportParseError( - LineIt.line_number(), - "Expected 'NUM[.NUM]: NUM[ mangled_name:NUM]*', found " + *LineIt); - return false; - } - assert(Matches.size() == 5); - unsigned LineOffset, NumSamples, Discriminator = 0; - Matches[1].getAsInteger(10, LineOffset); - if (Matches[2] != "") - Matches[2].getAsInteger(10, Discriminator); - Matches[3].getAsInteger(10, NumSamples); - - // FIXME: Handle called targets (in Matches[4]). - - // When dealing with instruction weights, we use the value - // zero to indicate the absence of a sample. If we read an - // actual zero from the profile file, return it as 1 to - // avoid the confusion later on. - if (NumSamples == 0) - NumSamples = 1; - FProfile.addBodySamples(LineOffset, Discriminator, NumSamples); - ++LineIt; - } - } - - return true; -} - /// \brief Get the weight for an instruction. /// /// The "weight" of an instruction \p Inst is the number of samples @@ -538,7 +215,7 @@ bool SampleModuleProfile::loadText() { /// \param Inst Instruction to query. /// /// \returns The profiled weight of I. -unsigned SampleFunctionProfile::getInstWeight(Instruction &Inst) { +unsigned SampleProfileLoader::getInstWeight(Instruction &Inst) { DebugLoc DLoc = Inst.getDebugLoc(); unsigned Lineno = DLoc.getLine(); if (Lineno < HeaderLineno) @@ -547,8 +224,7 @@ unsigned SampleFunctionProfile::getInstWeight(Instruction &Inst) { DILocation DIL(DLoc.getAsMDNode(*Ctx)); int LOffset = Lineno - HeaderLineno; unsigned Discriminator = DIL.getDiscriminator(); - unsigned Weight = - BodySamples.lookup(InstructionLocation(LOffset, Discriminator)); + unsigned Weight = Samples->samplesAt(LOffset, Discriminator); DEBUG(dbgs() << " " << Lineno << "." << Discriminator << ":" << Inst << " (line offset: " << LOffset << "." << Discriminator << " - weight: " << Weight << ")\n"); @@ -557,24 +233,24 @@ unsigned SampleFunctionProfile::getInstWeight(Instruction &Inst) { /// \brief Compute the weight of a basic block. /// -/// The weight of basic block \p B is the maximum weight of all the -/// instructions in B. The weight of \p B is computed and cached in +/// The weight of basic block \p BB is the maximum weight of all the +/// instructions in BB. The weight of \p BB is computed and cached in /// the BlockWeights map. /// -/// \param B The basic block to query. +/// \param BB The basic block to query. /// -/// \returns The computed weight of B. -unsigned SampleFunctionProfile::getBlockWeight(BasicBlock *B) { - // If we've computed B's weight before, return it. +/// \returns The computed weight of BB. +unsigned SampleProfileLoader::getBlockWeight(BasicBlock *BB) { + // If we've computed BB's weight before, return it. std::pair<BlockWeightMap::iterator, bool> Entry = - BlockWeights.insert(std::make_pair(B, 0)); + BlockWeights.insert(std::make_pair(BB, 0)); if (!Entry.second) return Entry.first->second; - // Otherwise, compute and cache B's weight. + // Otherwise, compute and cache BB's weight. unsigned Weight = 0; - for (BasicBlock::iterator I = B->begin(), E = B->end(); I != E; ++I) { - unsigned InstWeight = getInstWeight(*I); + for (auto &I : BB->getInstList()) { + unsigned InstWeight = getInstWeight(I); if (InstWeight > Weight) Weight = InstWeight; } @@ -588,13 +264,13 @@ unsigned SampleFunctionProfile::getBlockWeight(BasicBlock *B) { /// the weights of every basic block in the CFG. /// /// \param F The function to query. -bool SampleFunctionProfile::computeBlockWeights(Function &F) { +bool SampleProfileLoader::computeBlockWeights(Function &F) { bool Changed = false; DEBUG(dbgs() << "Block weights\n"); - for (Function::iterator B = F.begin(), E = F.end(); B != E; ++B) { - unsigned Weight = getBlockWeight(B); + for (auto &BB : F) { + unsigned Weight = getBlockWeight(&BB); Changed |= (Weight > 0); - DEBUG(printBlockWeight(dbgs(), B)); + DEBUG(printBlockWeight(dbgs(), &BB)); } return Changed; @@ -623,16 +299,13 @@ bool SampleFunctionProfile::computeBlockWeights(Function &F) { /// \param DomTree Opposite dominator tree. If \p Descendants is filled /// with blocks from \p BB1's dominator tree, then /// this is the post-dominator tree, and vice versa. -void SampleFunctionProfile::findEquivalencesFor( +void SampleProfileLoader::findEquivalencesFor( BasicBlock *BB1, SmallVector<BasicBlock *, 8> Descendants, DominatorTreeBase<BasicBlock> *DomTree) { - for (SmallVectorImpl<BasicBlock *>::iterator I = Descendants.begin(), - E = Descendants.end(); - I != E; ++I) { - BasicBlock *BB2 = *I; + for (auto *BB2 : Descendants) { bool IsDomParent = DomTree->dominates(BB2, BB1); bool IsInSameLoop = LI->getLoopFor(BB1) == LI->getLoopFor(BB2); - if (BB1 != BB2 && VisitedBlocks.insert(BB2) && IsDomParent && + if (BB1 != BB2 && VisitedBlocks.insert(BB2).second && IsDomParent && IsInSameLoop) { EquivalenceClass[BB2] = BB1; @@ -660,12 +333,12 @@ void SampleFunctionProfile::findEquivalencesFor( /// dominates B2, B2 post-dominates B1 and both are in the same loop. /// /// \param F The function to query. -void SampleFunctionProfile::findEquivalenceClasses(Function &F) { +void SampleProfileLoader::findEquivalenceClasses(Function &F) { SmallVector<BasicBlock *, 8> DominatedBBs; DEBUG(dbgs() << "\nBlock equivalence classes\n"); // Find equivalence sets based on dominance and post-dominance information. - for (Function::iterator B = F.begin(), E = F.end(); B != E; ++B) { - BasicBlock *BB1 = B; + for (auto &BB : F) { + BasicBlock *BB1 = &BB; // Compute BB1's equivalence class once. if (EquivalenceClass.count(BB1)) { @@ -712,8 +385,8 @@ void SampleFunctionProfile::findEquivalenceClasses(Function &F) { // each equivalence class has the largest weight, assign that weight // to all the blocks in that equivalence class. DEBUG(dbgs() << "\nAssign the same weight to all blocks in the same class\n"); - for (Function::iterator B = F.begin(), E = F.end(); B != E; ++B) { - BasicBlock *BB = B; + for (auto &BI : F) { + BasicBlock *BB = &BI; BasicBlock *EquivBB = EquivalenceClass[BB]; if (BB != EquivBB) BlockWeights[BB] = BlockWeights[EquivBB]; @@ -731,8 +404,8 @@ void SampleFunctionProfile::findEquivalenceClasses(Function &F) { /// \param UnknownEdge Set if E has not been visited before. /// /// \returns E's weight, if known. Otherwise, return 0. -unsigned SampleFunctionProfile::visitEdge(Edge E, unsigned *NumUnknownEdges, - Edge *UnknownEdge) { +unsigned SampleProfileLoader::visitEdge(Edge E, unsigned *NumUnknownEdges, + Edge *UnknownEdge) { if (!VisitedEdges.count(E)) { (*NumUnknownEdges)++; *UnknownEdge = E; @@ -753,11 +426,11 @@ unsigned SampleFunctionProfile::visitEdge(Edge E, unsigned *NumUnknownEdges, /// \param F Function to process. /// /// \returns True if new weights were assigned to edges or blocks. -bool SampleFunctionProfile::propagateThroughEdges(Function &F) { +bool SampleProfileLoader::propagateThroughEdges(Function &F) { bool Changed = false; DEBUG(dbgs() << "\nPropagation through edges\n"); - for (Function::iterator BI = F.begin(), EI = F.end(); BI != EI; ++BI) { - BasicBlock *BB = BI; + for (auto &BI : F) { + BasicBlock *BB = &BI; // Visit all the predecessor and successor edges to determine // which ones have a weight assigned already. Note that it doesn't @@ -771,16 +444,16 @@ bool SampleFunctionProfile::propagateThroughEdges(Function &F) { if (i == 0) { // First, visit all predecessor edges. - for (size_t I = 0; I < Predecessors[BB].size(); I++) { - Edge E = std::make_pair(Predecessors[BB][I], BB); + for (auto *Pred : Predecessors[BB]) { + Edge E = std::make_pair(Pred, BB); TotalWeight += visitEdge(E, &NumUnknownEdges, &UnknownEdge); if (E.first == E.second) SelfReferentialEdge = E; } } else { // On the second round, visit all successor edges. - for (size_t I = 0; I < Successors[BB].size(); I++) { - Edge E = std::make_pair(BB, Successors[BB][I]); + for (auto *Succ : Successors[BB]) { + Edge E = std::make_pair(BB, Succ); TotalWeight += visitEdge(E, &NumUnknownEdges, &UnknownEdge); } } @@ -821,7 +494,7 @@ bool SampleFunctionProfile::propagateThroughEdges(Function &F) { << " known. Set weight for block: "; printBlockWeight(dbgs(), BB);); } - if (VisitedBlocks.insert(BB)) + if (VisitedBlocks.insert(BB).second) Changed = true; } else if (NumUnknownEdges == 1 && VisitedBlocks.count(BB)) { // If there is a single unknown edge and the block has been @@ -857,9 +530,9 @@ bool SampleFunctionProfile::propagateThroughEdges(Function &F) { /// /// We are interested in unique edges. If a block B1 has multiple /// edges to another block B2, we only add a single B1->B2 edge. -void SampleFunctionProfile::buildEdges(Function &F) { - for (Function::iterator I = F.begin(), E = F.end(); I != E; ++I) { - BasicBlock *B1 = I; +void SampleProfileLoader::buildEdges(Function &F) { + for (auto &BI : F) { + BasicBlock *B1 = &BI; // Add predecessors for B1. SmallPtrSet<BasicBlock *, 16> Visited; @@ -867,7 +540,7 @@ void SampleFunctionProfile::buildEdges(Function &F) { llvm_unreachable("Found a stale predecessors list in a basic block."); for (pred_iterator PI = pred_begin(B1), PE = pred_end(B1); PI != PE; ++PI) { BasicBlock *B2 = *PI; - if (Visited.insert(B2)) + if (Visited.insert(B2).second) Predecessors[B1].push_back(B2); } @@ -877,7 +550,7 @@ void SampleFunctionProfile::buildEdges(Function &F) { llvm_unreachable("Found a stale successors list in a basic block."); for (succ_iterator SI = succ_begin(B1), SE = succ_end(B1); SI != SE; ++SI) { BasicBlock *B2 = *SI; - if (Visited.insert(B2)) + if (Visited.insert(B2).second) Successors[B1].push_back(B2); } } @@ -885,22 +558,22 @@ void SampleFunctionProfile::buildEdges(Function &F) { /// \brief Propagate weights into edges /// -/// The following rules are applied to every block B in the CFG: +/// The following rules are applied to every block BB in the CFG: /// -/// - If B has a single predecessor/successor, then the weight +/// - If BB has a single predecessor/successor, then the weight /// of that edge is the weight of the block. /// /// - If all incoming or outgoing edges are known except one, and the /// weight of the block is already known, the weight of the unknown /// edge will be the weight of the block minus the sum of all the known -/// edges. If the sum of all the known edges is larger than B's weight, +/// edges. If the sum of all the known edges is larger than BB's weight, /// we set the unknown edge weight to zero. /// /// - If there is a self-referential edge, and the weight of the block is /// known, the weight for that edge is set to the weight of the block /// minus the weight of the other incoming edges to that block (if /// known). -void SampleFunctionProfile::propagateWeights(Function &F) { +void SampleProfileLoader::propagateWeights(Function &F) { bool Changed = true; unsigned i = 0; @@ -920,9 +593,9 @@ void SampleFunctionProfile::propagateWeights(Function &F) { // edge weights computed during propagation. DEBUG(dbgs() << "\nPropagation complete. Setting branch weights\n"); MDBuilder MDB(F.getContext()); - for (Function::iterator I = F.begin(), E = F.end(); I != E; ++I) { - BasicBlock *B = I; - TerminatorInst *TI = B->getTerminator(); + for (auto &BI : F) { + BasicBlock *BB = &BI; + TerminatorInst *TI = BB->getTerminator(); if (TI->getNumSuccessors() == 1) continue; if (!isa<BranchInst>(TI) && !isa<SwitchInst>(TI)) @@ -934,7 +607,7 @@ void SampleFunctionProfile::propagateWeights(Function &F) { bool AllWeightsZero = true; for (unsigned I = 0; I < TI->getNumSuccessors(); ++I) { BasicBlock *Succ = TI->getSuccessor(I); - Edge E = std::make_pair(B, Succ); + Edge E = std::make_pair(BB, Succ); unsigned Weight = EdgeWeights[E]; DEBUG(dbgs() << "\t"; printEdgeWeight(dbgs(), E)); Weights.push_back(Weight); @@ -965,22 +638,17 @@ void SampleFunctionProfile::propagateWeights(Function &F) { /// /// \returns the line number where \p F is defined. If it returns 0, /// it means that there is no debug information available for \p F. -unsigned SampleFunctionProfile::getFunctionLoc(Function &F) { - NamedMDNode *CUNodes = F.getParent()->getNamedMetadata("llvm.dbg.cu"); - if (CUNodes) { - for (unsigned I = 0, E1 = CUNodes->getNumOperands(); I != E1; ++I) { - DICompileUnit CU(CUNodes->getOperand(I)); - DIArray Subprograms = CU.getSubprograms(); - for (unsigned J = 0, E2 = Subprograms.getNumElements(); J != E2; ++J) { - DISubprogram Subprogram(Subprograms.getElement(J)); - if (Subprogram.describes(&F)) - return Subprogram.getLineNumber(); - } - } - } +unsigned SampleProfileLoader::getFunctionLoc(Function &F) { + DISubprogram S = getDISubprogram(&F); + if (S.isSubprogram()) + return S.getLineNumber(); + // If could not find the start of \p F, emit a diagnostic to inform the user + // about the missed opportunity. F.getContext().diagnose(DiagnosticInfoSampleProfile( - "No debug information found in function " + F.getName())); + "No debug information found in function " + F.getName() + + ": Function profile not used", + DS_Warning)); return 0; } @@ -1002,15 +670,15 @@ unsigned SampleFunctionProfile::getFunctionLoc(Function &F) { /// /// 3- Propagation of block weights into edges. This uses a simple /// propagation heuristic. The following rules are applied to every -/// block B in the CFG: +/// block BB in the CFG: /// -/// - If B has a single predecessor/successor, then the weight +/// - If BB has a single predecessor/successor, then the weight /// of that edge is the weight of the block. /// /// - If all the edges are known except one, and the weight of the /// block is already known, the weight of the unknown edge will /// be the weight of the block minus the sum of all the known -/// edges. If the sum of all the known edges is larger than B's weight, +/// edges. If the sum of all the known edges is larger than BB's weight, /// we set the unknown edge weight to zero. /// /// - If there is a self-referential edge, and the weight of the block is @@ -1028,14 +696,12 @@ unsigned SampleFunctionProfile::getFunctionLoc(Function &F) { /// work here. /// /// Once all the branch weights are computed, we emit the MD_prof -/// metadata on B using the computed values for each of its branches. +/// metadata on BB using the computed values for each of its branches. /// /// \param F The function to query. /// /// \returns true if \p F was modified. Returns false, otherwise. -bool SampleFunctionProfile::emitAnnotations(Function &F, DominatorTree *DomTree, - PostDominatorTree *PostDomTree, - LoopInfo *Loops) { +bool SampleProfileLoader::emitAnnotations(Function &F) { bool Changed = false; // Initialize invariants used during computation and propagation. @@ -1045,10 +711,6 @@ bool SampleFunctionProfile::emitAnnotations(Function &F, DominatorTree *DomTree, DEBUG(dbgs() << "Line number for the first instruction in " << F.getName() << ": " << HeaderLineno << "\n"); - DT = DomTree; - PDT = PostDomTree; - LI = Loops; - Ctx = &F.getParent()->getContext(); // Compute basic block weights. Changed |= computeBlockWeights(F); @@ -1075,8 +737,14 @@ INITIALIZE_PASS_END(SampleProfileLoader, "sample-profile", "Sample Profile loader", false, false) bool SampleProfileLoader::doInitialization(Module &M) { - Profiler.reset(new SampleModuleProfile(M, Filename)); - ProfileIsValid = Profiler->loadText(); + auto ReaderOrErr = SampleProfileReader::create(Filename, M.getContext()); + if (std::error_code EC = ReaderOrErr.getError()) { + std::string Msg = "Could not open profile: " + EC.message(); + M.getContext().diagnose(DiagnosticInfoSampleProfile(Filename.data(), Msg)); + return false; + } + Reader = std::move(ReaderOrErr.get()); + ProfileIsValid = (Reader->read() == sampleprof_error::success); return true; } @@ -1091,11 +759,13 @@ FunctionPass *llvm::createSampleProfileLoaderPass(StringRef Name) { bool SampleProfileLoader::runOnFunction(Function &F) { if (!ProfileIsValid) return false; - DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - PostDominatorTree *PDT = &getAnalysis<PostDominatorTree>(); - LoopInfo *LI = &getAnalysis<LoopInfo>(); - SampleFunctionProfile &FunctionProfile = Profiler->getProfile(F); - if (!FunctionProfile.empty()) - return FunctionProfile.emitAnnotations(F, DT, PDT, LI); + + DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + PDT = &getAnalysis<PostDominatorTree>(); + LI = &getAnalysis<LoopInfo>(); + Ctx = &F.getParent()->getContext(); + Samples = Reader->getSamplesFor(F); + if (!Samples->empty()) + return emitAnnotations(F); return false; } diff --git a/lib/Transforms/Scalar/Scalar.cpp b/lib/Transforms/Scalar/Scalar.cpp index de724d419a48..a16e9e29a1f1 100644 --- a/lib/Transforms/Scalar/Scalar.cpp +++ b/lib/Transforms/Scalar/Scalar.cpp @@ -28,6 +28,7 @@ using namespace llvm; /// ScalarOpts library. void llvm::initializeScalarOpts(PassRegistry &Registry) { initializeADCEPass(Registry); + initializeAlignmentFromAssumptionsPass(Registry); initializeSampleProfileLoaderPass(Registry); initializeConstantHoistingPass(Registry); initializeConstantPropagationPass(Registry); @@ -38,6 +39,7 @@ void llvm::initializeScalarOpts(PassRegistry &Registry) { initializeDSEPass(Registry); initializeGVNPass(Registry); initializeEarlyCSEPass(Registry); + initializeFlattenCFGPassPass(Registry); initializeIndVarSimplifyPass(Registry); initializeJumpThreadingPass(Registry); initializeLICMPass(Registry); @@ -77,6 +79,10 @@ void LLVMAddAggressiveDCEPass(LLVMPassManagerRef PM) { unwrap(PM)->add(createAggressiveDCEPass()); } +void LLVMAddAlignmentFromAssumptionsPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createAlignmentFromAssumptionsPass()); +} + void LLVMAddCFGSimplificationPass(LLVMPassManagerRef PM) { unwrap(PM)->add(createCFGSimplificationPass()); } @@ -145,6 +151,10 @@ void LLVMAddPartiallyInlineLibCallsPass(LLVMPassManagerRef PM) { unwrap(PM)->add(createPartiallyInlineLibCallsPass()); } +void LLVMAddLowerSwitchPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createLowerSwitchPass()); +} + void LLVMAddPromoteMemoryToRegisterPass(LLVMPassManagerRef PM) { unwrap(PM)->add(createPromoteMemoryToRegisterPass()); } @@ -203,6 +213,10 @@ void LLVMAddTypeBasedAliasAnalysisPass(LLVMPassManagerRef PM) { unwrap(PM)->add(createTypeBasedAliasAnalysisPass()); } +void LLVMAddScopedNoAliasAAPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createScopedNoAliasAAPass()); +} + void LLVMAddBasicAliasAnalysisPass(LLVMPassManagerRef PM) { unwrap(PM)->add(createBasicAliasAnalysisPass()); } diff --git a/lib/Transforms/Scalar/ScalarReplAggregates.cpp b/lib/Transforms/Scalar/ScalarReplAggregates.cpp index e2a24a7fd4a7..5c49a5504b47 100644 --- a/lib/Transforms/Scalar/ScalarReplAggregates.cpp +++ b/lib/Transforms/Scalar/ScalarReplAggregates.cpp @@ -23,6 +23,7 @@ #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/Loads.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/CallSite.h" @@ -197,6 +198,7 @@ namespace { // getAnalysisUsage - This pass does not require any passes, but we know it // will not alter the CFG, so say so. void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<AssumptionCacheTracker>(); AU.addRequired<DominatorTreeWrapperPass>(); AU.setPreservesCFG(); } @@ -214,6 +216,7 @@ namespace { // getAnalysisUsage - This pass does not require any passes, but we know it // will not alter the CFG, so say so. void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<AssumptionCacheTracker>(); AU.setPreservesCFG(); } }; @@ -225,12 +228,14 @@ char SROA_SSAUp::ID = 0; INITIALIZE_PASS_BEGIN(SROA_DT, "scalarrepl", "Scalar Replacement of Aggregates (DT)", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_END(SROA_DT, "scalarrepl", "Scalar Replacement of Aggregates (DT)", false, false) INITIALIZE_PASS_BEGIN(SROA_SSAUp, "scalarrepl-ssa", "Scalar Replacement of Aggregates (SSAUp)", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_END(SROA_SSAUp, "scalarrepl-ssa", "Scalar Replacement of Aggregates (SSAUp)", false, false) @@ -1063,12 +1068,14 @@ public: void run(AllocaInst *AI, const SmallVectorImpl<Instruction*> &Insts) { // Remember which alloca we're promoting (for isInstInList). this->AI = AI; - if (MDNode *DebugNode = MDNode::getIfExists(AI->getContext(), AI)) { - for (User *U : DebugNode->users()) - if (DbgDeclareInst *DDI = dyn_cast<DbgDeclareInst>(U)) - DDIs.push_back(DDI); - else if (DbgValueInst *DVI = dyn_cast<DbgValueInst>(U)) - DVIs.push_back(DVI); + if (auto *L = LocalAsMetadata::getIfExists(AI)) { + if (auto *DebugNode = MetadataAsValue::getIfExists(AI->getContext(), L)) { + for (User *U : DebugNode->users()) + if (DbgDeclareInst *DDI = dyn_cast<DbgDeclareInst>(U)) + DDIs.push_back(DDI); + else if (DbgValueInst *DVI = dyn_cast<DbgValueInst>(U)) + DVIs.push_back(DVI); + } } LoadAndStorePromoter::run(Insts); @@ -1119,9 +1126,9 @@ public: } else { continue; } - Instruction *DbgVal = - DIB->insertDbgValueIntrinsic(Arg, 0, DIVariable(DVI->getVariable()), - Inst); + Instruction *DbgVal = DIB->insertDbgValueIntrinsic( + Arg, 0, DIVariable(DVI->getVariable()), + DIExpression(DVI->getExpression()), Inst); DbgVal->setDebugLoc(DVI->getDebugLoc()); } } @@ -1333,12 +1340,15 @@ static bool tryToMakeAllocaBePromotable(AllocaInst *AI, const DataLayout *DL) { LoadInst *FalseLoad = Builder.CreateLoad(SI->getFalseValue(), LI->getName()+".f"); - // Transfer alignment and TBAA info if present. + // Transfer alignment and AA info if present. TrueLoad->setAlignment(LI->getAlignment()); FalseLoad->setAlignment(LI->getAlignment()); - if (MDNode *Tag = LI->getMetadata(LLVMContext::MD_tbaa)) { - TrueLoad->setMetadata(LLVMContext::MD_tbaa, Tag); - FalseLoad->setMetadata(LLVMContext::MD_tbaa, Tag); + + AAMDNodes Tags; + LI->getAAMetadata(Tags); + if (Tags) { + TrueLoad->setAAMetadata(Tags); + FalseLoad->setAAMetadata(Tags); } Value *V = Builder.CreateSelect(SI->getCondition(), TrueLoad, FalseLoad); @@ -1364,10 +1374,12 @@ static bool tryToMakeAllocaBePromotable(AllocaInst *AI, const DataLayout *DL) { PHINode *NewPN = PHINode::Create(LoadTy, PN->getNumIncomingValues(), PN->getName()+".ld", PN); - // Get the TBAA tag and alignment to use from one of the loads. It doesn't + // Get the AA tags and alignment to use from one of the loads. It doesn't // matter which one we get and if any differ, it doesn't matter. LoadInst *SomeLoad = cast<LoadInst>(PN->user_back()); - MDNode *TBAATag = SomeLoad->getMetadata(LLVMContext::MD_tbaa); + + AAMDNodes AATags; + SomeLoad->getAAMetadata(AATags); unsigned Align = SomeLoad->getAlignment(); // Rewrite all loads of the PN to use the new PHI. @@ -1389,7 +1401,7 @@ static bool tryToMakeAllocaBePromotable(AllocaInst *AI, const DataLayout *DL) { PN->getName() + "." + Pred->getName(), Pred->getTerminator()); Load->setAlignment(Align); - if (TBAATag) Load->setMetadata(LLVMContext::MD_tbaa, TBAATag); + if (AATags) Load->setAAMetadata(AATags); } NewPN->addIncoming(Load, Pred); @@ -1407,9 +1419,11 @@ bool SROA::performPromotion(Function &F) { DominatorTree *DT = nullptr; if (HasDomTree) DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + AssumptionCache &AC = + getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); BasicBlock &BB = F.getEntryBlock(); // Get the entry node for the function - DIBuilder DIB(*F.getParent()); + DIBuilder DIB(*F.getParent(), /*AllowUnresolved*/ false); bool Changed = false; SmallVector<Instruction*, 64> Insts; while (1) { @@ -1425,7 +1439,7 @@ bool SROA::performPromotion(Function &F) { if (Allocas.empty()) break; if (HasDomTree) - PromoteMemToReg(Allocas, *DT); + PromoteMemToReg(Allocas, *DT, nullptr, &AC); else { SSAUpdater SSA; for (unsigned i = 0, e = Allocas.size(); i != e; ++i) { @@ -1658,7 +1672,7 @@ void SROA::isSafePHISelectUseForScalarRepl(Instruction *I, uint64_t Offset, AllocaInfo &Info) { // If we've already checked this PHI, don't do it again. if (PHINode *PN = dyn_cast<PHINode>(I)) - if (!Info.CheckedPHIs.insert(PN)) + if (!Info.CheckedPHIs.insert(PN).second) return; for (User *U : I->users()) { diff --git a/lib/Transforms/Scalar/Scalarizer.cpp b/lib/Transforms/Scalar/Scalarizer.cpp index 7a73f113b1d9..6036c099be0e 100644 --- a/lib/Transforms/Scalar/Scalarizer.cpp +++ b/lib/Transforms/Scalar/Scalarizer.cpp @@ -150,6 +150,16 @@ public: bool visitLoadInst(LoadInst &); bool visitStoreInst(StoreInst &); + static void registerOptions() { + // This is disabled by default because having separate loads and stores + // makes it more likely that the -combiner-alias-analysis limits will be + // reached. + OptionRegistry::registerOption<bool, Scalarizer, + &Scalarizer::ScalarizeLoadStore>( + "scalarize-load-store", + "Allow the scalarizer pass to scalarize loads and store", false); + } + private: Scatterer scatter(Instruction *, Value *); void gather(Instruction *, const ValueVector &); @@ -164,19 +174,14 @@ private: GatherList Gathered; unsigned ParallelLoopAccessMDKind; const DataLayout *DL; + bool ScalarizeLoadStore; }; char Scalarizer::ID = 0; } // end anonymous namespace -// This is disabled by default because having separate loads and stores makes -// it more likely that the -combiner-alias-analysis limits will be reached. -static cl::opt<bool> ScalarizeLoadStore - ("scalarize-load-store", cl::Hidden, cl::init(false), - cl::desc("Allow the scalarizer pass to scalarize loads and store")); - -INITIALIZE_PASS(Scalarizer, "scalarizer", "Scalarize vector operations", - false, false) +INITIALIZE_PASS_WITH_OPTIONS(Scalarizer, "scalarizer", + "Scalarize vector operations", false, false) Scatterer::Scatterer(BasicBlock *bb, BasicBlock::iterator bbi, Value *v, ValueVector *cachePtr) @@ -236,7 +241,9 @@ Value *Scatterer::operator[](unsigned I) { bool Scalarizer::doInitialization(Module &M) { ParallelLoopAccessMDKind = - M.getContext().getMDKindID("llvm.mem.parallel_loop_access"); + M.getContext().getMDKindID("llvm.mem.parallel_loop_access"); + ScalarizeLoadStore = + M.getContext().getOption<bool, Scalarizer, &Scalarizer::ScalarizeLoadStore>(); return false; } @@ -312,6 +319,8 @@ bool Scalarizer::canTransferMetadata(unsigned Tag) { || Tag == LLVMContext::MD_fpmath || Tag == LLVMContext::MD_tbaa_struct || Tag == LLVMContext::MD_invariant_load + || Tag == LLVMContext::MD_alias_scope + || Tag == LLVMContext::MD_noalias || Tag == ParallelLoopAccessMDKind); } @@ -322,8 +331,10 @@ void Scalarizer::transferMetadata(Instruction *Op, const ValueVector &CV) { Op->getAllMetadataOtherThanDebugLoc(MDs); for (unsigned I = 0, E = CV.size(); I != E; ++I) { if (Instruction *New = dyn_cast<Instruction>(CV[I])) { - for (SmallVectorImpl<std::pair<unsigned, MDNode *> >::iterator - MI = MDs.begin(), ME = MDs.end(); MI != ME; ++MI) + for (SmallVectorImpl<std::pair<unsigned, MDNode *>>::iterator + MI = MDs.begin(), + ME = MDs.end(); + MI != ME; ++MI) if (canTransferMetadata(MI->first)) New->setMetadata(MI->first, MI->second); New->setDebugLoc(Op->getDebugLoc()); diff --git a/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp b/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp index 6557ce4575dd..6157746af48c 100644 --- a/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp +++ b/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp @@ -79,6 +79,81 @@ // ld.global.f32 %f3, [%rl6+128]; // much better // ld.global.f32 %f4, [%rl6+132]; // much better // +// Another improvement enabled by the LowerGEP flag is to lower a GEP with +// multiple indices to either multiple GEPs with a single index or arithmetic +// operations (depending on whether the target uses alias analysis in codegen). +// Such transformation can have following benefits: +// (1) It can always extract constants in the indices of structure type. +// (2) After such Lowering, there are more optimization opportunities such as +// CSE, LICM and CGP. +// +// E.g. The following GEPs have multiple indices: +// BB1: +// %p = getelementptr [10 x %struct]* %ptr, i64 %i, i64 %j1, i32 3 +// load %p +// ... +// BB2: +// %p2 = getelementptr [10 x %struct]* %ptr, i64 %i, i64 %j1, i32 2 +// load %p2 +// ... +// +// We can not do CSE for to the common part related to index "i64 %i". Lowering +// GEPs can achieve such goals. +// If the target does not use alias analysis in codegen, this pass will +// lower a GEP with multiple indices into arithmetic operations: +// BB1: +// %1 = ptrtoint [10 x %struct]* %ptr to i64 ; CSE opportunity +// %2 = mul i64 %i, length_of_10xstruct ; CSE opportunity +// %3 = add i64 %1, %2 ; CSE opportunity +// %4 = mul i64 %j1, length_of_struct +// %5 = add i64 %3, %4 +// %6 = add i64 %3, struct_field_3 ; Constant offset +// %p = inttoptr i64 %6 to i32* +// load %p +// ... +// BB2: +// %7 = ptrtoint [10 x %struct]* %ptr to i64 ; CSE opportunity +// %8 = mul i64 %i, length_of_10xstruct ; CSE opportunity +// %9 = add i64 %7, %8 ; CSE opportunity +// %10 = mul i64 %j2, length_of_struct +// %11 = add i64 %9, %10 +// %12 = add i64 %11, struct_field_2 ; Constant offset +// %p = inttoptr i64 %12 to i32* +// load %p2 +// ... +// +// If the target uses alias analysis in codegen, this pass will lower a GEP +// with multiple indices into multiple GEPs with a single index: +// BB1: +// %1 = bitcast [10 x %struct]* %ptr to i8* ; CSE opportunity +// %2 = mul i64 %i, length_of_10xstruct ; CSE opportunity +// %3 = getelementptr i8* %1, i64 %2 ; CSE opportunity +// %4 = mul i64 %j1, length_of_struct +// %5 = getelementptr i8* %3, i64 %4 +// %6 = getelementptr i8* %5, struct_field_3 ; Constant offset +// %p = bitcast i8* %6 to i32* +// load %p +// ... +// BB2: +// %7 = bitcast [10 x %struct]* %ptr to i8* ; CSE opportunity +// %8 = mul i64 %i, length_of_10xstruct ; CSE opportunity +// %9 = getelementptr i8* %7, i64 %8 ; CSE opportunity +// %10 = mul i64 %j2, length_of_struct +// %11 = getelementptr i8* %9, i64 %10 +// %12 = getelementptr i8* %11, struct_field_2 ; Constant offset +// %p2 = bitcast i8* %12 to i32* +// load %p2 +// ... +// +// Lowering GEPs can also benefit other passes such as LICM and CGP. +// LICM (Loop Invariant Code Motion) can not hoist/sink a GEP of multiple +// indices if one of the index is variant. If we lower such GEP into invariant +// parts and variant parts, LICM can hoist/sink those invariant parts. +// CGP (CodeGen Prepare) tries to sink address calculations that match the +// target's addressing modes. A GEP with multiple indices may not match and will +// not be sunk. If we lower such GEP into smaller parts, CGP may sink some of +// them. So we end up with a better addressing mode. +// //===----------------------------------------------------------------------===// #include "llvm/Analysis/TargetTransformInfo.h" @@ -92,6 +167,9 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" +#include "llvm/Target/TargetMachine.h" +#include "llvm/Target/TargetSubtargetInfo.h" +#include "llvm/IR/IRBuilder.h" using namespace llvm; @@ -117,18 +195,17 @@ namespace { /// -instcombine probably already optimized (3 * (a + 5)) to (3 * a + 15). class ConstantOffsetExtractor { public: - /// Extracts a constant offset from the given GEP index. It outputs the - /// numeric value of the extracted constant offset (0 if failed), and a + /// Extracts a constant offset from the given GEP index. It returns the /// new index representing the remainder (equal to the original index minus - /// the constant offset). + /// the constant offset), or nullptr if we cannot extract a constant offset. /// \p Idx The given GEP index - /// \p NewIdx The new index to replace (output) /// \p DL The datalayout of the module /// \p GEP The given GEP - static int64_t Extract(Value *Idx, Value *&NewIdx, const DataLayout *DL, - GetElementPtrInst *GEP); - /// Looks for a constant offset without extracting it. The meaning of the - /// arguments and the return value are the same as Extract. + static Value *Extract(Value *Idx, const DataLayout *DL, + GetElementPtrInst *GEP); + /// Looks for a constant offset from the given GEP index without extracting + /// it. It returns the numeric value of the extracted constant offset (0 if + /// failed). The meaning of the arguments are the same as Extract. static int64_t Find(Value *Idx, const DataLayout *DL, GetElementPtrInst *GEP); private: @@ -228,7 +305,9 @@ class ConstantOffsetExtractor { class SeparateConstOffsetFromGEP : public FunctionPass { public: static char ID; - SeparateConstOffsetFromGEP() : FunctionPass(ID) { + SeparateConstOffsetFromGEP(const TargetMachine *TM = nullptr, + bool LowerGEP = false) + : FunctionPass(ID), TM(TM), LowerGEP(LowerGEP) { initializeSeparateConstOffsetFromGEPPass(*PassRegistry::getPassRegistry()); } @@ -251,10 +330,29 @@ class SeparateConstOffsetFromGEP : public FunctionPass { /// Tries to split the given GEP into a variadic base and a constant offset, /// and returns true if the splitting succeeds. bool splitGEP(GetElementPtrInst *GEP); - /// Finds the constant offset within each index, and accumulates them. This - /// function only inspects the GEP without changing it. The output - /// NeedsExtraction indicates whether we can extract a non-zero constant - /// offset from any index. + /// Lower a GEP with multiple indices into multiple GEPs with a single index. + /// Function splitGEP already split the original GEP into a variadic part and + /// a constant offset (i.e., AccumulativeByteOffset). This function lowers the + /// variadic part into a set of GEPs with a single index and applies + /// AccumulativeByteOffset to it. + /// \p Variadic The variadic part of the original GEP. + /// \p AccumulativeByteOffset The constant offset. + void lowerToSingleIndexGEPs(GetElementPtrInst *Variadic, + int64_t AccumulativeByteOffset); + /// Lower a GEP with multiple indices into ptrtoint+arithmetics+inttoptr form. + /// Function splitGEP already split the original GEP into a variadic part and + /// a constant offset (i.e., AccumulativeByteOffset). This function lowers the + /// variadic part into a set of arithmetic operations and applies + /// AccumulativeByteOffset to it. + /// \p Variadic The variadic part of the original GEP. + /// \p AccumulativeByteOffset The constant offset. + void lowerToArithmetics(GetElementPtrInst *Variadic, + int64_t AccumulativeByteOffset); + /// Finds the constant offset within each index and accumulates them. If + /// LowerGEP is true, it finds in indices of both sequential and structure + /// types, otherwise it only finds in sequential indices. The output + /// NeedsExtraction indicates whether we successfully find a non-zero constant + /// offset. int64_t accumulateByteOffset(GetElementPtrInst *GEP, bool &NeedsExtraction); /// Canonicalize array indices to pointer-size integers. This helps to /// simplify the logic of splitting a GEP. For example, if a + b is a @@ -274,6 +372,10 @@ class SeparateConstOffsetFromGEP : public FunctionPass { bool canonicalizeArrayIndicesToPointerSize(GetElementPtrInst *GEP); const DataLayout *DL; + const TargetMachine *TM; + /// Whether to lower a GEP with multiple indices into arithmetic operations or + /// multiple GEPs with a single index. + bool LowerGEP; }; } // anonymous namespace @@ -289,8 +391,10 @@ INITIALIZE_PASS_END( "Split GEPs to a variadic base and a constant offset for better CSE", false, false) -FunctionPass *llvm::createSeparateConstOffsetFromGEPPass() { - return new SeparateConstOffsetFromGEP(); +FunctionPass * +llvm::createSeparateConstOffsetFromGEPPass(const TargetMachine *TM, + bool LowerGEP) { + return new SeparateConstOffsetFromGEP(TM, LowerGEP); } bool ConstantOffsetExtractor::CanTraceInto(bool SignExtended, @@ -519,8 +623,13 @@ Value *ConstantOffsetExtractor::removeConstOffset(unsigned ChainIndex) { // // Replacing the "or" with "add" is fine, because // a | (b + 5) = a + (b + 5) = (a + b) + 5 - return BinaryOperator::CreateAdd(BO->getOperand(0), BO->getOperand(1), - BO->getName(), IP); + if (OpNo == 0) { + return BinaryOperator::CreateAdd(NextInChain, TheOther, BO->getName(), + IP); + } else { + return BinaryOperator::CreateAdd(TheOther, NextInChain, BO->getName(), + IP); + } } // We can reuse BO in this case, because the new expression shares the same @@ -537,19 +646,17 @@ Value *ConstantOffsetExtractor::removeConstOffset(unsigned ChainIndex) { return BO; } -int64_t ConstantOffsetExtractor::Extract(Value *Idx, Value *&NewIdx, - const DataLayout *DL, - GetElementPtrInst *GEP) { +Value *ConstantOffsetExtractor::Extract(Value *Idx, const DataLayout *DL, + GetElementPtrInst *GEP) { ConstantOffsetExtractor Extractor(DL, GEP); // Find a non-zero constant offset first. APInt ConstantOffset = Extractor.find(Idx, /* SignExtended */ false, /* ZeroExtended */ false, GEP->isInBounds()); - if (ConstantOffset != 0) { - // Separates the constant offset from the GEP index. - NewIdx = Extractor.rebuildWithoutConstOffset(); - } - return ConstantOffset.getSExtValue(); + if (ConstantOffset == 0) + return nullptr; + // Separates the constant offset from the GEP index. + return Extractor.rebuildWithoutConstOffset(); } int64_t ConstantOffsetExtractor::Find(Value *Idx, const DataLayout *DL, @@ -615,11 +722,116 @@ SeparateConstOffsetFromGEP::accumulateByteOffset(GetElementPtrInst *GEP, AccumulativeByteOffset += ConstantOffset * DL->getTypeAllocSize(GTI.getIndexedType()); } + } else if (LowerGEP) { + StructType *StTy = cast<StructType>(*GTI); + uint64_t Field = cast<ConstantInt>(GEP->getOperand(I))->getZExtValue(); + // Skip field 0 as the offset is always 0. + if (Field != 0) { + NeedsExtraction = true; + AccumulativeByteOffset += + DL->getStructLayout(StTy)->getElementOffset(Field); + } } } return AccumulativeByteOffset; } +void SeparateConstOffsetFromGEP::lowerToSingleIndexGEPs( + GetElementPtrInst *Variadic, int64_t AccumulativeByteOffset) { + IRBuilder<> Builder(Variadic); + Type *IntPtrTy = DL->getIntPtrType(Variadic->getType()); + + Type *I8PtrTy = + Builder.getInt8PtrTy(Variadic->getType()->getPointerAddressSpace()); + Value *ResultPtr = Variadic->getOperand(0); + if (ResultPtr->getType() != I8PtrTy) + ResultPtr = Builder.CreateBitCast(ResultPtr, I8PtrTy); + + gep_type_iterator GTI = gep_type_begin(*Variadic); + // Create an ugly GEP for each sequential index. We don't create GEPs for + // structure indices, as they are accumulated in the constant offset index. + for (unsigned I = 1, E = Variadic->getNumOperands(); I != E; ++I, ++GTI) { + if (isa<SequentialType>(*GTI)) { + Value *Idx = Variadic->getOperand(I); + // Skip zero indices. + if (ConstantInt *CI = dyn_cast<ConstantInt>(Idx)) + if (CI->isZero()) + continue; + + APInt ElementSize = APInt(IntPtrTy->getIntegerBitWidth(), + DL->getTypeAllocSize(GTI.getIndexedType())); + // Scale the index by element size. + if (ElementSize != 1) { + if (ElementSize.isPowerOf2()) { + Idx = Builder.CreateShl( + Idx, ConstantInt::get(IntPtrTy, ElementSize.logBase2())); + } else { + Idx = Builder.CreateMul(Idx, ConstantInt::get(IntPtrTy, ElementSize)); + } + } + // Create an ugly GEP with a single index for each index. + ResultPtr = Builder.CreateGEP(ResultPtr, Idx, "uglygep"); + } + } + + // Create a GEP with the constant offset index. + if (AccumulativeByteOffset != 0) { + Value *Offset = ConstantInt::get(IntPtrTy, AccumulativeByteOffset); + ResultPtr = Builder.CreateGEP(ResultPtr, Offset, "uglygep"); + } + if (ResultPtr->getType() != Variadic->getType()) + ResultPtr = Builder.CreateBitCast(ResultPtr, Variadic->getType()); + + Variadic->replaceAllUsesWith(ResultPtr); + Variadic->eraseFromParent(); +} + +void +SeparateConstOffsetFromGEP::lowerToArithmetics(GetElementPtrInst *Variadic, + int64_t AccumulativeByteOffset) { + IRBuilder<> Builder(Variadic); + Type *IntPtrTy = DL->getIntPtrType(Variadic->getType()); + + Value *ResultPtr = Builder.CreatePtrToInt(Variadic->getOperand(0), IntPtrTy); + gep_type_iterator GTI = gep_type_begin(*Variadic); + // Create ADD/SHL/MUL arithmetic operations for each sequential indices. We + // don't create arithmetics for structure indices, as they are accumulated + // in the constant offset index. + for (unsigned I = 1, E = Variadic->getNumOperands(); I != E; ++I, ++GTI) { + if (isa<SequentialType>(*GTI)) { + Value *Idx = Variadic->getOperand(I); + // Skip zero indices. + if (ConstantInt *CI = dyn_cast<ConstantInt>(Idx)) + if (CI->isZero()) + continue; + + APInt ElementSize = APInt(IntPtrTy->getIntegerBitWidth(), + DL->getTypeAllocSize(GTI.getIndexedType())); + // Scale the index by element size. + if (ElementSize != 1) { + if (ElementSize.isPowerOf2()) { + Idx = Builder.CreateShl( + Idx, ConstantInt::get(IntPtrTy, ElementSize.logBase2())); + } else { + Idx = Builder.CreateMul(Idx, ConstantInt::get(IntPtrTy, ElementSize)); + } + } + // Create an ADD for each index. + ResultPtr = Builder.CreateAdd(ResultPtr, Idx); + } + } + + // Create an ADD for the constant offset index. + if (AccumulativeByteOffset != 0) { + ResultPtr = Builder.CreateAdd( + ResultPtr, ConstantInt::get(IntPtrTy, AccumulativeByteOffset)); + } + + ResultPtr = Builder.CreateIntToPtr(ResultPtr, Variadic->getType()); + Variadic->replaceAllUsesWith(ResultPtr); + Variadic->eraseFromParent(); +} + bool SeparateConstOffsetFromGEP::splitGEP(GetElementPtrInst *GEP) { // Skip vector GEPs. if (GEP->getType()->isVectorTy()) @@ -637,32 +849,42 @@ bool SeparateConstOffsetFromGEP::splitGEP(GetElementPtrInst *GEP) { if (!NeedsExtraction) return Changed; - // Before really splitting the GEP, check whether the backend supports the - // addressing mode we are about to produce. If no, this splitting probably - // won't be beneficial. - TargetTransformInfo &TTI = getAnalysis<TargetTransformInfo>(); - if (!TTI.isLegalAddressingMode(GEP->getType()->getElementType(), - /*BaseGV=*/nullptr, AccumulativeByteOffset, - /*HasBaseReg=*/true, /*Scale=*/0)) { - return Changed; + // If LowerGEP is disabled, before really splitting the GEP, check whether the + // backend supports the addressing mode we are about to produce. If no, this + // splitting probably won't be beneficial. + // If LowerGEP is enabled, even the extracted constant offset can not match + // the addressing mode, we can still do optimizations to other lowered parts + // of variable indices. Therefore, we don't check for addressing modes in that + // case. + if (!LowerGEP) { + TargetTransformInfo &TTI = getAnalysis<TargetTransformInfo>(); + if (!TTI.isLegalAddressingMode(GEP->getType()->getElementType(), + /*BaseGV=*/nullptr, AccumulativeByteOffset, + /*HasBaseReg=*/true, /*Scale=*/0)) { + return Changed; + } } - // Remove the constant offset in each GEP index. The resultant GEP computes - // the variadic base. + // Remove the constant offset in each sequential index. The resultant GEP + // computes the variadic base. + // Notice that we don't remove struct field indices here. If LowerGEP is + // disabled, a structure index is not accumulated and we still use the old + // one. If LowerGEP is enabled, a structure index is accumulated in the + // constant offset. LowerToSingleIndexGEPs or lowerToArithmetics will later + // handle the constant offset and won't need a new structure index. gep_type_iterator GTI = gep_type_begin(*GEP); for (unsigned I = 1, E = GEP->getNumOperands(); I != E; ++I, ++GTI) { if (isa<SequentialType>(*GTI)) { - Value *NewIdx = nullptr; - // Tries to extract a constant offset from this GEP index. - int64_t ConstantOffset = - ConstantOffsetExtractor::Extract(GEP->getOperand(I), NewIdx, DL, GEP); - if (ConstantOffset != 0) { - assert(NewIdx != nullptr && - "ConstantOffset != 0 implies NewIdx is set"); + // Splits this GEP index into a variadic part and a constant offset, and + // uses the variadic part as the new index. + Value *NewIdx = + ConstantOffsetExtractor::Extract(GEP->getOperand(I), DL, GEP); + if (NewIdx != nullptr) { GEP->setOperand(I, NewIdx); } } } + // Clear the inbounds attribute because the new index may be off-bound. // e.g., // @@ -684,6 +906,21 @@ bool SeparateConstOffsetFromGEP::splitGEP(GetElementPtrInst *GEP) { // possible. GEPs with inbounds are more friendly to alias analysis. GEP->setIsInBounds(false); + // Lowers a GEP to either GEPs with a single index or arithmetic operations. + if (LowerGEP) { + // As currently BasicAA does not analyze ptrtoint/inttoptr, do not lower to + // arithmetic operations if the target uses alias analysis in codegen. + if (TM && TM->getSubtarget<TargetSubtargetInfo>().useAA()) + lowerToSingleIndexGEPs(GEP, AccumulativeByteOffset); + else + lowerToArithmetics(GEP, AccumulativeByteOffset); + return true; + } + + // No need to create another GEP if the accumulative byte offset is 0. + if (AccumulativeByteOffset == 0) + return true; + // Offsets the base with the accumulative byte offset. // // %gep ; the base @@ -715,16 +952,16 @@ bool SeparateConstOffsetFromGEP::splitGEP(GetElementPtrInst *GEP) { Instruction *NewGEP = GEP->clone(); NewGEP->insertBefore(GEP); - uint64_t ElementTypeSizeOfGEP = - DL->getTypeAllocSize(GEP->getType()->getElementType()); + // Per ANSI C standard, signed / unsigned = unsigned and signed % unsigned = + // unsigned.. Therefore, we cast ElementTypeSizeOfGEP to signed because it is + // used with unsigned integers later. + int64_t ElementTypeSizeOfGEP = static_cast<int64_t>( + DL->getTypeAllocSize(GEP->getType()->getElementType())); Type *IntPtrTy = DL->getIntPtrType(GEP->getType()); if (AccumulativeByteOffset % ElementTypeSizeOfGEP == 0) { // Very likely. As long as %gep is natually aligned, the byte offset we // extracted should be a multiple of sizeof(*%gep). - // Per ANSI C standard, signed / unsigned = unsigned. Therefore, we - // cast ElementTypeSizeOfGEP to signed. - int64_t Index = - AccumulativeByteOffset / static_cast<int64_t>(ElementTypeSizeOfGEP); + int64_t Index = AccumulativeByteOffset / ElementTypeSizeOfGEP; NewGEP = GetElementPtrInst::Create( NewGEP, ConstantInt::get(IntPtrTy, Index, true), GEP->getName(), GEP); } else { diff --git a/lib/Transforms/Scalar/SimplifyCFGPass.cpp b/lib/Transforms/Scalar/SimplifyCFGPass.cpp index 5d5606ba47b0..2e317f9d0999 100644 --- a/lib/Transforms/Scalar/SimplifyCFGPass.cpp +++ b/lib/Transforms/Scalar/SimplifyCFGPass.cpp @@ -25,6 +25,7 @@ #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/CFG.h" @@ -34,22 +35,30 @@ #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Module.h" #include "llvm/Pass.h" +#include "llvm/Support/CommandLine.h" #include "llvm/Transforms/Utils/Local.h" using namespace llvm; #define DEBUG_TYPE "simplifycfg" +static cl::opt<unsigned> +UserBonusInstThreshold("bonus-inst-threshold", cl::Hidden, cl::init(1), + cl::desc("Control the number of bonus instructions (default = 1)")); + STATISTIC(NumSimpl, "Number of blocks simplified"); namespace { struct CFGSimplifyPass : public FunctionPass { static char ID; // Pass identification, replacement for typeid - CFGSimplifyPass() : FunctionPass(ID) { + unsigned BonusInstThreshold; + CFGSimplifyPass(int T = -1) : FunctionPass(ID) { + BonusInstThreshold = (T == -1) ? UserBonusInstThreshold : unsigned(T); initializeCFGSimplifyPassPass(*PassRegistry::getPassRegistry()); } bool runOnFunction(Function &F) override; void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<AssumptionCacheTracker>(); AU.addRequired<TargetTransformInfo>(); } }; @@ -59,12 +68,13 @@ char CFGSimplifyPass::ID = 0; INITIALIZE_PASS_BEGIN(CFGSimplifyPass, "simplifycfg", "Simplify the CFG", false, false) INITIALIZE_AG_DEPENDENCY(TargetTransformInfo) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_END(CFGSimplifyPass, "simplifycfg", "Simplify the CFG", false, false) // Public interface to the CFGSimplification pass -FunctionPass *llvm::createCFGSimplificationPass() { - return new CFGSimplifyPass(); +FunctionPass *llvm::createCFGSimplificationPass(int Threshold) { + return new CFGSimplifyPass(Threshold); } /// mergeEmptyReturnBlocks - If we have more than one empty (other than phi @@ -146,7 +156,8 @@ static bool mergeEmptyReturnBlocks(Function &F) { /// iterativelySimplifyCFG - Call SimplifyCFG on all the blocks in the function, /// iterating until no more changes are made. static bool iterativelySimplifyCFG(Function &F, const TargetTransformInfo &TTI, - const DataLayout *DL) { + const DataLayout *DL, AssumptionCache *AC, + unsigned BonusInstThreshold) { bool Changed = false; bool LocalChange = true; while (LocalChange) { @@ -155,7 +166,7 @@ static bool iterativelySimplifyCFG(Function &F, const TargetTransformInfo &TTI, // Loop over all of the basic blocks and remove them if they are unneeded... // for (Function::iterator BBIt = F.begin(); BBIt != F.end(); ) { - if (SimplifyCFG(BBIt++, TTI, DL)) { + if (SimplifyCFG(BBIt++, TTI, BonusInstThreshold, DL, AC)) { LocalChange = true; ++NumSimpl; } @@ -172,12 +183,14 @@ bool CFGSimplifyPass::runOnFunction(Function &F) { if (skipOptnoneFunction(F)) return false; + AssumptionCache *AC = + &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); const TargetTransformInfo &TTI = getAnalysis<TargetTransformInfo>(); DataLayoutPass *DLP = getAnalysisIfAvailable<DataLayoutPass>(); const DataLayout *DL = DLP ? &DLP->getDataLayout() : nullptr; bool EverChanged = removeUnreachableBlocks(F); EverChanged |= mergeEmptyReturnBlocks(F); - EverChanged |= iterativelySimplifyCFG(F, TTI, DL); + EverChanged |= iterativelySimplifyCFG(F, TTI, DL, AC, BonusInstThreshold); // If neither pass changed anything, we're done. if (!EverChanged) return false; @@ -191,7 +204,7 @@ bool CFGSimplifyPass::runOnFunction(Function &F) { return true; do { - EverChanged = iterativelySimplifyCFG(F, TTI, DL); + EverChanged = iterativelySimplifyCFG(F, TTI, DL, AC, BonusInstThreshold); EverChanged |= removeUnreachableBlocks(F); } while (EverChanged); diff --git a/lib/Transforms/Scalar/Sink.cpp b/lib/Transforms/Scalar/Sink.cpp index 7348c45c5d37..903b675fdd56 100644 --- a/lib/Transforms/Scalar/Sink.cpp +++ b/lib/Transforms/Scalar/Sink.cpp @@ -56,7 +56,7 @@ namespace { } private: bool ProcessBlock(BasicBlock &BB); - bool SinkInstruction(Instruction *I, SmallPtrSet<Instruction *, 8> &Stores); + bool SinkInstruction(Instruction *I, SmallPtrSetImpl<Instruction*> &Stores); bool AllUsesDominatedByBlock(Instruction *Inst, BasicBlock *BB) const; bool IsAcceptableTarget(Instruction *Inst, BasicBlock *SuccToSinkTo) const; }; @@ -157,7 +157,7 @@ bool Sinking::ProcessBlock(BasicBlock &BB) { } static bool isSafeToMove(Instruction *Inst, AliasAnalysis *AA, - SmallPtrSet<Instruction *, 8> &Stores) { + SmallPtrSetImpl<Instruction *> &Stores) { if (Inst->mayWriteToMemory()) { Stores.insert(Inst); @@ -166,9 +166,8 @@ static bool isSafeToMove(Instruction *Inst, AliasAnalysis *AA, if (LoadInst *L = dyn_cast<LoadInst>(Inst)) { AliasAnalysis::Location Loc = AA->getLocation(L); - for (SmallPtrSet<Instruction *, 8>::iterator I = Stores.begin(), - E = Stores.end(); I != E; ++I) - if (AA->getModRefInfo(*I, Loc) & AliasAnalysis::Mod) + for (Instruction *S : Stores) + if (AA->getModRefInfo(S, Loc) & AliasAnalysis::Mod) return false; } @@ -220,7 +219,7 @@ bool Sinking::IsAcceptableTarget(Instruction *Inst, /// SinkInstruction - Determine whether it is safe to sink the specified machine /// instruction out of its current block into a successor. bool Sinking::SinkInstruction(Instruction *Inst, - SmallPtrSet<Instruction *, 8> &Stores) { + SmallPtrSetImpl<Instruction *> &Stores) { // Don't sink static alloca instructions. CodeGen assumes allocas outside the // entry block are dynamically sized stack objects. diff --git a/lib/Transforms/Scalar/StructurizeCFG.cpp b/lib/Transforms/Scalar/StructurizeCFG.cpp index b9673ed655e0..7fe87f9319b6 100644 --- a/lib/Transforms/Scalar/StructurizeCFG.cpp +++ b/lib/Transforms/Scalar/StructurizeCFG.cpp @@ -10,6 +10,7 @@ #include "llvm/Transforms/Scalar.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/SCCIterator.h" +#include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/RegionInfo.h" #include "llvm/Analysis/RegionIterator.h" #include "llvm/Analysis/RegionPass.h" @@ -166,6 +167,7 @@ class StructurizeCFG : public RegionPass { Region *ParentRegion; DominatorTree *DT; + LoopInfo *LI; RNVector Order; BBSet Visited; @@ -247,6 +249,7 @@ public: void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequiredID(LowerSwitchID); AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<LoopInfo>(); AU.addPreserved<DominatorTreeWrapperPass>(); RegionPass::getAnalysisUsage(AU); } @@ -301,8 +304,9 @@ void StructurizeCFG::analyzeLoops(RegionNode *N) { for (unsigned i = 0, e = Term->getNumSuccessors(); i != e; ++i) { BasicBlock *Succ = Term->getSuccessor(i); - if (Visited.count(Succ)) + if (Visited.count(Succ) && LI->isLoopHeader(Succ) ) { Loops[Succ] = BB; + } } } } @@ -862,6 +866,7 @@ bool StructurizeCFG::runOnRegion(Region *R, RGPassManager &RGM) { ParentRegion = R; DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + LI = &getAnalysis<LoopInfo>(); orderNodes(); collectInfos(); diff --git a/lib/Transforms/Scalar/TailRecursionElimination.cpp b/lib/Transforms/Scalar/TailRecursionElimination.cpp index b7580255150c..f3c3e3054b60 100644 --- a/lib/Transforms/Scalar/TailRecursionElimination.cpp +++ b/lib/Transforms/Scalar/TailRecursionElimination.cpp @@ -63,6 +63,7 @@ #include "llvm/IR/CFG.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/Function.h" @@ -86,6 +87,7 @@ STATISTIC(NumAccumAdded, "Number of accumulators introduced"); namespace { struct TailCallElim : public FunctionPass { const TargetTransformInfo *TTI; + const DataLayout *DL; static char ID; // Pass identification, replacement for typeid TailCallElim() : FunctionPass(ID) { @@ -157,6 +159,8 @@ bool TailCallElim::runOnFunction(Function &F) { if (skipOptnoneFunction(F)) return false; + DL = F.getParent()->getDataLayout(); + bool AllCallsAreTailCalls = false; bool Modified = markTails(F, AllCallsAreTailCalls); if (AllCallsAreTailCalls) @@ -175,7 +179,7 @@ struct AllocaDerivedValueTracker { auto AddUsesToWorklist = [&](Value *V) { for (auto &U : V->uses()) { - if (!Visited.insert(&U)) + if (!Visited.insert(&U).second) continue; Worklist.push_back(&U); } @@ -400,18 +404,28 @@ bool TailCallElim::runTRE(Function &F) { // alloca' is changed from being a static alloca to being a dynamic alloca. // Until this is resolved, disable this transformation if that would ever // happen. This bug is PR962. + SmallVector<BasicBlock*, 8> BBToErase; for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB) { if (ReturnInst *Ret = dyn_cast<ReturnInst>(BB->getTerminator())) { bool Change = ProcessReturningBlock(Ret, OldEntry, TailCallsAreMarkedTail, ArgumentPHIs, !CanTRETailMarkedCall); - if (!Change && BB->getFirstNonPHIOrDbg() == Ret) + if (!Change && BB->getFirstNonPHIOrDbg() == Ret) { Change = FoldReturnAndProcessPred(BB, Ret, OldEntry, TailCallsAreMarkedTail, ArgumentPHIs, !CanTRETailMarkedCall); + // FoldReturnAndProcessPred may have emptied some BB. Remember to + // erase them. + if (Change && BB->empty()) + BBToErase.push_back(BB); + + } MadeChange |= Change; } } + for (auto BB: BBToErase) + BB->eraseFromParent(); + // If we eliminated any tail recursions, it's possible that we inserted some // silly PHI nodes which just merge an initial value (the incoming operand) // with themselves. Check to see if we did and clean up our mess if so. This @@ -450,7 +464,7 @@ bool TailCallElim::CanMoveAboveCall(Instruction *I, CallInst *CI) { // being loaded from. if (CI->mayWriteToMemory() || !isSafeToLoadUnconditionally(L->getPointerOperand(), L, - L->getAlignment())) + L->getAlignment(), DL)) return false; } } @@ -819,8 +833,20 @@ bool TailCallElim::FoldReturnAndProcessPred(BasicBlock *BB, if (CallInst *CI = FindTRECandidate(BI, CannotTailCallElimCallsMarkedTail)){ DEBUG(dbgs() << "FOLDING: " << *BB << "INTO UNCOND BRANCH PRED: " << *Pred); - EliminateRecursiveTailCall(CI, FoldReturnIntoUncondBranch(Ret, BB, Pred), - OldEntry, TailCallsAreMarkedTail, ArgumentPHIs, + ReturnInst *RI = FoldReturnIntoUncondBranch(Ret, BB, Pred); + + // Cleanup: if all predecessors of BB have been eliminated by + // FoldReturnIntoUncondBranch, we would like to delete it, but we + // can not just nuke it as it is being used as an iterator by our caller. + // Just empty it, and the caller will erase it when it is safe to do so. + // It is important to empty it, because the ret instruction in there is + // still using a value which EliminateRecursiveTailCall will attempt + // to remove. + if (!BB->hasAddressTaken() && pred_begin(BB) == pred_end(BB)) + BB->getInstList().clear(); + + EliminateRecursiveTailCall(CI, RI, OldEntry, TailCallsAreMarkedTail, + ArgumentPHIs, CannotTailCallElimCallsMarkedTail); ++NumRetDuped; Change = true; diff --git a/lib/Transforms/Utils/AddDiscriminators.cpp b/lib/Transforms/Utils/AddDiscriminators.cpp index 196ac79aaf29..820544bcebf0 100644 --- a/lib/Transforms/Utils/AddDiscriminators.cpp +++ b/lib/Transforms/Utils/AddDiscriminators.cpp @@ -167,7 +167,7 @@ bool AddDiscriminators::runOnFunction(Function &F) { bool Changed = false; Module *M = F.getParent(); LLVMContext &Ctx = M->getContext(); - DIBuilder Builder(*M); + DIBuilder Builder(*M, /*AllowUnresolved*/ false); // Traverse all the blocks looking for instructions in different // blocks that are at the same file:line location. @@ -193,13 +193,11 @@ bool AddDiscriminators::runOnFunction(Function &F) { // Create a new lexical scope and compute a new discriminator // number for it. StringRef Filename = FirstDIL.getFilename(); - unsigned LineNumber = FirstDIL.getLineNumber(); - unsigned ColumnNumber = FirstDIL.getColumnNumber(); DIScope Scope = FirstDIL.getScope(); DIFile File = Builder.createFile(Filename, Scope.getDirectory()); unsigned Discriminator = FirstDIL.computeNewDiscriminator(Ctx); - DILexicalBlock NewScope = Builder.createLexicalBlock( - Scope, File, LineNumber, ColumnNumber, Discriminator); + DILexicalBlockFile NewScope = + Builder.createLexicalBlockFile(Scope, File, Discriminator); DILocation NewDIL = FirstDIL.copyWithNewScope(Ctx, NewScope); DebugLoc newDebugLoc = DebugLoc::getFromDILocation(NewDIL); diff --git a/lib/Transforms/Utils/BasicBlockUtils.cpp b/lib/Transforms/Utils/BasicBlockUtils.cpp index 602e8ba55107..983f025a1a3a 100644 --- a/lib/Transforms/Utils/BasicBlockUtils.cpp +++ b/lib/Transforms/Utils/BasicBlockUtils.cpp @@ -265,6 +265,18 @@ BasicBlock *llvm::SplitEdge(BasicBlock *BB, BasicBlock *Succ, Pass *P) { return SplitBlock(BB, BB->getTerminator(), P); } +unsigned llvm::SplitAllCriticalEdges(Function &F, Pass *P) { + unsigned NumBroken = 0; + for (Function::iterator I = F.begin(), E = F.end(); I != E; ++I) { + TerminatorInst *TI = I->getTerminator(); + if (TI->getNumSuccessors() > 1 && !isa<IndirectBrInst>(TI)) + for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) + if (SplitCriticalEdge(TI, i, P)) + ++NumBroken; + } + return NumBroken; +} + /// SplitBlock - Split the specified block at the specified instruction - every /// thing before SplitPt stays in Old and everything starting with SplitPt moves /// to a new block. The two blocks are joined by an unconditional branch and diff --git a/lib/Transforms/Utils/BreakCriticalEdges.cpp b/lib/Transforms/Utils/BreakCriticalEdges.cpp index 80bd51637514..eda22cfc1bab 100644 --- a/lib/Transforms/Utils/BreakCriticalEdges.cpp +++ b/lib/Transforms/Utils/BreakCriticalEdges.cpp @@ -40,7 +40,11 @@ namespace { initializeBreakCriticalEdgesPass(*PassRegistry::getPassRegistry()); } - bool runOnFunction(Function &F) override; + bool runOnFunction(Function &F) override { + unsigned N = SplitAllCriticalEdges(F, this); + NumBroken += N; + return N > 0; + } void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addPreserved<DominatorTreeWrapperPass>(); @@ -62,24 +66,6 @@ FunctionPass *llvm::createBreakCriticalEdgesPass() { return new BreakCriticalEdges(); } -// runOnFunction - Loop over all of the edges in the CFG, breaking critical -// edges as they are found. -// -bool BreakCriticalEdges::runOnFunction(Function &F) { - bool Changed = false; - for (Function::iterator I = F.begin(), E = F.end(); I != E; ++I) { - TerminatorInst *TI = I->getTerminator(); - if (TI->getNumSuccessors() > 1 && !isa<IndirectBrInst>(TI)) - for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) - if (SplitCriticalEdge(TI, i, this)) { - ++NumBroken; - Changed = true; - } - } - - return Changed; -} - //===----------------------------------------------------------------------===// // Implementation of the external critical edge manipulation functions //===----------------------------------------------------------------------===// diff --git a/lib/Transforms/Utils/BuildLibCalls.cpp b/lib/Transforms/Utils/BuildLibCalls.cpp index be00b6956199..322485d9e32a 100644 --- a/lib/Transforms/Utils/BuildLibCalls.cpp +++ b/lib/Transforms/Utils/BuildLibCalls.cpp @@ -42,8 +42,7 @@ Value *llvm::EmitStrLen(Value *Ptr, IRBuilder<> &B, const DataLayout *TD, AttributeSet AS[2]; AS[0] = AttributeSet::get(M->getContext(), 1, Attribute::NoCapture); Attribute::AttrKind AVs[2] = { Attribute::ReadOnly, Attribute::NoUnwind }; - AS[1] = AttributeSet::get(M->getContext(), AttributeSet::FunctionIndex, - ArrayRef<Attribute::AttrKind>(AVs, 2)); + AS[1] = AttributeSet::get(M->getContext(), AttributeSet::FunctionIndex, AVs); LLVMContext &Context = B.GetInsertBlock()->getContext(); Constant *StrLen = M->getOrInsertFunction("strlen", @@ -51,7 +50,7 @@ Value *llvm::EmitStrLen(Value *Ptr, IRBuilder<> &B, const DataLayout *TD, AS), TD->getIntPtrType(Context), B.getInt8PtrTy(), - NULL); + nullptr); CallInst *CI = B.CreateCall(StrLen, CastToCStr(Ptr, B), "strlen"); if (const Function *F = dyn_cast<Function>(StrLen->stripPointerCasts())) CI->setCallingConv(F->getCallingConv()); @@ -71,8 +70,7 @@ Value *llvm::EmitStrNLen(Value *Ptr, Value *MaxLen, IRBuilder<> &B, AttributeSet AS[2]; AS[0] = AttributeSet::get(M->getContext(), 1, Attribute::NoCapture); Attribute::AttrKind AVs[2] = { Attribute::ReadOnly, Attribute::NoUnwind }; - AS[1] = AttributeSet::get(M->getContext(), AttributeSet::FunctionIndex, - ArrayRef<Attribute::AttrKind>(AVs, 2)); + AS[1] = AttributeSet::get(M->getContext(), AttributeSet::FunctionIndex, AVs); LLVMContext &Context = B.GetInsertBlock()->getContext(); Constant *StrNLen = M->getOrInsertFunction("strnlen", @@ -81,7 +79,7 @@ Value *llvm::EmitStrNLen(Value *Ptr, Value *MaxLen, IRBuilder<> &B, TD->getIntPtrType(Context), B.getInt8PtrTy(), TD->getIntPtrType(Context), - NULL); + nullptr); CallInst *CI = B.CreateCall2(StrNLen, CastToCStr(Ptr, B), MaxLen, "strnlen"); if (const Function *F = dyn_cast<Function>(StrNLen->stripPointerCasts())) CI->setCallingConv(F->getCallingConv()); @@ -100,15 +98,14 @@ Value *llvm::EmitStrChr(Value *Ptr, char C, IRBuilder<> &B, Module *M = B.GetInsertBlock()->getParent()->getParent(); Attribute::AttrKind AVs[2] = { Attribute::ReadOnly, Attribute::NoUnwind }; AttributeSet AS = - AttributeSet::get(M->getContext(), AttributeSet::FunctionIndex, - ArrayRef<Attribute::AttrKind>(AVs, 2)); + AttributeSet::get(M->getContext(), AttributeSet::FunctionIndex, AVs); Type *I8Ptr = B.getInt8PtrTy(); Type *I32Ty = B.getInt32Ty(); Constant *StrChr = M->getOrInsertFunction("strchr", AttributeSet::get(M->getContext(), AS), - I8Ptr, I8Ptr, I32Ty, NULL); + I8Ptr, I8Ptr, I32Ty, nullptr); CallInst *CI = B.CreateCall2(StrChr, CastToCStr(Ptr, B), ConstantInt::get(I32Ty, C), "strchr"); if (const Function *F = dyn_cast<Function>(StrChr->stripPointerCasts())) @@ -128,8 +125,7 @@ Value *llvm::EmitStrNCmp(Value *Ptr1, Value *Ptr2, Value *Len, AS[0] = AttributeSet::get(M->getContext(), 1, Attribute::NoCapture); AS[1] = AttributeSet::get(M->getContext(), 2, Attribute::NoCapture); Attribute::AttrKind AVs[2] = { Attribute::ReadOnly, Attribute::NoUnwind }; - AS[2] = AttributeSet::get(M->getContext(), AttributeSet::FunctionIndex, - ArrayRef<Attribute::AttrKind>(AVs, 2)); + AS[2] = AttributeSet::get(M->getContext(), AttributeSet::FunctionIndex, AVs); LLVMContext &Context = B.GetInsertBlock()->getContext(); Value *StrNCmp = M->getOrInsertFunction("strncmp", @@ -138,7 +134,7 @@ Value *llvm::EmitStrNCmp(Value *Ptr1, Value *Ptr2, Value *Len, B.getInt32Ty(), B.getInt8PtrTy(), B.getInt8PtrTy(), - TD->getIntPtrType(Context), NULL); + TD->getIntPtrType(Context), nullptr); CallInst *CI = B.CreateCall3(StrNCmp, CastToCStr(Ptr1, B), CastToCStr(Ptr2, B), Len, "strncmp"); @@ -164,7 +160,7 @@ Value *llvm::EmitStrCpy(Value *Dst, Value *Src, IRBuilder<> &B, Type *I8Ptr = B.getInt8PtrTy(); Value *StrCpy = M->getOrInsertFunction(Name, AttributeSet::get(M->getContext(), AS), - I8Ptr, I8Ptr, I8Ptr, NULL); + I8Ptr, I8Ptr, I8Ptr, nullptr); CallInst *CI = B.CreateCall2(StrCpy, CastToCStr(Dst, B), CastToCStr(Src, B), Name); if (const Function *F = dyn_cast<Function>(StrCpy->stripPointerCasts())) @@ -190,7 +186,7 @@ Value *llvm::EmitStrNCpy(Value *Dst, Value *Src, Value *Len, AttributeSet::get(M->getContext(), AS), I8Ptr, I8Ptr, I8Ptr, - Len->getType(), NULL); + Len->getType(), nullptr); CallInst *CI = B.CreateCall3(StrNCpy, CastToCStr(Dst, B), CastToCStr(Src, B), Len, "strncpy"); if (const Function *F = dyn_cast<Function>(StrNCpy->stripPointerCasts())) @@ -218,7 +214,7 @@ Value *llvm::EmitMemCpyChk(Value *Dst, Value *Src, Value *Len, Value *ObjSize, B.getInt8PtrTy(), B.getInt8PtrTy(), TD->getIntPtrType(Context), - TD->getIntPtrType(Context), NULL); + TD->getIntPtrType(Context), nullptr); Dst = CastToCStr(Dst, B); Src = CastToCStr(Src, B); CallInst *CI = B.CreateCall4(MemCpy, Dst, Src, Len, ObjSize); @@ -238,8 +234,7 @@ Value *llvm::EmitMemChr(Value *Ptr, Value *Val, Module *M = B.GetInsertBlock()->getParent()->getParent(); AttributeSet AS; Attribute::AttrKind AVs[2] = { Attribute::ReadOnly, Attribute::NoUnwind }; - AS = AttributeSet::get(M->getContext(), AttributeSet::FunctionIndex, - ArrayRef<Attribute::AttrKind>(AVs, 2)); + AS = AttributeSet::get(M->getContext(), AttributeSet::FunctionIndex, AVs); LLVMContext &Context = B.GetInsertBlock()->getContext(); Value *MemChr = M->getOrInsertFunction("memchr", AttributeSet::get(M->getContext(), AS), @@ -247,7 +242,7 @@ Value *llvm::EmitMemChr(Value *Ptr, Value *Val, B.getInt8PtrTy(), B.getInt32Ty(), TD->getIntPtrType(Context), - NULL); + nullptr); CallInst *CI = B.CreateCall3(MemChr, CastToCStr(Ptr, B), Val, Len, "memchr"); if (const Function *F = dyn_cast<Function>(MemChr->stripPointerCasts())) @@ -268,8 +263,7 @@ Value *llvm::EmitMemCmp(Value *Ptr1, Value *Ptr2, AS[0] = AttributeSet::get(M->getContext(), 1, Attribute::NoCapture); AS[1] = AttributeSet::get(M->getContext(), 2, Attribute::NoCapture); Attribute::AttrKind AVs[2] = { Attribute::ReadOnly, Attribute::NoUnwind }; - AS[2] = AttributeSet::get(M->getContext(), AttributeSet::FunctionIndex, - ArrayRef<Attribute::AttrKind>(AVs, 2)); + AS[2] = AttributeSet::get(M->getContext(), AttributeSet::FunctionIndex, AVs); LLVMContext &Context = B.GetInsertBlock()->getContext(); Value *MemCmp = M->getOrInsertFunction("memcmp", @@ -277,7 +271,7 @@ Value *llvm::EmitMemCmp(Value *Ptr1, Value *Ptr2, B.getInt32Ty(), B.getInt8PtrTy(), B.getInt8PtrTy(), - TD->getIntPtrType(Context), NULL); + TD->getIntPtrType(Context), nullptr); CallInst *CI = B.CreateCall3(MemCmp, CastToCStr(Ptr1, B), CastToCStr(Ptr2, B), Len, "memcmp"); @@ -313,7 +307,7 @@ Value *llvm::EmitUnaryFloatFnCall(Value *Op, StringRef Name, IRBuilder<> &B, Module *M = B.GetInsertBlock()->getParent()->getParent(); Value *Callee = M->getOrInsertFunction(Name, Op->getType(), - Op->getType(), NULL); + Op->getType(), nullptr); CallInst *CI = B.CreateCall(Callee, Op, Name); CI->setAttributes(Attrs); if (const Function *F = dyn_cast<Function>(Callee->stripPointerCasts())) @@ -334,7 +328,7 @@ Value *llvm::EmitBinaryFloatFnCall(Value *Op1, Value *Op2, StringRef Name, Module *M = B.GetInsertBlock()->getParent()->getParent(); Value *Callee = M->getOrInsertFunction(Name, Op1->getType(), - Op1->getType(), Op2->getType(), NULL); + Op1->getType(), Op2->getType(), nullptr); CallInst *CI = B.CreateCall2(Callee, Op1, Op2, Name); CI->setAttributes(Attrs); if (const Function *F = dyn_cast<Function>(Callee->stripPointerCasts())) @@ -352,7 +346,7 @@ Value *llvm::EmitPutChar(Value *Char, IRBuilder<> &B, const DataLayout *TD, Module *M = B.GetInsertBlock()->getParent()->getParent(); Value *PutChar = M->getOrInsertFunction("putchar", B.getInt32Ty(), - B.getInt32Ty(), NULL); + B.getInt32Ty(), nullptr); CallInst *CI = B.CreateCall(PutChar, B.CreateIntCast(Char, B.getInt32Ty(), @@ -382,7 +376,7 @@ Value *llvm::EmitPutS(Value *Str, IRBuilder<> &B, const DataLayout *TD, AttributeSet::get(M->getContext(), AS), B.getInt32Ty(), B.getInt8PtrTy(), - NULL); + nullptr); CallInst *CI = B.CreateCall(PutS, CastToCStr(Str, B), "puts"); if (const Function *F = dyn_cast<Function>(PutS->stripPointerCasts())) CI->setCallingConv(F->getCallingConv()); @@ -407,12 +401,12 @@ Value *llvm::EmitFPutC(Value *Char, Value *File, IRBuilder<> &B, AttributeSet::get(M->getContext(), AS), B.getInt32Ty(), B.getInt32Ty(), File->getType(), - NULL); + nullptr); else F = M->getOrInsertFunction("fputc", B.getInt32Ty(), B.getInt32Ty(), - File->getType(), NULL); + File->getType(), nullptr); Char = B.CreateIntCast(Char, B.getInt32Ty(), /*isSigned*/true, "chari"); CallInst *CI = B.CreateCall2(F, Char, File, "fputc"); @@ -442,11 +436,11 @@ Value *llvm::EmitFPutS(Value *Str, Value *File, IRBuilder<> &B, AttributeSet::get(M->getContext(), AS), B.getInt32Ty(), B.getInt8PtrTy(), - File->getType(), NULL); + File->getType(), nullptr); else F = M->getOrInsertFunction(FPutsName, B.getInt32Ty(), B.getInt8PtrTy(), - File->getType(), NULL); + File->getType(), nullptr); CallInst *CI = B.CreateCall2(F, CastToCStr(Str, B), File, "fputs"); if (const Function *Fn = dyn_cast<Function>(F->stripPointerCasts())) @@ -478,13 +472,13 @@ Value *llvm::EmitFWrite(Value *Ptr, Value *Size, Value *File, B.getInt8PtrTy(), TD->getIntPtrType(Context), TD->getIntPtrType(Context), - File->getType(), NULL); + File->getType(), nullptr); else F = M->getOrInsertFunction(FWriteName, TD->getIntPtrType(Context), B.getInt8PtrTy(), TD->getIntPtrType(Context), TD->getIntPtrType(Context), - File->getType(), NULL); + File->getType(), nullptr); CallInst *CI = B.CreateCall4(F, CastToCStr(Ptr, B), Size, ConstantInt::get(TD->getIntPtrType(Context), 1), File); @@ -492,135 +486,3 @@ Value *llvm::EmitFWrite(Value *Ptr, Value *Size, Value *File, CI->setCallingConv(Fn->getCallingConv()); return CI; } - -SimplifyFortifiedLibCalls::~SimplifyFortifiedLibCalls() { } - -bool SimplifyFortifiedLibCalls::fold(CallInst *CI, const DataLayout *TD, - const TargetLibraryInfo *TLI) { - // We really need DataLayout for later. - if (!TD) return false; - - this->CI = CI; - Function *Callee = CI->getCalledFunction(); - StringRef Name = Callee->getName(); - FunctionType *FT = Callee->getFunctionType(); - LLVMContext &Context = CI->getParent()->getContext(); - IRBuilder<> B(CI); - - if (Name == "__memcpy_chk") { - // Check if this has the right signature. - if (FT->getNumParams() != 4 || FT->getReturnType() != FT->getParamType(0) || - !FT->getParamType(0)->isPointerTy() || - !FT->getParamType(1)->isPointerTy() || - FT->getParamType(2) != TD->getIntPtrType(Context) || - FT->getParamType(3) != TD->getIntPtrType(Context)) - return false; - - if (isFoldable(3, 2, false)) { - B.CreateMemCpy(CI->getArgOperand(0), CI->getArgOperand(1), - CI->getArgOperand(2), 1); - replaceCall(CI->getArgOperand(0)); - return true; - } - return false; - } - - // Should be similar to memcpy. - if (Name == "__mempcpy_chk") { - return false; - } - - if (Name == "__memmove_chk") { - // Check if this has the right signature. - if (FT->getNumParams() != 4 || FT->getReturnType() != FT->getParamType(0) || - !FT->getParamType(0)->isPointerTy() || - !FT->getParamType(1)->isPointerTy() || - FT->getParamType(2) != TD->getIntPtrType(Context) || - FT->getParamType(3) != TD->getIntPtrType(Context)) - return false; - - if (isFoldable(3, 2, false)) { - B.CreateMemMove(CI->getArgOperand(0), CI->getArgOperand(1), - CI->getArgOperand(2), 1); - replaceCall(CI->getArgOperand(0)); - return true; - } - return false; - } - - if (Name == "__memset_chk") { - // Check if this has the right signature. - if (FT->getNumParams() != 4 || FT->getReturnType() != FT->getParamType(0) || - !FT->getParamType(0)->isPointerTy() || - !FT->getParamType(1)->isIntegerTy() || - FT->getParamType(2) != TD->getIntPtrType(Context) || - FT->getParamType(3) != TD->getIntPtrType(Context)) - return false; - - if (isFoldable(3, 2, false)) { - Value *Val = B.CreateIntCast(CI->getArgOperand(1), B.getInt8Ty(), - false); - B.CreateMemSet(CI->getArgOperand(0), Val, CI->getArgOperand(2), 1); - replaceCall(CI->getArgOperand(0)); - return true; - } - return false; - } - - if (Name == "__strcpy_chk" || Name == "__stpcpy_chk") { - // Check if this has the right signature. - if (FT->getNumParams() != 3 || - FT->getReturnType() != FT->getParamType(0) || - FT->getParamType(0) != FT->getParamType(1) || - FT->getParamType(0) != Type::getInt8PtrTy(Context) || - FT->getParamType(2) != TD->getIntPtrType(Context)) - return 0; - - - // If a) we don't have any length information, or b) we know this will - // fit then just lower to a plain st[rp]cpy. Otherwise we'll keep our - // st[rp]cpy_chk call which may fail at runtime if the size is too long. - // TODO: It might be nice to get a maximum length out of the possible - // string lengths for varying. - if (isFoldable(2, 1, true)) { - Value *Ret = EmitStrCpy(CI->getArgOperand(0), CI->getArgOperand(1), B, TD, - TLI, Name.substr(2, 6)); - if (!Ret) - return false; - replaceCall(Ret); - return true; - } - return false; - } - - if (Name == "__strncpy_chk" || Name == "__stpncpy_chk") { - // Check if this has the right signature. - if (FT->getNumParams() != 4 || FT->getReturnType() != FT->getParamType(0) || - FT->getParamType(0) != FT->getParamType(1) || - FT->getParamType(0) != Type::getInt8PtrTy(Context) || - !FT->getParamType(2)->isIntegerTy() || - FT->getParamType(3) != TD->getIntPtrType(Context)) - return false; - - if (isFoldable(3, 2, false)) { - Value *Ret = EmitStrNCpy(CI->getArgOperand(0), CI->getArgOperand(1), - CI->getArgOperand(2), B, TD, TLI, - Name.substr(2, 7)); - if (!Ret) - return false; - replaceCall(Ret); - return true; - } - return false; - } - - if (Name == "__strcat_chk") { - return false; - } - - if (Name == "__strncat_chk") { - return false; - } - - return false; -} diff --git a/lib/Transforms/Utils/CMakeLists.txt b/lib/Transforms/Utils/CMakeLists.txt index fcf548f97c5d..6ce22b101825 100644 --- a/lib/Transforms/Utils/CMakeLists.txt +++ b/lib/Transforms/Utils/CMakeLists.txt @@ -1,16 +1,17 @@ add_llvm_library(LLVMTransformUtils - AddDiscriminators.cpp ASanStackFrameLayout.cpp + AddDiscriminators.cpp BasicBlockUtils.cpp BreakCriticalEdges.cpp BuildLibCalls.cpp BypassSlowDivision.cpp - CtorUtils.cpp CloneFunction.cpp CloneModule.cpp CmpInstAnalysis.cpp CodeExtractor.cpp + CtorUtils.cpp DemoteRegToStack.cpp + FlattenCFG.cpp GlobalStatus.cpp InlineFunction.cpp InstructionNamer.cpp @@ -29,10 +30,10 @@ add_llvm_library(LLVMTransformUtils PromoteMemoryToRegister.cpp SSAUpdater.cpp SimplifyCFG.cpp - FlattenCFG.cpp SimplifyIndVar.cpp SimplifyInstructions.cpp SimplifyLibCalls.cpp + SymbolRewriter.cpp UnifyFunctionExitNodes.cpp Utils.cpp ValueMapper.cpp diff --git a/lib/Transforms/Utils/CloneFunction.cpp b/lib/Transforms/Utils/CloneFunction.cpp index 5c8f20d5f884..96a763fac93b 100644 --- a/lib/Transforms/Utils/CloneFunction.cpp +++ b/lib/Transforms/Utils/CloneFunction.cpp @@ -164,14 +164,13 @@ static MDNode* FindSubprogram(const Function *F, DebugInfoFinder &Finder) { // Add an operand to an existing MDNode. The new operand will be added at the // back of the operand list. -static void AddOperand(MDNode *Node, Value *Operand) { - SmallVector<Value*, 16> Operands; - for (unsigned i = 0; i < Node->getNumOperands(); i++) { - Operands.push_back(Node->getOperand(i)); - } - Operands.push_back(Operand); - MDNode *NewNode = MDNode::get(Node->getContext(), Operands); - Node->replaceAllUsesWith(NewNode); +static void AddOperand(DICompileUnit CU, DIArray SPs, Metadata *NewSP) { + SmallVector<Metadata *, 16> NewSPs; + NewSPs.reserve(SPs->getNumOperands() + 1); + for (unsigned I = 0, E = SPs->getNumOperands(); I != E; ++I) + NewSPs.push_back(SPs->getOperand(I)); + NewSPs.push_back(NewSP); + CU.replaceSubprograms(DIArray(MDNode::get(CU->getContext(), NewSPs))); } // Clone the module-level debug info associated with OldFunc. The cloned data @@ -187,7 +186,7 @@ static void CloneDebugInfoMetadata(Function *NewFunc, const Function *OldFunc, // Ensure that OldFunc appears in the map. // (if it's already there it must point to NewFunc anyway) VMap[OldFunc] = NewFunc; - DISubprogram NewSubprogram(MapValue(OldSubprogramMDNode, VMap)); + DISubprogram NewSubprogram(MapMetadata(OldSubprogramMDNode, VMap)); for (DICompileUnit CU : Finder.compile_units()) { DIArray Subprograms(CU.getSubprograms()); @@ -196,7 +195,8 @@ static void CloneDebugInfoMetadata(Function *NewFunc, const Function *OldFunc, // also contain the new one. for (unsigned i = 0; i < Subprograms.getNumElements(); i++) { if ((MDNode*)Subprograms.getElement(i) == OldSubprogramMDNode) { - AddOperand(Subprograms, NewSubprogram); + AddOperand(CU, Subprograms, NewSubprogram); + break; } } } diff --git a/lib/Transforms/Utils/CloneModule.cpp b/lib/Transforms/Utils/CloneModule.cpp index 3f75b3e677ee..fae9ff5bce0f 100644 --- a/lib/Transforms/Utils/CloneModule.cpp +++ b/lib/Transforms/Utils/CloneModule.cpp @@ -17,6 +17,7 @@ #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Module.h" #include "llvm/Transforms/Utils/ValueMapper.h" +#include "llvm-c/Core.h" using namespace llvm; /// CloneModule - Return an exact copy of the specified module. This is not as @@ -108,7 +109,7 @@ Module *llvm::CloneModule(const Module *M, ValueToValueMapTy &VMap) { I != E; ++I) { GlobalAlias *GA = cast<GlobalAlias>(VMap[I]); if (const Constant *C = I->getAliasee()) - GA->setAliasee(cast<GlobalObject>(MapValue(C, VMap))); + GA->setAliasee(MapValue(C, VMap)); } // And named metadata.... @@ -117,8 +118,16 @@ Module *llvm::CloneModule(const Module *M, ValueToValueMapTy &VMap) { const NamedMDNode &NMD = *I; NamedMDNode *NewNMD = New->getOrInsertNamedMetadata(NMD.getName()); for (unsigned i = 0, e = NMD.getNumOperands(); i != e; ++i) - NewNMD->addOperand(MapValue(NMD.getOperand(i), VMap)); + NewNMD->addOperand(MapMetadata(NMD.getOperand(i), VMap)); } return New; } + +extern "C" { + +LLVMModuleRef LLVMCloneModule(LLVMModuleRef M) { + return wrap(CloneModule(unwrap(M))); +} + +} diff --git a/lib/Transforms/Utils/CtorUtils.cpp b/lib/Transforms/Utils/CtorUtils.cpp index a3594248de0c..26875e837b8b 100644 --- a/lib/Transforms/Utils/CtorUtils.cpp +++ b/lib/Transforms/Utils/CtorUtils.cpp @@ -11,6 +11,7 @@ // //===----------------------------------------------------------------------===// +#include "llvm/ADT/BitVector.h" #include "llvm/Transforms/Utils/CtorUtils.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Function.h" @@ -24,41 +25,22 @@ namespace llvm { namespace { -/// Given a specified llvm.global_ctors list, install the -/// specified array. -void installGlobalCtors(GlobalVariable *GCL, - const std::vector<Function *> &Ctors) { - // If we made a change, reassemble the initializer list. - Constant *CSVals[3]; - - StructType *StructTy = - cast<StructType>(GCL->getType()->getElementType()->getArrayElementType()); - - // Create the new init list. - std::vector<Constant *> CAList; - for (Function *F : Ctors) { - Type *Int32Ty = Type::getInt32Ty(GCL->getContext()); - if (F) { - CSVals[0] = ConstantInt::get(Int32Ty, 65535); - CSVals[1] = F; - } else { - CSVals[0] = ConstantInt::get(Int32Ty, 0x7fffffff); - CSVals[1] = Constant::getNullValue(StructTy->getElementType(1)); - } - // FIXME: Only allow the 3-field form in LLVM 4.0. - size_t NumElts = StructTy->getNumElements(); - if (NumElts > 2) - CSVals[2] = Constant::getNullValue(StructTy->getElementType(2)); - CAList.push_back( - ConstantStruct::get(StructTy, makeArrayRef(CSVals, NumElts))); - } - - // Create the array initializer. - Constant *CA = - ConstantArray::get(ArrayType::get(StructTy, CAList.size()), CAList); +/// Given a specified llvm.global_ctors list, remove the listed elements. +void removeGlobalCtors(GlobalVariable *GCL, const BitVector &CtorsToRemove) { + // Filter out the initializer elements to remove. + ConstantArray *OldCA = cast<ConstantArray>(GCL->getInitializer()); + SmallVector<Constant *, 10> CAList; + for (unsigned I = 0, E = OldCA->getNumOperands(); I < E; ++I) + if (!CtorsToRemove.test(I)) + CAList.push_back(OldCA->getOperand(I)); + + // Create the new array initializer. + ArrayType *ATy = + ArrayType::get(OldCA->getType()->getElementType(), CAList.size()); + Constant *CA = ConstantArray::get(ATy, CAList); // If we didn't change the number of elements, don't create a new GV. - if (CA->getType() == GCL->getInitializer()->getType()) { + if (CA->getType() == OldCA->getType()) { GCL->setInitializer(CA); return; } @@ -82,7 +64,7 @@ void installGlobalCtors(GlobalVariable *GCL, /// Given a llvm.global_ctors list that we can understand, /// return a list of the functions and null terminator as a vector. -std::vector<Function*> parseGlobalCtors(GlobalVariable *GV) { +std::vector<Function *> parseGlobalCtors(GlobalVariable *GV) { if (GV->getInitializer()->isNullValue()) return std::vector<Function *>(); ConstantArray *CA = cast<ConstantArray>(GV->getInitializer()); @@ -147,17 +129,15 @@ bool optimizeGlobalCtorsList(Module &M, bool MadeChange = false; // Loop over global ctors, optimizing them when we can. - for (unsigned i = 0; i != Ctors.size(); ++i) { + unsigned NumCtors = Ctors.size(); + BitVector CtorsToRemove(NumCtors); + for (unsigned i = 0; i != Ctors.size() && NumCtors > 0; ++i) { Function *F = Ctors[i]; // Found a null terminator in the middle of the list, prune off the rest of // the list. - if (!F) { - if (i != Ctors.size() - 1) { - Ctors.resize(i + 1); - MadeChange = true; - } - break; - } + if (!F) + continue; + DEBUG(dbgs() << "Optimizing Global Constructor: " << *F << "\n"); // We cannot simplify external ctor functions. @@ -166,9 +146,10 @@ bool optimizeGlobalCtorsList(Module &M, // If we can evaluate the ctor at compile time, do. if (ShouldRemove(F)) { - Ctors.erase(Ctors.begin() + i); + Ctors[i] = nullptr; + CtorsToRemove.set(i); + NumCtors--; MadeChange = true; - --i; continue; } } @@ -176,7 +157,7 @@ bool optimizeGlobalCtorsList(Module &M, if (!MadeChange) return false; - installGlobalCtors(GlobalCtors, Ctors); + removeGlobalCtors(GlobalCtors, CtorsToRemove); return true; } diff --git a/lib/Transforms/Utils/FlattenCFG.cpp b/lib/Transforms/Utils/FlattenCFG.cpp index 51ead40c916e..4eb3e3dd17d2 100644 --- a/lib/Transforms/Utils/FlattenCFG.cpp +++ b/lib/Transforms/Utils/FlattenCFG.cpp @@ -238,9 +238,13 @@ bool FlattenCFGOpt::FlattenParallelAndOr(BasicBlock *BB, IRBuilder<> &Builder, // Do branch inversion. BasicBlock *CurrBlock = LastCondBlock; bool EverChanged = false; - while (1) { + for (;CurrBlock != FirstCondBlock; + CurrBlock = CurrBlock->getSinglePredecessor()) { BranchInst *BI = dyn_cast<BranchInst>(CurrBlock->getTerminator()); CmpInst *CI = dyn_cast<CmpInst>(BI->getCondition()); + if (!CI) + continue; + CmpInst::Predicate Predicate = CI->getPredicate(); // Canonicalize icmp_ne -> icmp_eq, fcmp_one -> fcmp_oeq if ((Predicate == CmpInst::ICMP_NE) || (Predicate == CmpInst::FCMP_ONE)) { @@ -248,9 +252,6 @@ bool FlattenCFGOpt::FlattenParallelAndOr(BasicBlock *BB, IRBuilder<> &Builder, BI->swapSuccessors(); EverChanged = true; } - if (CurrBlock == FirstCondBlock) - break; - CurrBlock = CurrBlock->getSinglePredecessor(); } return EverChanged; } diff --git a/lib/Transforms/Utils/GlobalStatus.cpp b/lib/Transforms/Utils/GlobalStatus.cpp index 12057e4b929c..52e2d59557fa 100644 --- a/lib/Transforms/Utils/GlobalStatus.cpp +++ b/lib/Transforms/Utils/GlobalStatus.cpp @@ -35,6 +35,9 @@ bool llvm::isSafeToDestroyConstant(const Constant *C) { if (isa<GlobalValue>(C)) return false; + if (isa<ConstantInt>(C) || isa<ConstantFP>(C)) + return false; + for (const User *U : C->users()) if (const Constant *CU = dyn_cast<Constant>(U)) { if (!isSafeToDestroyConstant(CU)) @@ -45,7 +48,7 @@ bool llvm::isSafeToDestroyConstant(const Constant *C) { } static bool analyzeGlobalAux(const Value *V, GlobalStatus &GS, - SmallPtrSet<const PHINode *, 16> &PhiUsers) { + SmallPtrSetImpl<const PHINode *> &PhiUsers) { for (const Use &U : V->uses()) { const User *UR = U.getUser(); if (const ConstantExpr *CE = dyn_cast<ConstantExpr>(UR)) { @@ -130,7 +133,7 @@ static bool analyzeGlobalAux(const Value *V, GlobalStatus &GS, } else if (const PHINode *PN = dyn_cast<PHINode>(I)) { // PHI nodes we can check just like select or GEP instructions, but we // have to be careful about infinite recursion. - if (PhiUsers.insert(PN)) // Not already visited. + if (PhiUsers.insert(PN).second) // Not already visited. if (analyzeGlobalAux(I, GS, PhiUsers)) return true; } else if (isa<CmpInst>(I)) { diff --git a/lib/Transforms/Utils/InlineFunction.cpp b/lib/Transforms/Utils/InlineFunction.cpp index f0a9f2b1fcb3..2a86eb598d4d 100644 --- a/lib/Transforms/Utils/InlineFunction.cpp +++ b/lib/Transforms/Utils/InlineFunction.cpp @@ -13,10 +13,16 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/CallGraph.h" +#include "llvm/Analysis/CaptureTracking.h" #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/CFG.h" @@ -24,14 +30,28 @@ #include "llvm/IR/DataLayout.h" #include "llvm/IR/DebugInfo.h" #include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Dominators.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" +#include "llvm/IR/MDBuilder.h" #include "llvm/IR/Module.h" #include "llvm/Transforms/Utils/Local.h" +#include "llvm/Support/CommandLine.h" +#include <algorithm> using namespace llvm; +static cl::opt<bool> +EnableNoAliasConversion("enable-noalias-to-md-conversion", cl::init(true), + cl::Hidden, + cl::desc("Convert noalias attributes to metadata during inlining.")); + +static cl::opt<bool> +PreserveAlignmentAssumptions("preserve-alignment-assumptions-during-inlining", + cl::init(true), cl::Hidden, + cl::desc("Convert align attributes to assumptions during inlining.")); + bool llvm::InlineFunction(CallInst *CI, InlineFunctionInfo &IFI, bool InsertLifetime) { return InlineFunction(CallSite(CI), IFI, InsertLifetime); @@ -84,7 +104,7 @@ namespace { /// split the landing pad block after the landingpad instruction and jump /// to there. void forwardResume(ResumeInst *RI, - SmallPtrSet<LandingPadInst*, 16> &InlinedLPads); + SmallPtrSetImpl<LandingPadInst*> &InlinedLPads); /// addIncomingPHIValuesFor - Add incoming-PHI values to the unwind /// destination block for the given basic block, using the values for the @@ -143,7 +163,7 @@ BasicBlock *InvokeInliningInfo::getInnerResumeDest() { /// branch. When there is more than one predecessor, we need to split the /// landing pad block after the landingpad instruction and jump to there. void InvokeInliningInfo::forwardResume(ResumeInst *RI, - SmallPtrSet<LandingPadInst*, 16> &InlinedLPads) { + SmallPtrSetImpl<LandingPadInst*> &InlinedLPads) { BasicBlock *Dest = getInnerResumeDest(); BasicBlock *Src = RI->getParent(); @@ -233,9 +253,7 @@ static void HandleInlinedInvoke(InvokeInst *II, BasicBlock *FirstNewBlock, // Append the clauses from the outer landing pad instruction into the inlined // landing pad instructions. LandingPadInst *OuterLPad = Invoke.getLandingPadInst(); - for (SmallPtrSet<LandingPadInst*, 16>::iterator I = InlinedLPads.begin(), - E = InlinedLPads.end(); I != E; ++I) { - LandingPadInst *InlinedLPad = *I; + for (LandingPadInst *InlinedLPad : InlinedLPads) { unsigned OuterNum = OuterLPad->getNumClauses(); InlinedLPad->reserveClauses(OuterNum); for (unsigned OuterIdx = 0; OuterIdx != OuterNum; ++OuterIdx) @@ -260,6 +278,387 @@ static void HandleInlinedInvoke(InvokeInst *II, BasicBlock *FirstNewBlock, InvokeDest->removePredecessor(II->getParent()); } +/// CloneAliasScopeMetadata - When inlining a function that contains noalias +/// scope metadata, this metadata needs to be cloned so that the inlined blocks +/// have different "unqiue scopes" at every call site. Were this not done, then +/// aliasing scopes from a function inlined into a caller multiple times could +/// not be differentiated (and this would lead to miscompiles because the +/// non-aliasing property communicated by the metadata could have +/// call-site-specific control dependencies). +static void CloneAliasScopeMetadata(CallSite CS, ValueToValueMapTy &VMap) { + const Function *CalledFunc = CS.getCalledFunction(); + SetVector<const MDNode *> MD; + + // Note: We could only clone the metadata if it is already used in the + // caller. I'm omitting that check here because it might confuse + // inter-procedural alias analysis passes. We can revisit this if it becomes + // an efficiency or overhead problem. + + for (Function::const_iterator I = CalledFunc->begin(), IE = CalledFunc->end(); + I != IE; ++I) + for (BasicBlock::const_iterator J = I->begin(), JE = I->end(); J != JE; ++J) { + if (const MDNode *M = J->getMetadata(LLVMContext::MD_alias_scope)) + MD.insert(M); + if (const MDNode *M = J->getMetadata(LLVMContext::MD_noalias)) + MD.insert(M); + } + + if (MD.empty()) + return; + + // Walk the existing metadata, adding the complete (perhaps cyclic) chain to + // the set. + SmallVector<const Metadata *, 16> Queue(MD.begin(), MD.end()); + while (!Queue.empty()) { + const MDNode *M = cast<MDNode>(Queue.pop_back_val()); + for (unsigned i = 0, ie = M->getNumOperands(); i != ie; ++i) + if (const MDNode *M1 = dyn_cast<MDNode>(M->getOperand(i))) + if (MD.insert(M1)) + Queue.push_back(M1); + } + + // Now we have a complete set of all metadata in the chains used to specify + // the noalias scopes and the lists of those scopes. + SmallVector<MDNode *, 16> DummyNodes; + DenseMap<const MDNode *, TrackingMDNodeRef> MDMap; + for (SetVector<const MDNode *>::iterator I = MD.begin(), IE = MD.end(); + I != IE; ++I) { + MDNode *Dummy = MDNode::getTemporary(CalledFunc->getContext(), None); + DummyNodes.push_back(Dummy); + MDMap[*I].reset(Dummy); + } + + // Create new metadata nodes to replace the dummy nodes, replacing old + // metadata references with either a dummy node or an already-created new + // node. + for (SetVector<const MDNode *>::iterator I = MD.begin(), IE = MD.end(); + I != IE; ++I) { + SmallVector<Metadata *, 4> NewOps; + for (unsigned i = 0, ie = (*I)->getNumOperands(); i != ie; ++i) { + const Metadata *V = (*I)->getOperand(i); + if (const MDNode *M = dyn_cast<MDNode>(V)) + NewOps.push_back(MDMap[M]); + else + NewOps.push_back(const_cast<Metadata *>(V)); + } + + MDNode *NewM = MDNode::get(CalledFunc->getContext(), NewOps); + MDNodeFwdDecl *TempM = cast<MDNodeFwdDecl>(MDMap[*I]); + + TempM->replaceAllUsesWith(NewM); + } + + // Now replace the metadata in the new inlined instructions with the + // repacements from the map. + for (ValueToValueMapTy::iterator VMI = VMap.begin(), VMIE = VMap.end(); + VMI != VMIE; ++VMI) { + if (!VMI->second) + continue; + + Instruction *NI = dyn_cast<Instruction>(VMI->second); + if (!NI) + continue; + + if (MDNode *M = NI->getMetadata(LLVMContext::MD_alias_scope)) { + MDNode *NewMD = MDMap[M]; + // If the call site also had alias scope metadata (a list of scopes to + // which instructions inside it might belong), propagate those scopes to + // the inlined instructions. + if (MDNode *CSM = + CS.getInstruction()->getMetadata(LLVMContext::MD_alias_scope)) + NewMD = MDNode::concatenate(NewMD, CSM); + NI->setMetadata(LLVMContext::MD_alias_scope, NewMD); + } else if (NI->mayReadOrWriteMemory()) { + if (MDNode *M = + CS.getInstruction()->getMetadata(LLVMContext::MD_alias_scope)) + NI->setMetadata(LLVMContext::MD_alias_scope, M); + } + + if (MDNode *M = NI->getMetadata(LLVMContext::MD_noalias)) { + MDNode *NewMD = MDMap[M]; + // If the call site also had noalias metadata (a list of scopes with + // which instructions inside it don't alias), propagate those scopes to + // the inlined instructions. + if (MDNode *CSM = + CS.getInstruction()->getMetadata(LLVMContext::MD_noalias)) + NewMD = MDNode::concatenate(NewMD, CSM); + NI->setMetadata(LLVMContext::MD_noalias, NewMD); + } else if (NI->mayReadOrWriteMemory()) { + if (MDNode *M = CS.getInstruction()->getMetadata(LLVMContext::MD_noalias)) + NI->setMetadata(LLVMContext::MD_noalias, M); + } + } + + // Now that everything has been replaced, delete the dummy nodes. + for (unsigned i = 0, ie = DummyNodes.size(); i != ie; ++i) + MDNode::deleteTemporary(DummyNodes[i]); +} + +/// AddAliasScopeMetadata - If the inlined function has noalias arguments, then +/// add new alias scopes for each noalias argument, tag the mapped noalias +/// parameters with noalias metadata specifying the new scope, and tag all +/// non-derived loads, stores and memory intrinsics with the new alias scopes. +static void AddAliasScopeMetadata(CallSite CS, ValueToValueMapTy &VMap, + const DataLayout *DL, AliasAnalysis *AA) { + if (!EnableNoAliasConversion) + return; + + const Function *CalledFunc = CS.getCalledFunction(); + SmallVector<const Argument *, 4> NoAliasArgs; + + for (Function::const_arg_iterator I = CalledFunc->arg_begin(), + E = CalledFunc->arg_end(); I != E; ++I) { + if (I->hasNoAliasAttr() && !I->hasNUses(0)) + NoAliasArgs.push_back(I); + } + + if (NoAliasArgs.empty()) + return; + + // To do a good job, if a noalias variable is captured, we need to know if + // the capture point dominates the particular use we're considering. + DominatorTree DT; + DT.recalculate(const_cast<Function&>(*CalledFunc)); + + // noalias indicates that pointer values based on the argument do not alias + // pointer values which are not based on it. So we add a new "scope" for each + // noalias function argument. Accesses using pointers based on that argument + // become part of that alias scope, accesses using pointers not based on that + // argument are tagged as noalias with that scope. + + DenseMap<const Argument *, MDNode *> NewScopes; + MDBuilder MDB(CalledFunc->getContext()); + + // Create a new scope domain for this function. + MDNode *NewDomain = + MDB.createAnonymousAliasScopeDomain(CalledFunc->getName()); + for (unsigned i = 0, e = NoAliasArgs.size(); i != e; ++i) { + const Argument *A = NoAliasArgs[i]; + + std::string Name = CalledFunc->getName(); + if (A->hasName()) { + Name += ": %"; + Name += A->getName(); + } else { + Name += ": argument "; + Name += utostr(i); + } + + // Note: We always create a new anonymous root here. This is true regardless + // of the linkage of the callee because the aliasing "scope" is not just a + // property of the callee, but also all control dependencies in the caller. + MDNode *NewScope = MDB.createAnonymousAliasScope(NewDomain, Name); + NewScopes.insert(std::make_pair(A, NewScope)); + } + + // Iterate over all new instructions in the map; for all memory-access + // instructions, add the alias scope metadata. + for (ValueToValueMapTy::iterator VMI = VMap.begin(), VMIE = VMap.end(); + VMI != VMIE; ++VMI) { + if (const Instruction *I = dyn_cast<Instruction>(VMI->first)) { + if (!VMI->second) + continue; + + Instruction *NI = dyn_cast<Instruction>(VMI->second); + if (!NI) + continue; + + bool IsArgMemOnlyCall = false, IsFuncCall = false; + SmallVector<const Value *, 2> PtrArgs; + + if (const LoadInst *LI = dyn_cast<LoadInst>(I)) + PtrArgs.push_back(LI->getPointerOperand()); + else if (const StoreInst *SI = dyn_cast<StoreInst>(I)) + PtrArgs.push_back(SI->getPointerOperand()); + else if (const VAArgInst *VAAI = dyn_cast<VAArgInst>(I)) + PtrArgs.push_back(VAAI->getPointerOperand()); + else if (const AtomicCmpXchgInst *CXI = dyn_cast<AtomicCmpXchgInst>(I)) + PtrArgs.push_back(CXI->getPointerOperand()); + else if (const AtomicRMWInst *RMWI = dyn_cast<AtomicRMWInst>(I)) + PtrArgs.push_back(RMWI->getPointerOperand()); + else if (ImmutableCallSite ICS = ImmutableCallSite(I)) { + // If we know that the call does not access memory, then we'll still + // know that about the inlined clone of this call site, and we don't + // need to add metadata. + if (ICS.doesNotAccessMemory()) + continue; + + IsFuncCall = true; + if (AA) { + AliasAnalysis::ModRefBehavior MRB = AA->getModRefBehavior(ICS); + if (MRB == AliasAnalysis::OnlyAccessesArgumentPointees || + MRB == AliasAnalysis::OnlyReadsArgumentPointees) + IsArgMemOnlyCall = true; + } + + for (ImmutableCallSite::arg_iterator AI = ICS.arg_begin(), + AE = ICS.arg_end(); AI != AE; ++AI) { + // We need to check the underlying objects of all arguments, not just + // the pointer arguments, because we might be passing pointers as + // integers, etc. + // However, if we know that the call only accesses pointer arguments, + // then we only need to check the pointer arguments. + if (IsArgMemOnlyCall && !(*AI)->getType()->isPointerTy()) + continue; + + PtrArgs.push_back(*AI); + } + } + + // If we found no pointers, then this instruction is not suitable for + // pairing with an instruction to receive aliasing metadata. + // However, if this is a call, this we might just alias with none of the + // noalias arguments. + if (PtrArgs.empty() && !IsFuncCall) + continue; + + // It is possible that there is only one underlying object, but you + // need to go through several PHIs to see it, and thus could be + // repeated in the Objects list. + SmallPtrSet<const Value *, 4> ObjSet; + SmallVector<Metadata *, 4> Scopes, NoAliases; + + SmallSetVector<const Argument *, 4> NAPtrArgs; + for (unsigned i = 0, ie = PtrArgs.size(); i != ie; ++i) { + SmallVector<Value *, 4> Objects; + GetUnderlyingObjects(const_cast<Value*>(PtrArgs[i]), + Objects, DL, /* MaxLookup = */ 0); + + for (Value *O : Objects) + ObjSet.insert(O); + } + + // Figure out if we're derived from anything that is not a noalias + // argument. + bool CanDeriveViaCapture = false, UsesAliasingPtr = false; + for (const Value *V : ObjSet) { + // Is this value a constant that cannot be derived from any pointer + // value (we need to exclude constant expressions, for example, that + // are formed from arithmetic on global symbols). + bool IsNonPtrConst = isa<ConstantInt>(V) || isa<ConstantFP>(V) || + isa<ConstantPointerNull>(V) || + isa<ConstantDataVector>(V) || isa<UndefValue>(V); + if (IsNonPtrConst) + continue; + + // If this is anything other than a noalias argument, then we cannot + // completely describe the aliasing properties using alias.scope + // metadata (and, thus, won't add any). + if (const Argument *A = dyn_cast<Argument>(V)) { + if (!A->hasNoAliasAttr()) + UsesAliasingPtr = true; + } else { + UsesAliasingPtr = true; + } + + // If this is not some identified function-local object (which cannot + // directly alias a noalias argument), or some other argument (which, + // by definition, also cannot alias a noalias argument), then we could + // alias a noalias argument that has been captured). + if (!isa<Argument>(V) && + !isIdentifiedFunctionLocal(const_cast<Value*>(V))) + CanDeriveViaCapture = true; + } + + // A function call can always get captured noalias pointers (via other + // parameters, globals, etc.). + if (IsFuncCall && !IsArgMemOnlyCall) + CanDeriveViaCapture = true; + + // First, we want to figure out all of the sets with which we definitely + // don't alias. Iterate over all noalias set, and add those for which: + // 1. The noalias argument is not in the set of objects from which we + // definitely derive. + // 2. The noalias argument has not yet been captured. + // An arbitrary function that might load pointers could see captured + // noalias arguments via other noalias arguments or globals, and so we + // must always check for prior capture. + for (const Argument *A : NoAliasArgs) { + if (!ObjSet.count(A) && (!CanDeriveViaCapture || + // It might be tempting to skip the + // PointerMayBeCapturedBefore check if + // A->hasNoCaptureAttr() is true, but this is + // incorrect because nocapture only guarantees + // that no copies outlive the function, not + // that the value cannot be locally captured. + !PointerMayBeCapturedBefore(A, + /* ReturnCaptures */ false, + /* StoreCaptures */ false, I, &DT))) + NoAliases.push_back(NewScopes[A]); + } + + if (!NoAliases.empty()) + NI->setMetadata(LLVMContext::MD_noalias, + MDNode::concatenate( + NI->getMetadata(LLVMContext::MD_noalias), + MDNode::get(CalledFunc->getContext(), NoAliases))); + + // Next, we want to figure out all of the sets to which we might belong. + // We might belong to a set if the noalias argument is in the set of + // underlying objects. If there is some non-noalias argument in our list + // of underlying objects, then we cannot add a scope because the fact + // that some access does not alias with any set of our noalias arguments + // cannot itself guarantee that it does not alias with this access + // (because there is some pointer of unknown origin involved and the + // other access might also depend on this pointer). We also cannot add + // scopes to arbitrary functions unless we know they don't access any + // non-parameter pointer-values. + bool CanAddScopes = !UsesAliasingPtr; + if (CanAddScopes && IsFuncCall) + CanAddScopes = IsArgMemOnlyCall; + + if (CanAddScopes) + for (const Argument *A : NoAliasArgs) { + if (ObjSet.count(A)) + Scopes.push_back(NewScopes[A]); + } + + if (!Scopes.empty()) + NI->setMetadata( + LLVMContext::MD_alias_scope, + MDNode::concatenate(NI->getMetadata(LLVMContext::MD_alias_scope), + MDNode::get(CalledFunc->getContext(), Scopes))); + } + } +} + +/// If the inlined function has non-byval align arguments, then +/// add @llvm.assume-based alignment assumptions to preserve this information. +static void AddAlignmentAssumptions(CallSite CS, InlineFunctionInfo &IFI) { + if (!PreserveAlignmentAssumptions || !IFI.DL) + return; + + // To avoid inserting redundant assumptions, we should check for assumptions + // already in the caller. To do this, we might need a DT of the caller. + DominatorTree DT; + bool DTCalculated = false; + + Function *CalledFunc = CS.getCalledFunction(); + for (Function::arg_iterator I = CalledFunc->arg_begin(), + E = CalledFunc->arg_end(); + I != E; ++I) { + unsigned Align = I->getType()->isPointerTy() ? I->getParamAlignment() : 0; + if (Align && !I->hasByValOrInAllocaAttr() && !I->hasNUses(0)) { + if (!DTCalculated) { + DT.recalculate(const_cast<Function&>(*CS.getInstruction()->getParent() + ->getParent())); + DTCalculated = true; + } + + // If we can already prove the asserted alignment in the context of the + // caller, then don't bother inserting the assumption. + Value *Arg = CS.getArgument(I->getArgNo()); + if (getKnownAlignment(Arg, IFI.DL, + &IFI.ACT->getAssumptionCache(*CalledFunc), + CS.getInstruction(), &DT) >= Align) + continue; + + IRBuilder<>(CS.getInstruction()).CreateAlignmentAssumption(*IFI.DL, Arg, + Align); + } + } +} + /// UpdateCallGraphAfterInlining - Once we have cloned code over from a callee /// into the caller, update the specified callgraph to reflect the changes we /// made. Note that it's possible that not all code was copied over, so only @@ -327,31 +726,19 @@ static void UpdateCallGraphAfterInlining(CallSite CS, static void HandleByValArgumentInit(Value *Dst, Value *Src, Module *M, BasicBlock *InsertBlock, InlineFunctionInfo &IFI) { - LLVMContext &Context = Src->getContext(); - Type *VoidPtrTy = Type::getInt8PtrTy(Context); Type *AggTy = cast<PointerType>(Src->getType())->getElementType(); - Type *Tys[3] = { VoidPtrTy, VoidPtrTy, Type::getInt64Ty(Context) }; - Function *MemCpyFn = Intrinsic::getDeclaration(M, Intrinsic::memcpy, Tys); - IRBuilder<> builder(InsertBlock->begin()); - Value *DstCast = builder.CreateBitCast(Dst, VoidPtrTy, "tmp"); - Value *SrcCast = builder.CreateBitCast(Src, VoidPtrTy, "tmp"); + IRBuilder<> Builder(InsertBlock->begin()); Value *Size; if (IFI.DL == nullptr) Size = ConstantExpr::getSizeOf(AggTy); else - Size = ConstantInt::get(Type::getInt64Ty(Context), - IFI.DL->getTypeStoreSize(AggTy)); + Size = Builder.getInt64(IFI.DL->getTypeStoreSize(AggTy)); // Always generate a memcpy of alignment 1 here because we don't know // the alignment of the src pointer. Other optimizations can infer // better alignment. - Value *CallArgs[] = { - DstCast, SrcCast, Size, - ConstantInt::get(Type::getInt32Ty(Context), 1), - ConstantInt::getFalse(Context) // isVolatile - }; - builder.CreateCall(MemCpyFn, CallArgs); + Builder.CreateMemCpy(Dst, Src, Size, /*Align=*/1); } /// HandleByValArgument - When inlining a call site that has a byval argument, @@ -363,6 +750,8 @@ static Value *HandleByValArgument(Value *Arg, Instruction *TheCall, PointerType *ArgTy = cast<PointerType>(Arg->getType()); Type *AggTy = ArgTy->getElementType(); + Function *Caller = TheCall->getParent()->getParent(); + // If the called function is readonly, then it could not mutate the caller's // copy of the byval'd memory. In this case, it is safe to elide the copy and // temporary. @@ -375,8 +764,9 @@ static Value *HandleByValArgument(Value *Arg, Instruction *TheCall, // If the pointer is already known to be sufficiently aligned, or if we can // round it up to a larger alignment, then we don't need a temporary. - if (getOrEnforceKnownAlignment(Arg, ByValAlignment, - IFI.DL) >= ByValAlignment) + if (getOrEnforceKnownAlignment(Arg, ByValAlignment, IFI.DL, + &IFI.ACT->getAssumptionCache(*Caller), + TheCall) >= ByValAlignment) return Arg; // Otherwise, we have to make a memcpy to get a safe alignment. This is bad @@ -393,8 +783,6 @@ static Value *HandleByValArgument(Value *Arg, Instruction *TheCall, // pointer inside the callee). Align = std::max(Align, ByValAlignment); - Function *Caller = TheCall->getParent()->getParent(); - Value *NewAlloca = new AllocaInst(AggTy, nullptr, Align, Arg->getName(), &*Caller->begin()->begin()); IFI.StaticAllocas.push_back(cast<AllocaInst>(NewAlloca)); @@ -472,47 +860,33 @@ static void fixupLineNumbers(Function *Fn, Function::iterator FI, // originates from the call location. This is important for // ((__always_inline__, __nodebug__)) functions which must use caller // location for all instructions in their function body. + + // Don't update static allocas, as they may get moved later. + if (auto *AI = dyn_cast<AllocaInst>(BI)) + if (isa<Constant>(AI->getArraySize())) + continue; + BI->setDebugLoc(TheCallDL); } else { BI->setDebugLoc(updateInlinedAtInfo(DL, TheCallDL, BI->getContext())); if (DbgValueInst *DVI = dyn_cast<DbgValueInst>(BI)) { LLVMContext &Ctx = BI->getContext(); MDNode *InlinedAt = BI->getDebugLoc().getInlinedAt(Ctx); - DVI->setOperand(2, createInlinedVariable(DVI->getVariable(), - InlinedAt, Ctx)); + DVI->setOperand(2, MetadataAsValue::get( + Ctx, createInlinedVariable(DVI->getVariable(), + InlinedAt, Ctx))); + } else if (DbgDeclareInst *DDI = dyn_cast<DbgDeclareInst>(BI)) { + LLVMContext &Ctx = BI->getContext(); + MDNode *InlinedAt = BI->getDebugLoc().getInlinedAt(Ctx); + DDI->setOperand(1, MetadataAsValue::get( + Ctx, createInlinedVariable(DDI->getVariable(), + InlinedAt, Ctx))); } } } } } -/// Returns a musttail call instruction if one immediately precedes the given -/// return instruction with an optional bitcast instruction between them. -static CallInst *getPrecedingMustTailCall(ReturnInst *RI) { - Instruction *Prev = RI->getPrevNode(); - if (!Prev) - return nullptr; - - if (Value *RV = RI->getReturnValue()) { - if (RV != Prev) - return nullptr; - - // Look through the optional bitcast. - if (auto *BI = dyn_cast<BitCastInst>(Prev)) { - RV = BI->getOperand(0); - Prev = BI->getPrevNode(); - if (!Prev || RV != Prev) - return nullptr; - } - } - - if (auto *CI = dyn_cast<CallInst>(Prev)) { - if (CI->isMustTailCall()) - return CI; - } - return nullptr; -} - /// InlineFunction - This function inlines the called function into the basic /// block of the caller. This returns false if it is not possible to inline /// this call. The program is still in a well defined state if this occurs @@ -626,6 +1000,11 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, VMap[I] = ActualArg; } + // Add alignment assumptions if necessary. We do this before the inlined + // instructions are actually cloned into the caller so that we can easily + // check what will be known at the start of the inlined code. + AddAlignmentAssumptions(CS, IFI); + // We want the inliner to prune the code as it copies. We would LOVE to // have no dead or constant instructions leftover after inlining occurs // (which can happen, e.g., because an argument was constant), but we'll be @@ -648,6 +1027,17 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, // Update inlined instructions' line number information. fixupLineNumbers(Caller, FirstNewBlock, TheCall); + + // Clone existing noalias metadata if necessary. + CloneAliasScopeMetadata(CS, VMap); + + // Add noalias metadata if necessary. + AddAliasScopeMetadata(CS, VMap, IFI.DL, IFI.AA); + + // FIXME: We could register any cloned assumptions instead of clearing the + // whole function's cache. + if (IFI.ACT) + IFI.ACT->getAssumptionCache(*Caller).clear(); } // If there are any alloca instructions in the block that used to be the entry @@ -765,7 +1155,8 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, for (ReturnInst *RI : Returns) { // Don't insert llvm.lifetime.end calls between a musttail call and a // return. The return kills all local allocas. - if (InlinedMustTailCalls && getPrecedingMustTailCall(RI)) + if (InlinedMustTailCalls && + RI->getParent()->getTerminatingMustTailCall()) continue; IRBuilder<>(RI).CreateLifetimeEnd(AI, AllocaSize); } @@ -789,7 +1180,7 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, for (ReturnInst *RI : Returns) { // Don't insert llvm.stackrestore calls between a musttail call and a // return. The return will restore the stack pointer. - if (InlinedMustTailCalls && getPrecedingMustTailCall(RI)) + if (InlinedMustTailCalls && RI->getParent()->getTerminatingMustTailCall()) continue; IRBuilder<>(RI).CreateCall(StackRestore, SavedPtr); } @@ -812,7 +1203,8 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, // Handle the returns preceded by musttail calls separately. SmallVector<ReturnInst *, 8> NormalReturns; for (ReturnInst *RI : Returns) { - CallInst *ReturnedMustTail = getPrecedingMustTailCall(RI); + CallInst *ReturnedMustTail = + RI->getParent()->getTerminatingMustTailCall(); if (!ReturnedMustTail) { NormalReturns.push_back(RI); continue; @@ -1016,7 +1408,8 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, // the entries are the same or undef). If so, remove the PHI so it doesn't // block other optimizations. if (PHI) { - if (Value *V = SimplifyInstruction(PHI, IFI.DL)) { + if (Value *V = SimplifyInstruction(PHI, IFI.DL, nullptr, nullptr, + &IFI.ACT->getAssumptionCache(*Caller))) { PHI->replaceAllUsesWith(V); PHI->eraseFromParent(); } diff --git a/lib/Transforms/Utils/IntegerDivision.cpp b/lib/Transforms/Utils/IntegerDivision.cpp index 9f91eeb79531..0ae746cc83db 100644 --- a/lib/Transforms/Utils/IntegerDivision.cpp +++ b/lib/Transforms/Utils/IntegerDivision.cpp @@ -398,11 +398,13 @@ bool llvm::expandRemainder(BinaryOperator *Rem) { Rem->dropAllReferences(); Rem->eraseFromParent(); - // If we didn't actually generate a udiv instruction, we're done - BinaryOperator *BO = dyn_cast<BinaryOperator>(Builder.GetInsertPoint()); - if (!BO || BO->getOpcode() != Instruction::URem) + // If we didn't actually generate an urem instruction, we're done + // This happens for example if the input were constant. In this case the + // Builder insertion point was unchanged + if (Rem == Builder.GetInsertPoint()) return true; + BinaryOperator *BO = dyn_cast<BinaryOperator>(Builder.GetInsertPoint()); Rem = BO; } @@ -456,11 +458,13 @@ bool llvm::expandDivision(BinaryOperator *Div) { Div->dropAllReferences(); Div->eraseFromParent(); - // If we didn't actually generate a udiv instruction, we're done - BinaryOperator *BO = dyn_cast<BinaryOperator>(Builder.GetInsertPoint()); - if (!BO || BO->getOpcode() != Instruction::UDiv) + // If we didn't actually generate an udiv instruction, we're done + // This happens for example if the input were constant. In this case the + // Builder insertion point was unchanged + if (Div == Builder.GetInsertPoint()) return true; + BinaryOperator *BO = dyn_cast<BinaryOperator>(Builder.GetInsertPoint()); Div = BO; } diff --git a/lib/Transforms/Utils/LCSSA.cpp b/lib/Transforms/Utils/LCSSA.cpp index 51a3d9c1fced..3f9b702c5b9a 100644 --- a/lib/Transforms/Utils/LCSSA.cpp +++ b/lib/Transforms/Utils/LCSSA.cpp @@ -61,7 +61,7 @@ static bool isExitBlock(BasicBlock *BB, /// uses. static bool processInstruction(Loop &L, Instruction &Inst, DominatorTree &DT, const SmallVectorImpl<BasicBlock *> &ExitBlocks, - PredIteratorCache &PredCache) { + PredIteratorCache &PredCache, LoopInfo *LI) { SmallVector<Use *, 16> UsesToRewrite; BasicBlock *InstBB = Inst.getParent(); @@ -94,6 +94,7 @@ static bool processInstruction(Loop &L, Instruction &Inst, DominatorTree &DT, DomTreeNode *DomNode = DT.getNode(DomBB); SmallVector<PHINode *, 16> AddedPHIs; + SmallVector<PHINode *, 8> PostProcessPHIs; SSAUpdater SSAUpdate; SSAUpdate.Initialize(Inst.getType(), Inst.getName()); @@ -131,6 +132,18 @@ static bool processInstruction(Loop &L, Instruction &Inst, DominatorTree &DT, // Remember that this phi makes the value alive in this block. SSAUpdate.AddAvailableValue(ExitBB, PN); + + // LoopSimplify might fail to simplify some loops (e.g. when indirect + // branches are involved). In such situations, it might happen that an exit + // for Loop L1 is the header of a disjoint Loop L2. Thus, when we create + // PHIs in such an exit block, we are also inserting PHIs into L2's header. + // This could break LCSSA form for L2 because these inserted PHIs can also + // have uses outside of L2. Remember all PHIs in such situation as to + // revisit than later on. FIXME: Remove this if indirectbr support into + // LoopSimplify gets improved. + if (auto *OtherLoop = LI->getLoopFor(ExitBB)) + if (!L.contains(OtherLoop)) + PostProcessPHIs.push_back(PN); } // Rewrite all uses outside the loop in terms of the new PHIs we just @@ -157,6 +170,25 @@ static bool processInstruction(Loop &L, Instruction &Inst, DominatorTree &DT, SSAUpdate.RewriteUse(*UsesToRewrite[i]); } + // Post process PHI instructions that were inserted into another disjoint loop + // and update their exits properly. + for (auto *I : PostProcessPHIs) { + if (I->use_empty()) + continue; + + BasicBlock *PHIBB = I->getParent(); + Loop *OtherLoop = LI->getLoopFor(PHIBB); + SmallVector<BasicBlock *, 8> EBs; + OtherLoop->getExitBlocks(EBs); + if (EBs.empty()) + continue; + + // Recurse and re-process each PHI instruction. FIXME: we should really + // convert this entire thing to a worklist approach where we process a + // vector of instructions... + processInstruction(*OtherLoop, *I, DT, EBs, PredCache, LI); + } + // Remove PHI nodes that did not have any uses rewritten. for (unsigned i = 0, e = AddedPHIs.size(); i != e; ++i) { if (AddedPHIs[i]->use_empty()) @@ -180,7 +212,8 @@ blockDominatesAnExit(BasicBlock *BB, return false; } -bool llvm::formLCSSA(Loop &L, DominatorTree &DT, ScalarEvolution *SE) { +bool llvm::formLCSSA(Loop &L, DominatorTree &DT, LoopInfo *LI, + ScalarEvolution *SE) { bool Changed = false; // Get the set of exiting blocks. @@ -212,7 +245,7 @@ bool llvm::formLCSSA(Loop &L, DominatorTree &DT, ScalarEvolution *SE) { !isa<PHINode>(I->user_back()))) continue; - Changed |= processInstruction(L, *I, DT, ExitBlocks, PredCache); + Changed |= processInstruction(L, *I, DT, ExitBlocks, PredCache, LI); } } @@ -228,15 +261,15 @@ bool llvm::formLCSSA(Loop &L, DominatorTree &DT, ScalarEvolution *SE) { } /// Process a loop nest depth first. -bool llvm::formLCSSARecursively(Loop &L, DominatorTree &DT, +bool llvm::formLCSSARecursively(Loop &L, DominatorTree &DT, LoopInfo *LI, ScalarEvolution *SE) { bool Changed = false; // Recurse depth-first through inner loops. - for (Loop::iterator LI = L.begin(), LE = L.end(); LI != LE; ++LI) - Changed |= formLCSSARecursively(**LI, DT, SE); + for (Loop::iterator I = L.begin(), E = L.end(); I != E; ++I) + Changed |= formLCSSARecursively(**I, DT, LI, SE); - Changed |= formLCSSA(L, DT, SE); + Changed |= formLCSSA(L, DT, LI, SE); return Changed; } @@ -291,7 +324,7 @@ bool LCSSA::runOnFunction(Function &F) { // Simplify each loop nest in the function. for (LoopInfo::iterator I = LI->begin(), E = LI->end(); I != E; ++I) - Changed |= formLCSSARecursively(**I, *DT, SE); + Changed |= formLCSSARecursively(**I, *DT, LI, SE); return Changed; } diff --git a/lib/Transforms/Utils/Local.cpp b/lib/Transforms/Utils/Local.cpp index a5e443fcf46b..08a4b3f3b737 100644 --- a/lib/Transforms/Utils/Local.cpp +++ b/lib/Transforms/Utils/Local.cpp @@ -128,7 +128,7 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions, // Check to see if this branch is going to the same place as the default // dest. If so, eliminate it as an explicit compare. if (i.getCaseSuccessor() == DefaultDest) { - MDNode* MD = SI->getMetadata(LLVMContext::MD_prof); + MDNode *MD = SI->getMetadata(LLVMContext::MD_prof); unsigned NCases = SI->getNumCases(); // Fold the case metadata into the default if there will be any branches // left, unless the metadata doesn't match the switch. @@ -137,7 +137,8 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions, SmallVector<uint32_t, 8> Weights; for (unsigned MD_i = 1, MD_e = MD->getNumOperands(); MD_i < MD_e; ++MD_i) { - ConstantInt* CI = dyn_cast<ConstantInt>(MD->getOperand(MD_i)); + ConstantInt *CI = + mdconst::dyn_extract<ConstantInt>(MD->getOperand(MD_i)); assert(CI); Weights.push_back(CI->getValue().getZExtValue()); } @@ -206,10 +207,12 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions, BranchInst *NewBr = Builder.CreateCondBr(Cond, FirstCase.getCaseSuccessor(), SI->getDefaultDest()); - MDNode* MD = SI->getMetadata(LLVMContext::MD_prof); + MDNode *MD = SI->getMetadata(LLVMContext::MD_prof); if (MD && MD->getNumOperands() == 3) { - ConstantInt *SICase = dyn_cast<ConstantInt>(MD->getOperand(2)); - ConstantInt *SIDef = dyn_cast<ConstantInt>(MD->getOperand(1)); + ConstantInt *SICase = + mdconst::dyn_extract<ConstantInt>(MD->getOperand(2)); + ConstantInt *SIDef = + mdconst::dyn_extract<ConstantInt>(MD->getOperand(1)); assert(SICase && SIDef); // The TrueWeight should be the weight for the single case of SI. NewBr->setMetadata(LLVMContext::MD_prof, @@ -301,6 +304,14 @@ bool llvm::isInstructionTriviallyDead(Instruction *I, if (II->getIntrinsicID() == Intrinsic::lifetime_start || II->getIntrinsicID() == Intrinsic::lifetime_end) return isa<UndefValue>(II->getArgOperand(1)); + + // Assumptions are dead if their condition is trivially true. + if (II->getIntrinsicID() == Intrinsic::assume) { + if (ConstantInt *Cond = dyn_cast<ConstantInt>(II->getArgOperand(0))) + return !Cond->isZero(); + + return false; + } } if (isAllocLikeFn(I, TLI)) return true; @@ -384,7 +395,7 @@ bool llvm::RecursivelyDeleteDeadPHINode(PHINode *PN, // If we find an instruction more than once, we're on a cycle that // won't prove fruitful. - if (!Visited.insert(I)) { + if (!Visited.insert(I).second) { // Break the cycle and delete the instruction and its operands. I->replaceAllUsesWith(UndefValue::get(I->getType())); (void)RecursivelyDeleteTriviallyDeadInstructions(I, TLI); @@ -931,13 +942,16 @@ static unsigned enforceKnownAlignment(Value *V, unsigned Align, /// and it is more than the alignment of the ultimate object, see if we can /// increase the alignment of the ultimate object, making this check succeed. unsigned llvm::getOrEnforceKnownAlignment(Value *V, unsigned PrefAlign, - const DataLayout *DL) { + const DataLayout *DL, + AssumptionCache *AC, + const Instruction *CxtI, + const DominatorTree *DT) { assert(V->getType()->isPointerTy() && "getOrEnforceKnownAlignment expects a pointer!"); unsigned BitWidth = DL ? DL->getPointerTypeSizeInBits(V->getType()) : 64; APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); - computeKnownBits(V, KnownZero, KnownOne, DL); + computeKnownBits(V, KnownZero, KnownOne, DL, 0, AC, CxtI, DT); unsigned TrailZ = KnownZero.countTrailingOnes(); // Avoid trouble with ridiculously large TrailZ values, such as @@ -982,6 +996,7 @@ static bool LdStHasDebugValue(DIVariable &DIVar, Instruction *I) { bool llvm::ConvertDebugDeclareToDebugValue(DbgDeclareInst *DDI, StoreInst *SI, DIBuilder &Builder) { DIVariable DIVar(DDI->getVariable()); + DIExpression DIExpr(DDI->getExpression()); assert((!DIVar || DIVar.isVariable()) && "Variable in DbgDeclareInst should be either null or a DIVariable."); if (!DIVar) @@ -999,9 +1014,10 @@ bool llvm::ConvertDebugDeclareToDebugValue(DbgDeclareInst *DDI, if (SExtInst *SExt = dyn_cast<SExtInst>(SI->getOperand(0))) ExtendedArg = dyn_cast<Argument>(SExt->getOperand(0)); if (ExtendedArg) - DbgVal = Builder.insertDbgValueIntrinsic(ExtendedArg, 0, DIVar, SI); + DbgVal = Builder.insertDbgValueIntrinsic(ExtendedArg, 0, DIVar, DIExpr, SI); else - DbgVal = Builder.insertDbgValueIntrinsic(SI->getOperand(0), 0, DIVar, SI); + DbgVal = Builder.insertDbgValueIntrinsic(SI->getOperand(0), 0, DIVar, + DIExpr, SI); DbgVal->setDebugLoc(DDI->getDebugLoc()); return true; } @@ -1011,6 +1027,7 @@ bool llvm::ConvertDebugDeclareToDebugValue(DbgDeclareInst *DDI, bool llvm::ConvertDebugDeclareToDebugValue(DbgDeclareInst *DDI, LoadInst *LI, DIBuilder &Builder) { DIVariable DIVar(DDI->getVariable()); + DIExpression DIExpr(DDI->getExpression()); assert((!DIVar || DIVar.isVariable()) && "Variable in DbgDeclareInst should be either null or a DIVariable."); if (!DIVar) @@ -1020,8 +1037,7 @@ bool llvm::ConvertDebugDeclareToDebugValue(DbgDeclareInst *DDI, return true; Instruction *DbgVal = - Builder.insertDbgValueIntrinsic(LI->getOperand(0), 0, - DIVar, LI); + Builder.insertDbgValueIntrinsic(LI->getOperand(0), 0, DIVar, DIExpr, LI); DbgVal->setDebugLoc(DDI->getDebugLoc()); return true; } @@ -1035,7 +1051,7 @@ static bool isArray(AllocaInst *AI) { /// LowerDbgDeclare - Lowers llvm.dbg.declare intrinsics into appropriate set /// of llvm.dbg.value intrinsics. bool llvm::LowerDbgDeclare(Function &F) { - DIBuilder DIB(*F.getParent()); + DIBuilder DIB(*F.getParent(), /*AllowUnresolved*/ false); SmallVector<DbgDeclareInst *, 4> Dbgs; for (auto &FI : F) for (BasicBlock::iterator BI : FI) @@ -1061,14 +1077,14 @@ bool llvm::LowerDbgDeclare(Function &F) { else if (LoadInst *LI = dyn_cast<LoadInst>(U)) ConvertDebugDeclareToDebugValue(DDI, LI, DIB); else if (CallInst *CI = dyn_cast<CallInst>(U)) { - // This is a call by-value or some other instruction that - // takes a pointer to the variable. Insert a *value* - // intrinsic that describes the alloca. - auto DbgVal = - DIB.insertDbgValueIntrinsic(AI, 0, - DIVariable(DDI->getVariable()), CI); - DbgVal->setDebugLoc(DDI->getDebugLoc()); - } + // This is a call by-value or some other instruction that + // takes a pointer to the variable. Insert a *value* + // intrinsic that describes the alloca. + auto DbgVal = DIB.insertDbgValueIntrinsic( + AI, 0, DIVariable(DDI->getVariable()), + DIExpression(DDI->getExpression()), CI); + DbgVal->setDebugLoc(DDI->getDebugLoc()); + } DDI->eraseFromParent(); } } @@ -1078,10 +1094,11 @@ bool llvm::LowerDbgDeclare(Function &F) { /// FindAllocaDbgDeclare - Finds the llvm.dbg.declare intrinsic describing the /// alloca 'V', if any. DbgDeclareInst *llvm::FindAllocaDbgDeclare(Value *V) { - if (MDNode *DebugNode = MDNode::getIfExists(V->getContext(), V)) - for (User *U : DebugNode->users()) - if (DbgDeclareInst *DDI = dyn_cast<DbgDeclareInst>(U)) - return DDI; + if (auto *L = LocalAsMetadata::getIfExists(V)) + if (auto *MDV = MetadataAsValue::getIfExists(V->getContext(), L)) + for (User *U : MDV->users()) + if (DbgDeclareInst *DDI = dyn_cast<DbgDeclareInst>(U)) + return DDI; return nullptr; } @@ -1092,33 +1109,27 @@ bool llvm::replaceDbgDeclareForAlloca(AllocaInst *AI, Value *NewAllocaAddress, if (!DDI) return false; DIVariable DIVar(DDI->getVariable()); + DIExpression DIExpr(DDI->getExpression()); assert((!DIVar || DIVar.isVariable()) && "Variable in DbgDeclareInst should be either null or a DIVariable."); if (!DIVar) return false; - // Create a copy of the original DIDescriptor for user variable, appending + // Create a copy of the original DIDescriptor for user variable, prepending // "deref" operation to a list of address elements, as new llvm.dbg.declare // will take a value storing address of the memory for variable, not // alloca itself. - Type *Int64Ty = Type::getInt64Ty(AI->getContext()); - SmallVector<Value*, 4> NewDIVarAddress; - if (DIVar.hasComplexAddress()) { - for (unsigned i = 0, n = DIVar.getNumAddrElements(); i < n; ++i) { - NewDIVarAddress.push_back( - ConstantInt::get(Int64Ty, DIVar.getAddrElement(i))); - } - } - NewDIVarAddress.push_back(ConstantInt::get(Int64Ty, DIBuilder::OpDeref)); - DIVariable NewDIVar = Builder.createComplexVariable( - DIVar.getTag(), DIVar.getContext(), DIVar.getName(), - DIVar.getFile(), DIVar.getLineNumber(), DIVar.getType(), - NewDIVarAddress, DIVar.getArgNumber()); + SmallVector<int64_t, 4> NewDIExpr; + NewDIExpr.push_back(dwarf::DW_OP_deref); + if (DIExpr) + for (unsigned i = 0, n = DIExpr.getNumElements(); i < n; ++i) + NewDIExpr.push_back(DIExpr.getElement(i)); // Insert llvm.dbg.declare in the same basic block as the original alloca, // and remove old llvm.dbg.declare. BasicBlock *BB = AI->getParent(); - Builder.insertDeclare(NewAllocaAddress, NewDIVar, BB); + Builder.insertDeclare(NewAllocaAddress, DIVar, + Builder.createExpression(NewDIExpr), BB); DDI->eraseFromParent(); return true; } @@ -1170,7 +1181,7 @@ static void changeToCall(InvokeInst *II) { } static bool markAliveBlocks(BasicBlock *BB, - SmallPtrSet<BasicBlock*, 128> &Reachable) { + SmallPtrSetImpl<BasicBlock*> &Reachable) { SmallVector<BasicBlock*, 128> Worklist; Worklist.push_back(BB); @@ -1183,6 +1194,26 @@ static bool markAliveBlocks(BasicBlock *BB, // instructions into LLVM unreachable insts. The instruction combining pass // canonicalizes unreachable insts into stores to null or undef. for (BasicBlock::iterator BBI = BB->begin(), E = BB->end(); BBI != E;++BBI){ + // Assumptions that are known to be false are equivalent to unreachable. + // Also, if the condition is undefined, then we make the choice most + // beneficial to the optimizer, and choose that to also be unreachable. + if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(BBI)) + if (II->getIntrinsicID() == Intrinsic::assume) { + bool MakeUnreachable = false; + if (isa<UndefValue>(II->getArgOperand(0))) + MakeUnreachable = true; + else if (ConstantInt *Cond = + dyn_cast<ConstantInt>(II->getArgOperand(0))) + MakeUnreachable = Cond->isZero(); + + if (MakeUnreachable) { + // Don't insert a call to llvm.trap right before the unreachable. + changeToUnreachable(BBI, false); + Changed = true; + break; + } + } + if (CallInst *CI = dyn_cast<CallInst>(BBI)) { if (CI->doesNotReturn()) { // If we found a call to a no-return function, insert an unreachable @@ -1237,7 +1268,7 @@ static bool markAliveBlocks(BasicBlock *BB, Changed |= ConstantFoldTerminator(BB, true); for (succ_iterator SI = succ_begin(BB), SE = succ_end(BB); SI != SE; ++SI) - if (Reachable.insert(*SI)) + if (Reachable.insert(*SI).second) Worklist.push_back(*SI); } while (!Worklist.empty()); return Changed; @@ -1277,3 +1308,43 @@ bool llvm::removeUnreachableBlocks(Function &F) { return true; } + +void llvm::combineMetadata(Instruction *K, const Instruction *J, ArrayRef<unsigned> KnownIDs) { + SmallVector<std::pair<unsigned, MDNode *>, 4> Metadata; + K->dropUnknownMetadata(KnownIDs); + K->getAllMetadataOtherThanDebugLoc(Metadata); + for (unsigned i = 0, n = Metadata.size(); i < n; ++i) { + unsigned Kind = Metadata[i].first; + MDNode *JMD = J->getMetadata(Kind); + MDNode *KMD = Metadata[i].second; + + switch (Kind) { + default: + K->setMetadata(Kind, nullptr); // Remove unknown metadata + break; + case LLVMContext::MD_dbg: + llvm_unreachable("getAllMetadataOtherThanDebugLoc returned a MD_dbg"); + case LLVMContext::MD_tbaa: + K->setMetadata(Kind, MDNode::getMostGenericTBAA(JMD, KMD)); + break; + case LLVMContext::MD_alias_scope: + case LLVMContext::MD_noalias: + K->setMetadata(Kind, MDNode::intersect(JMD, KMD)); + break; + case LLVMContext::MD_range: + K->setMetadata(Kind, MDNode::getMostGenericRange(JMD, KMD)); + break; + case LLVMContext::MD_fpmath: + K->setMetadata(Kind, MDNode::getMostGenericFPMath(JMD, KMD)); + break; + case LLVMContext::MD_invariant_load: + // Only set the !invariant.load if it is present in both instructions. + K->setMetadata(Kind, JMD); + break; + case LLVMContext::MD_nonnull: + // Only set the !nonnull if it is present in both instructions. + K->setMetadata(Kind, JMD); + break; + } + } +} diff --git a/lib/Transforms/Utils/LoopSimplify.cpp b/lib/Transforms/Utils/LoopSimplify.cpp index ef422914b6b2..c832a4b36f50 100644 --- a/lib/Transforms/Utils/LoopSimplify.cpp +++ b/lib/Transforms/Utils/LoopSimplify.cpp @@ -44,6 +44,7 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/DependenceAnalysis.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopInfo.h" @@ -173,8 +174,7 @@ static BasicBlock *rewriteLoopExitBlock(Loop *L, BasicBlock *Exit, Pass *PP) { if (Exit->isLandingPad()) { SmallVector<BasicBlock*, 2> NewBBs; - SplitLandingPadPredecessors(Exit, ArrayRef<BasicBlock*>(&LoopBlocks[0], - LoopBlocks.size()), + SplitLandingPadPredecessors(Exit, LoopBlocks, ".loopexit", ".nonloopexit", PP, NewBBs); NewExitBB = NewBBs[0]; @@ -209,11 +209,12 @@ static void addBlockAndPredsToSet(BasicBlock *InputBB, BasicBlock *StopBlock, /// \brief The first part of loop-nestification is to find a PHI node that tells /// us how to partition the loops. static PHINode *findPHIToPartitionLoops(Loop *L, AliasAnalysis *AA, - DominatorTree *DT) { + DominatorTree *DT, + AssumptionCache *AC) { for (BasicBlock::iterator I = L->getHeader()->begin(); isa<PHINode>(I); ) { PHINode *PN = cast<PHINode>(I); ++I; - if (Value *V = SimplifyInstruction(PN, nullptr, nullptr, DT)) { + if (Value *V = SimplifyInstruction(PN, nullptr, nullptr, DT, AC)) { // This is a degenerate PHI already, don't modify it! PN->replaceAllUsesWith(V); if (AA) AA->deleteValue(PN); @@ -252,7 +253,8 @@ static PHINode *findPHIToPartitionLoops(Loop *L, AliasAnalysis *AA, /// static Loop *separateNestedLoop(Loop *L, BasicBlock *Preheader, AliasAnalysis *AA, DominatorTree *DT, - LoopInfo *LI, ScalarEvolution *SE, Pass *PP) { + LoopInfo *LI, ScalarEvolution *SE, Pass *PP, + AssumptionCache *AC) { // Don't try to separate loops without a preheader. if (!Preheader) return nullptr; @@ -261,7 +263,7 @@ static Loop *separateNestedLoop(Loop *L, BasicBlock *Preheader, assert(!L->getHeader()->isLandingPad() && "Can't insert backedge to landing pad"); - PHINode *PN = findPHIToPartitionLoops(L, AA, DT); + PHINode *PN = findPHIToPartitionLoops(L, AA, DT, AC); if (!PN) return nullptr; // No known way to partition. // Pull out all predecessors that have varying values in the loop. This @@ -474,8 +476,8 @@ static BasicBlock *insertUniqueBackedgeBlock(Loop *L, BasicBlock *Preheader, /// explicit if they accepted the analysis directly and then updated it. static bool simplifyOneLoop(Loop *L, SmallVectorImpl<Loop *> &Worklist, AliasAnalysis *AA, DominatorTree *DT, LoopInfo *LI, - ScalarEvolution *SE, Pass *PP, - const DataLayout *DL) { + ScalarEvolution *SE, Pass *PP, const DataLayout *DL, + AssumptionCache *AC) { bool Changed = false; ReprocessLoop: @@ -496,20 +498,19 @@ ReprocessLoop: } // Delete each unique out-of-loop (and thus dead) predecessor. - for (SmallPtrSet<BasicBlock*, 4>::iterator I = BadPreds.begin(), - E = BadPreds.end(); I != E; ++I) { + for (BasicBlock *P : BadPreds) { DEBUG(dbgs() << "LoopSimplify: Deleting edge from dead predecessor " - << (*I)->getName() << "\n"); + << P->getName() << "\n"); // Inform each successor of each dead pred. - for (succ_iterator SI = succ_begin(*I), SE = succ_end(*I); SI != SE; ++SI) - (*SI)->removePredecessor(*I); + for (succ_iterator SI = succ_begin(P), SE = succ_end(P); SI != SE; ++SI) + (*SI)->removePredecessor(P); // Zap the dead pred's terminator and replace it with unreachable. - TerminatorInst *TI = (*I)->getTerminator(); + TerminatorInst *TI = P->getTerminator(); TI->replaceAllUsesWith(UndefValue::get(TI->getType())); - (*I)->getTerminator()->eraseFromParent(); - new UnreachableInst((*I)->getContext(), *I); + P->getTerminator()->eraseFromParent(); + new UnreachableInst(P->getContext(), P); Changed = true; } } @@ -582,7 +583,8 @@ ReprocessLoop: // this for loops with a giant number of backedges, just factor them into a // common backedge instead. if (L->getNumBackEdges() < 8) { - if (Loop *OuterL = separateNestedLoop(L, Preheader, AA, DT, LI, SE, PP)) { + if (Loop *OuterL = + separateNestedLoop(L, Preheader, AA, DT, LI, SE, PP, AC)) { ++NumNested; // Enqueue the outer loop as it should be processed next in our // depth-first nest walk. @@ -612,7 +614,7 @@ ReprocessLoop: PHINode *PN; for (BasicBlock::iterator I = L->getHeader()->begin(); (PN = dyn_cast<PHINode>(I++)); ) - if (Value *V = SimplifyInstruction(PN, nullptr, nullptr, DT)) { + if (Value *V = SimplifyInstruction(PN, nullptr, nullptr, DT, AC)) { if (AA) AA->deleteValue(PN); if (SE) SE->forgetValue(PN); PN->replaceAllUsesWith(V); @@ -712,7 +714,7 @@ ReprocessLoop: bool llvm::simplifyLoop(Loop *L, DominatorTree *DT, LoopInfo *LI, Pass *PP, AliasAnalysis *AA, ScalarEvolution *SE, - const DataLayout *DL) { + const DataLayout *DL, AssumptionCache *AC) { bool Changed = false; // Worklist maintains our depth-first queue of loops in this nest to process. @@ -730,7 +732,7 @@ bool llvm::simplifyLoop(Loop *L, DominatorTree *DT, LoopInfo *LI, Pass *PP, while (!Worklist.empty()) Changed |= simplifyOneLoop(Worklist.pop_back_val(), Worklist, AA, DT, LI, - SE, PP, DL); + SE, PP, DL, AC); return Changed; } @@ -749,10 +751,13 @@ namespace { LoopInfo *LI; ScalarEvolution *SE; const DataLayout *DL; + AssumptionCache *AC; bool runOnFunction(Function &F) override; void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<AssumptionCacheTracker>(); + // We need loop information to identify the loops... AU.addRequired<DominatorTreeWrapperPass>(); AU.addPreserved<DominatorTreeWrapperPass>(); @@ -773,11 +778,12 @@ namespace { char LoopSimplify::ID = 0; INITIALIZE_PASS_BEGIN(LoopSimplify, "loop-simplify", - "Canonicalize natural loops", true, false) + "Canonicalize natural loops", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopInfo) INITIALIZE_PASS_END(LoopSimplify, "loop-simplify", - "Canonicalize natural loops", true, false) + "Canonicalize natural loops", false, false) // Publicly exposed interface to pass... char &llvm::LoopSimplifyID = LoopSimplify::ID; @@ -794,10 +800,11 @@ bool LoopSimplify::runOnFunction(Function &F) { SE = getAnalysisIfAvailable<ScalarEvolution>(); DataLayoutPass *DLP = getAnalysisIfAvailable<DataLayoutPass>(); DL = DLP ? &DLP->getDataLayout() : nullptr; + AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); // Simplify each loop nest in the function. for (LoopInfo::iterator I = LI->begin(), E = LI->end(); I != E; ++I) - Changed |= simplifyLoop(*I, DT, LI, this, AA, SE, DL); + Changed |= simplifyLoop(*I, DT, LI, this, AA, SE, DL, AC); return Changed; } diff --git a/lib/Transforms/Utils/LoopUnroll.cpp b/lib/Transforms/Utils/LoopUnroll.cpp index ab1c25a75e26..57459206e528 100644 --- a/lib/Transforms/Utils/LoopUnroll.cpp +++ b/lib/Transforms/Utils/LoopUnroll.cpp @@ -19,6 +19,7 @@ #include "llvm/Transforms/Utils/UnrollLoop.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopIterator.h" #include "llvm/Analysis/LoopPass.h" @@ -111,7 +112,7 @@ FoldBlockIntoPredecessor(BasicBlock *BB, LoopInfo* LI, LPPassManager *LPM, if (LPM) { if (ScalarEvolution *SE = LPM->getAnalysisIfAvailable<ScalarEvolution>()) { if (Loop *L = LI->getLoopFor(BB)) { - if (ForgottenLoops.insert(L)) + if (ForgottenLoops.insert(L).second) SE->forgetLoop(L); } } @@ -153,8 +154,8 @@ FoldBlockIntoPredecessor(BasicBlock *BB, LoopInfo* LI, LPPassManager *LPM, /// This utility preserves LoopInfo. If DominatorTree or ScalarEvolution are /// available from the Pass it must also preserve those analyses. bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, - bool AllowRuntime, unsigned TripMultiple, - LoopInfo *LI, Pass *PP, LPPassManager *LPM) { + bool AllowRuntime, unsigned TripMultiple, LoopInfo *LI, + Pass *PP, LPPassManager *LPM, AssumptionCache *AC) { BasicBlock *Preheader = L->getLoopPreheader(); if (!Preheader) { DEBUG(dbgs() << " Can't unroll; loop preheader-insertion failed.\n"); @@ -222,11 +223,10 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, // Notify ScalarEvolution that the loop will be substantially changed, // if not outright eliminated. - if (PP) { - ScalarEvolution *SE = PP->getAnalysisIfAvailable<ScalarEvolution>(); - if (SE) - SE->forgetLoop(L); - } + ScalarEvolution *SE = + PP ? PP->getAnalysisIfAvailable<ScalarEvolution>() : nullptr; + if (SE) + SE->forgetLoop(L); // If we know the trip count, we know the multiple... unsigned BreakoutTrip = 0; @@ -300,15 +300,45 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, for (unsigned It = 1; It != Count; ++It) { std::vector<BasicBlock*> NewBlocks; + SmallDenseMap<const Loop *, Loop *, 4> NewLoops; + NewLoops[L] = L; for (LoopBlocksDFS::RPOIterator BB = BlockBegin; BB != BlockEnd; ++BB) { ValueToValueMapTy VMap; BasicBlock *New = CloneBasicBlock(*BB, VMap, "." + Twine(It)); Header->getParent()->getBasicBlockList().push_back(New); - // Loop over all of the PHI nodes in the block, changing them to use the - // incoming values from the previous block. + // Tell LI about New. + if (*BB == Header) { + assert(LI->getLoopFor(*BB) == L && "Header should not be in a sub-loop"); + L->addBasicBlockToLoop(New, LI->getBase()); + } else { + // Figure out which loop New is in. + const Loop *OldLoop = LI->getLoopFor(*BB); + assert(OldLoop && "Should (at least) be in the loop being unrolled!"); + + Loop *&NewLoop = NewLoops[OldLoop]; + if (!NewLoop) { + // Found a new sub-loop. + assert(*BB == OldLoop->getHeader() && + "Header should be first in RPO"); + + Loop *NewLoopParent = NewLoops.lookup(OldLoop->getParentLoop()); + assert(NewLoopParent && + "Expected parent loop before sub-loop in RPO"); + NewLoop = new Loop; + NewLoopParent->addChildLoop(NewLoop); + + // Forget the old loop, since its inputs may have changed. + if (SE) + SE->forgetLoop(OldLoop); + } + NewLoop->addBasicBlockToLoop(New, LI->getBase()); + } + if (*BB == Header) + // Loop over all of the PHI nodes in the block, changing them to use + // the incoming values from the previous block. for (unsigned i = 0, e = OrigPHINode.size(); i != e; ++i) { PHINode *NewPHI = cast<PHINode>(VMap[OrigPHINode[i]]); Value *InVal = NewPHI->getIncomingValueForBlock(LatchBlock); @@ -325,8 +355,6 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, VI != VE; ++VI) LastValueMap[VI->first] = VI->second; - L->addBasicBlockToLoop(New, LI->getBase()); - // Add phi entries for newly created values to all exit blocks. for (succ_iterator SI = succ_begin(*BB), SE = succ_end(*BB); SI != SE; ++SI) { @@ -442,6 +470,10 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, } } + // FIXME: We could register any cloned assumptions instead of clearing the + // whole function's cache. + AC->clear(); + DominatorTree *DT = nullptr; if (PP) { // FIXME: Reconstruct dom info, because it is not preserved properly. @@ -453,7 +485,6 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, } // Simplify any new induction variables in the partially unrolled loop. - ScalarEvolution *SE = PP->getAnalysisIfAvailable<ScalarEvolution>(); if (SE && !CompletelyUnroll) { SmallVector<WeakVH, 16> DeadInsts; simplifyLoopIVs(L, SE, LPM, DeadInsts); @@ -502,8 +533,7 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, if (OuterL) { DataLayoutPass *DLP = PP->getAnalysisIfAvailable<DataLayoutPass>(); const DataLayout *DL = DLP ? &DLP->getDataLayout() : nullptr; - ScalarEvolution *SE = PP->getAnalysisIfAvailable<ScalarEvolution>(); - simplifyLoop(OuterL, DT, LI, PP, /*AliasAnalysis*/ nullptr, SE, DL); + simplifyLoop(OuterL, DT, LI, PP, /*AliasAnalysis*/ nullptr, SE, DL, AC); // LCSSA must be performed on the outermost affected loop. The unrolled // loop's last loop latch is guaranteed to be in the outermost loop after @@ -513,7 +543,7 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, while (OuterL->getParentLoop() != LatchLoop) OuterL = OuterL->getParentLoop(); - formLCSSARecursively(*OuterL, *DT, SE); + formLCSSARecursively(*OuterL, *DT, LI, SE); } } diff --git a/lib/Transforms/Utils/LoopUnrollRuntime.cpp b/lib/Transforms/Utils/LoopUnrollRuntime.cpp index a96c46ad63e0..f12cd61d463a 100644 --- a/lib/Transforms/Utils/LoopUnrollRuntime.cpp +++ b/lib/Transforms/Utils/LoopUnrollRuntime.cpp @@ -28,6 +28,7 @@ #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionExpander.h" #include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Metadata.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" @@ -57,7 +58,7 @@ STATISTIC(NumRuntimeUnrolled, static void ConnectProlog(Loop *L, Value *TripCount, unsigned Count, BasicBlock *LastPrologBB, BasicBlock *PrologEnd, BasicBlock *OrigPH, BasicBlock *NewPH, - ValueToValueMapTy &LVMap, Pass *P) { + ValueToValueMapTy &VMap, Pass *P) { BasicBlock *Latch = L->getLoopLatch(); assert(Latch && "Loop must have a latch"); @@ -86,7 +87,7 @@ static void ConnectProlog(Loop *L, Value *TripCount, unsigned Count, Value *V = PN->getIncomingValueForBlock(Latch); if (Instruction *I = dyn_cast<Instruction>(V)) { if (L->contains(I)) { - V = LVMap[I]; + V = VMap[I]; } } // Adding a value to the new PHI node from the last prolog block @@ -127,76 +128,123 @@ static void ConnectProlog(Loop *L, Value *TripCount, unsigned Count, } /// Create a clone of the blocks in a loop and connect them together. -/// This function doesn't create a clone of the loop structure. +/// If UnrollProlog is true, loop structure will not be cloned, otherwise a new +/// loop will be created including all cloned blocks, and the iterator of it +/// switches to count NewIter down to 0. /// -/// There are two value maps that are defined and used. VMap is -/// for the values in the current loop instance. LVMap contains -/// the values from the last loop instance. We need the LVMap values -/// to update the initial values for the current loop instance. -/// -static void CloneLoopBlocks(Loop *L, - bool FirstCopy, - BasicBlock *InsertTop, - BasicBlock *InsertBot, +static void CloneLoopBlocks(Loop *L, Value *NewIter, const bool UnrollProlog, + BasicBlock *InsertTop, BasicBlock *InsertBot, std::vector<BasicBlock *> &NewBlocks, - LoopBlocksDFS &LoopBlocks, - ValueToValueMapTy &VMap, - ValueToValueMapTy &LVMap, + LoopBlocksDFS &LoopBlocks, ValueToValueMapTy &VMap, LoopInfo *LI) { - BasicBlock *Preheader = L->getLoopPreheader(); BasicBlock *Header = L->getHeader(); BasicBlock *Latch = L->getLoopLatch(); Function *F = Header->getParent(); LoopBlocksDFS::RPOIterator BlockBegin = LoopBlocks.beginRPO(); LoopBlocksDFS::RPOIterator BlockEnd = LoopBlocks.endRPO(); + Loop *NewLoop = 0; + Loop *ParentLoop = L->getParentLoop(); + if (!UnrollProlog) { + NewLoop = new Loop(); + if (ParentLoop) + ParentLoop->addChildLoop(NewLoop); + else + LI->addTopLevelLoop(NewLoop); + } + // For each block in the original loop, create a new copy, // and update the value map with the newly created values. for (LoopBlocksDFS::RPOIterator BB = BlockBegin; BB != BlockEnd; ++BB) { - BasicBlock *NewBB = CloneBasicBlock(*BB, VMap, ".unr", F); + BasicBlock *NewBB = CloneBasicBlock(*BB, VMap, ".prol", F); NewBlocks.push_back(NewBB); - if (Loop *ParentLoop = L->getParentLoop()) + if (NewLoop) + NewLoop->addBasicBlockToLoop(NewBB, LI->getBase()); + else if (ParentLoop) ParentLoop->addBasicBlockToLoop(NewBB, LI->getBase()); VMap[*BB] = NewBB; if (Header == *BB) { // For the first block, add a CFG connection to this newly - // created block + // created block. InsertTop->getTerminator()->setSuccessor(0, NewBB); - // Change the incoming values to the ones defined in the - // previously cloned loop. - for (BasicBlock::iterator I = Header->begin(); isa<PHINode>(I); ++I) { - PHINode *NewPHI = cast<PHINode>(VMap[I]); - if (FirstCopy) { - // We replace the first phi node with the value from the preheader - VMap[I] = NewPHI->getIncomingValueForBlock(Preheader); - NewBB->getInstList().erase(NewPHI); - } else { - // Update VMap with values from the previous block - unsigned idx = NewPHI->getBasicBlockIndex(Latch); - Value *InVal = NewPHI->getIncomingValue(idx); - if (Instruction *I = dyn_cast<Instruction>(InVal)) - if (L->contains(I)) - InVal = LVMap[InVal]; - NewPHI->setIncomingValue(idx, InVal); - NewPHI->setIncomingBlock(idx, InsertTop); - } - } } - if (Latch == *BB) { + // For the last block, if UnrollProlog is true, create a direct jump to + // InsertBot. If not, create a loop back to cloned head. VMap.erase((*BB)->getTerminator()); - NewBB->getTerminator()->eraseFromParent(); - BranchInst::Create(InsertBot, NewBB); + BasicBlock *FirstLoopBB = cast<BasicBlock>(VMap[Header]); + BranchInst *LatchBR = cast<BranchInst>(NewBB->getTerminator()); + if (UnrollProlog) { + LatchBR->eraseFromParent(); + BranchInst::Create(InsertBot, NewBB); + } else { + PHINode *NewIdx = PHINode::Create(NewIter->getType(), 2, "prol.iter", + FirstLoopBB->getFirstNonPHI()); + IRBuilder<> Builder(LatchBR); + Value *IdxSub = + Builder.CreateSub(NewIdx, ConstantInt::get(NewIdx->getType(), 1), + NewIdx->getName() + ".sub"); + Value *IdxCmp = + Builder.CreateIsNotNull(IdxSub, NewIdx->getName() + ".cmp"); + BranchInst::Create(FirstLoopBB, InsertBot, IdxCmp, NewBB); + NewIdx->addIncoming(NewIter, InsertTop); + NewIdx->addIncoming(IdxSub, NewBB); + LatchBR->eraseFromParent(); + } } } - // LastValueMap is updated with the values for the current loop - // which are used the next time this function is called. - for (ValueToValueMapTy::iterator VI = VMap.begin(), VE = VMap.end(); - VI != VE; ++VI) { - LVMap[VI->first] = VI->second; + + // Change the incoming values to the ones defined in the preheader or + // cloned loop. + for (BasicBlock::iterator I = Header->begin(); isa<PHINode>(I); ++I) { + PHINode *NewPHI = cast<PHINode>(VMap[I]); + if (UnrollProlog) { + VMap[I] = NewPHI->getIncomingValueForBlock(Preheader); + cast<BasicBlock>(VMap[Header])->getInstList().erase(NewPHI); + } else { + unsigned idx = NewPHI->getBasicBlockIndex(Preheader); + NewPHI->setIncomingBlock(idx, InsertTop); + BasicBlock *NewLatch = cast<BasicBlock>(VMap[Latch]); + idx = NewPHI->getBasicBlockIndex(Latch); + Value *InVal = NewPHI->getIncomingValue(idx); + NewPHI->setIncomingBlock(idx, NewLatch); + if (VMap[InVal]) + NewPHI->setIncomingValue(idx, VMap[InVal]); + } + } + if (NewLoop) { + // Add unroll disable metadata to disable future unrolling for this loop. + SmallVector<Metadata *, 4> MDs; + // Reserve first location for self reference to the LoopID metadata node. + MDs.push_back(nullptr); + MDNode *LoopID = NewLoop->getLoopID(); + if (LoopID) { + // First remove any existing loop unrolling metadata. + for (unsigned i = 1, ie = LoopID->getNumOperands(); i < ie; ++i) { + bool IsUnrollMetadata = false; + MDNode *MD = dyn_cast<MDNode>(LoopID->getOperand(i)); + if (MD) { + const MDString *S = dyn_cast<MDString>(MD->getOperand(0)); + IsUnrollMetadata = S && S->getString().startswith("llvm.loop.unroll."); + } + if (!IsUnrollMetadata) + MDs.push_back(LoopID->getOperand(i)); + } + } + + LLVMContext &Context = NewLoop->getHeader()->getContext(); + SmallVector<Metadata *, 1> DisableOperands; + DisableOperands.push_back(MDString::get(Context, "llvm.loop.unroll.disable")); + MDNode *DisableNode = MDNode::get(Context, DisableOperands); + MDs.push_back(DisableNode); + + MDNode *NewLoopID = MDNode::get(Context, MDs); + // Set operand 0 to refer to the loop id itself. + NewLoopID->replaceOperandWith(0, NewLoopID); + NewLoop->setLoopID(NewLoopID); } } @@ -212,18 +260,16 @@ static void CloneLoopBlocks(Loop *L, /// instruction in SimplifyCFG.cpp. Then, the backend decides how code for /// the switch instruction is generated. /// -/// extraiters = tripcount % loopfactor -/// if (extraiters == 0) jump Loop: -/// if (extraiters == loopfactor) jump L1 -/// if (extraiters == loopfactor-1) jump L2 -/// ... -/// L1: LoopBody; -/// L2: LoopBody; -/// ... -/// if tripcount < loopfactor jump End -/// Loop: -/// ... -/// End: +/// extraiters = tripcount % loopfactor +/// if (extraiters == 0) jump Loop: +/// else jump Prol +/// Prol: LoopBody; +/// extraiters -= 1 // Omitted if unroll factor is 2. +/// if (extraiters != 0) jump Prol: // Omitted if unroll factor is 2. +/// if (tripcount < loopfactor) jump End +/// Loop: +/// ... +/// End: /// bool llvm::UnrollRuntimeLoopProlog(Loop *L, unsigned Count, LoopInfo *LI, LPPassManager *LPM) { @@ -250,6 +296,10 @@ bool llvm::UnrollRuntimeLoopProlog(Loop *L, unsigned Count, LoopInfo *LI, if (isa<SCEVCouldNotCompute>(BECount) || !BECount->getType()->isIntegerTy()) return false; + // If BECount is INT_MAX, we can't compute trip-count without overflow. + if (BECount->isAllOnesValue()) + return false; + // Add 1 since the backedge count doesn't include the first loop iteration const SCEV *TripCountSC = SE->getAddExpr(BECount, SE->getConstant(BECount->getType(), 1)); @@ -284,26 +334,21 @@ bool llvm::UnrollRuntimeLoopProlog(Loop *L, unsigned Count, LoopInfo *LI, IRBuilder<> B(PreHeaderBR); Value *ModVal = B.CreateAnd(TripCount, Count - 1, "xtraiter"); - // Check if for no extra iterations, then jump to unrolled loop. We have to - // check that the trip count computation didn't overflow when adding one to - // the backedge taken count. + // Check if for no extra iterations, then jump to cloned/unrolled loop. + // We have to check that the trip count computation didn't overflow when + // adding one to the backedge taken count. Value *LCmp = B.CreateIsNotNull(ModVal, "lcmp.mod"); Value *OverflowCheck = B.CreateIsNull(TripCount, "lcmp.overflow"); Value *BranchVal = B.CreateOr(OverflowCheck, LCmp, "lcmp.or"); - // Branch to either the extra iterations or the unrolled loop + // Branch to either the extra iterations or the cloned/unrolled loop // We will fix up the true branch label when adding loop body copies BranchInst::Create(PEnd, PEnd, BranchVal, PreHeaderBR); assert(PreHeaderBR->isUnconditional() && PreHeaderBR->getSuccessor(0) == PEnd && "CFG edges in Preheader are not correct"); PreHeaderBR->eraseFromParent(); - - ValueToValueMapTy LVMap; Function *F = Header->getParent(); - // These variables are used to update the CFG links in each iteration - BasicBlock *CompareBB = nullptr; - BasicBlock *LastLoopBB = PH; // Get an ordered list of blocks in the loop to help with the ordering of the // cloned blocks in the prolog code LoopBlocksDFS LoopBlocks(L); @@ -314,62 +359,39 @@ bool llvm::UnrollRuntimeLoopProlog(Loop *L, unsigned Count, LoopInfo *LI, // and generate a condition that branches to the copy depending on the // number of 'left over' iterations. // - for (unsigned leftOverIters = Count-1; leftOverIters > 0; --leftOverIters) { - std::vector<BasicBlock*> NewBlocks; - ValueToValueMapTy VMap; - - // Clone all the basic blocks in the loop, but we don't clone the loop - // This function adds the appropriate CFG connections. - CloneLoopBlocks(L, (leftOverIters == Count-1), LastLoopBB, PEnd, NewBlocks, - LoopBlocks, VMap, LVMap, LI); - LastLoopBB = cast<BasicBlock>(VMap[Latch]); - - // Insert the cloned blocks into function just before the original loop - F->getBasicBlockList().splice(PEnd, F->getBasicBlockList(), - NewBlocks[0], F->end()); - - // Generate the code for the comparison which determines if the loop - // prolog code needs to be executed. - if (leftOverIters == Count-1) { - // There is no compare block for the fall-thru case when for the last - // left over iteration - CompareBB = NewBlocks[0]; - } else { - // Create a new block for the comparison - BasicBlock *NewBB = BasicBlock::Create(CompareBB->getContext(), "unr.cmp", - F, CompareBB); - if (Loop *ParentLoop = L->getParentLoop()) { - // Add the new block to the parent loop, if needed - ParentLoop->addBasicBlockToLoop(NewBB, LI->getBase()); - } - - // The comparison w/ the extra iteration value and branch - Type *CountTy = TripCount->getType(); - Value *BranchVal = new ICmpInst(*NewBB, ICmpInst::ICMP_EQ, ModVal, - ConstantInt::get(CountTy, leftOverIters), - "un.tmp"); - // Branch to either the extra iterations or the unrolled loop - BranchInst::Create(NewBlocks[0], CompareBB, - BranchVal, NewBB); - CompareBB = NewBB; - PH->getTerminator()->setSuccessor(0, NewBB); - VMap[NewPH] = CompareBB; - } - - // Rewrite the cloned instruction operands to use the values - // created when the clone is created. - for (unsigned i = 0, e = NewBlocks.size(); i != e; ++i) { - for (BasicBlock::iterator I = NewBlocks[i]->begin(), - E = NewBlocks[i]->end(); I != E; ++I) { - RemapInstruction(I, VMap, - RF_NoModuleLevelChanges|RF_IgnoreMissingEntries); - } + std::vector<BasicBlock *> NewBlocks; + ValueToValueMapTy VMap; + + // If unroll count is 2 and we can't overflow in tripcount computation (which + // is BECount + 1), then we don't need a loop for prologue, and we can unroll + // it. We can be sure that we don't overflow only if tripcount is a constant. + bool UnrollPrologue = (Count == 2 && isa<ConstantInt>(TripCount)); + + // Clone all the basic blocks in the loop. If Count is 2, we don't clone + // the loop, otherwise we create a cloned loop to execute the extra + // iterations. This function adds the appropriate CFG connections. + CloneLoopBlocks(L, ModVal, UnrollPrologue, PH, PEnd, NewBlocks, LoopBlocks, + VMap, LI); + + // Insert the cloned blocks into function just before the original loop + F->getBasicBlockList().splice(PEnd, F->getBasicBlockList(), NewBlocks[0], + F->end()); + + // Rewrite the cloned instruction operands to use the values + // created when the clone is created. + for (unsigned i = 0, e = NewBlocks.size(); i != e; ++i) { + for (BasicBlock::iterator I = NewBlocks[i]->begin(), + E = NewBlocks[i]->end(); + I != E; ++I) { + RemapInstruction(I, VMap, + RF_NoModuleLevelChanges | RF_IgnoreMissingEntries); } } // Connect the prolog code to the original loop and update the // PHI functions. - ConnectProlog(L, TripCount, Count, LastLoopBB, PEnd, PH, NewPH, LVMap, + BasicBlock *LastLoopBB = cast<BasicBlock>(VMap[Latch]); + ConnectProlog(L, TripCount, Count, LastLoopBB, PEnd, PH, NewPH, VMap, LPM->getAsPass()); NumRuntimeUnrolled++; return true; diff --git a/lib/Transforms/Utils/LowerSwitch.cpp b/lib/Transforms/Utils/LowerSwitch.cpp index d6e5bb626805..35cd917330ab 100644 --- a/lib/Transforms/Utils/LowerSwitch.cpp +++ b/lib/Transforms/Utils/LowerSwitch.cpp @@ -131,18 +131,39 @@ static raw_ostream& operator<<(raw_ostream &O, return O << "]"; } -static void fixPhis(BasicBlock *Succ, - BasicBlock *OrigBlock, - BasicBlock *NewNode) { - for (BasicBlock::iterator I = Succ->begin(), - E = Succ->getFirstNonPHI(); - I != E; ++I) { +// \brief Update the first occurrence of the "switch statement" BB in the PHI +// node with the "new" BB. The other occurrences will: +// +// 1) Be updated by subsequent calls to this function. Switch statements may +// have more than one outcoming edge into the same BB if they all have the same +// value. When the switch statement is converted these incoming edges are now +// coming from multiple BBs. +// 2) Removed if subsequent incoming values now share the same case, i.e., +// multiple outcome edges are condensed into one. This is necessary to keep the +// number of phi values equal to the number of branches to SuccBB. +static void fixPhis(BasicBlock *SuccBB, BasicBlock *OrigBB, BasicBlock *NewBB, + unsigned NumMergedCases) { + for (BasicBlock::iterator I = SuccBB->begin(), IE = SuccBB->getFirstNonPHI(); + I != IE; ++I) { PHINode *PN = cast<PHINode>(I); - for (unsigned I = 0, E = PN->getNumIncomingValues(); I != E; ++I) { - if (PN->getIncomingBlock(I) == OrigBlock) - PN->setIncomingBlock(I, NewNode); + // Only update the first occurence. + unsigned Idx = 0, E = PN->getNumIncomingValues(); + unsigned LocalNumMergedCases = NumMergedCases; + for (; Idx != E; ++Idx) { + if (PN->getIncomingBlock(Idx) == OrigBB) { + PN->setIncomingBlock(Idx, NewBB); + break; + } } + + // Remove additional occurences coming from condensed cases and keep the + // number of incoming values equal to the number of branches to SuccBB. + for (++Idx; LocalNumMergedCases > 0 && Idx < E; ++Idx) + if (PN->getIncomingBlock(Idx) == OrigBB) { + PN->removeIncomingValue(Idx); + LocalNumMergedCases--; + } } } @@ -165,7 +186,11 @@ BasicBlock *LowerSwitch::switchConvert(CaseItr Begin, CaseItr End, // emitting the code that checks if the value actually falls in the range // because the bounds already tell us so. if (Begin->Low == LowerBound && Begin->High == UpperBound) { - fixPhis(Begin->BB, OrigBlock, Predecessor); + unsigned NumMergedCases = 0; + if (LowerBound && UpperBound) + NumMergedCases = + UpperBound->getSExtValue() - LowerBound->getSExtValue(); + fixPhis(Begin->BB, OrigBlock, Predecessor, NumMergedCases); return Begin->BB; } return newLeafBlock(*Begin, Val, OrigBlock, Default); diff --git a/lib/Transforms/Utils/Mem2Reg.cpp b/lib/Transforms/Utils/Mem2Reg.cpp index 189caa7d145a..00cf4e6c01c8 100644 --- a/lib/Transforms/Utils/Mem2Reg.cpp +++ b/lib/Transforms/Utils/Mem2Reg.cpp @@ -14,6 +14,7 @@ #include "llvm/Transforms/Scalar.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/Instructions.h" @@ -38,6 +39,7 @@ namespace { bool runOnFunction(Function &F) override; void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<AssumptionCacheTracker>(); AU.addRequired<DominatorTreeWrapperPass>(); AU.setPreservesCFG(); // This is a cluster of orthogonal Transforms @@ -51,6 +53,7 @@ namespace { char PromotePass::ID = 0; INITIALIZE_PASS_BEGIN(PromotePass, "mem2reg", "Promote Memory to Register", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_END(PromotePass, "mem2reg", "Promote Memory to Register", false, false) @@ -63,6 +66,8 @@ bool PromotePass::runOnFunction(Function &F) { bool Changed = false; DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + AssumptionCache &AC = + getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); while (1) { Allocas.clear(); @@ -76,7 +81,7 @@ bool PromotePass::runOnFunction(Function &F) { if (Allocas.empty()) break; - PromoteMemToReg(Allocas, DT); + PromoteMemToReg(Allocas, DT, nullptr, &AC); NumPromoted += Allocas.size(); Changed = true; } diff --git a/lib/Transforms/Utils/ModuleUtils.cpp b/lib/Transforms/Utils/ModuleUtils.cpp index d9dbbca1c366..35c701eeedc9 100644 --- a/lib/Transforms/Utils/ModuleUtils.cpp +++ b/lib/Transforms/Utils/ModuleUtils.cpp @@ -78,7 +78,7 @@ void llvm::appendToGlobalDtors(Module &M, Function *F, int Priority) { } GlobalVariable * -llvm::collectUsedGlobalVariables(Module &M, SmallPtrSet<GlobalValue *, 8> &Set, +llvm::collectUsedGlobalVariables(Module &M, SmallPtrSetImpl<GlobalValue *> &Set, bool CompilerUsed) { const char *Name = CompilerUsed ? "llvm.compiler.used" : "llvm.used"; GlobalVariable *GV = M.getGlobalVariable(Name); diff --git a/lib/Transforms/Utils/PromoteMemoryToRegister.cpp b/lib/Transforms/Utils/PromoteMemoryToRegister.cpp index 06d73feb1cc8..dabadb794d4e 100644 --- a/lib/Transforms/Utils/PromoteMemoryToRegister.cpp +++ b/lib/Transforms/Utils/PromoteMemoryToRegister.cpp @@ -238,6 +238,9 @@ struct PromoteMem2Reg { /// An AliasSetTracker object to update. If null, don't update it. AliasSetTracker *AST; + /// A cache of @llvm.assume intrinsics used by SimplifyInstruction. + AssumptionCache *AC; + /// Reverse mapping of Allocas. DenseMap<AllocaInst *, unsigned> AllocaLookup; @@ -279,9 +282,10 @@ struct PromoteMem2Reg { public: PromoteMem2Reg(ArrayRef<AllocaInst *> Allocas, DominatorTree &DT, - AliasSetTracker *AST) + AliasSetTracker *AST, AssumptionCache *AC) : Allocas(Allocas.begin(), Allocas.end()), DT(DT), - DIB(*DT.getRoot()->getParent()->getParent()), AST(AST) {} + DIB(*DT.getRoot()->getParent()->getParent(), /*AllowUnresolved*/ false), + AST(AST), AC(AC) {} void run(); @@ -302,8 +306,8 @@ private: void DetermineInsertionPoint(AllocaInst *AI, unsigned AllocaNum, AllocaInfo &Info); void ComputeLiveInBlocks(AllocaInst *AI, AllocaInfo &Info, - const SmallPtrSet<BasicBlock *, 32> &DefBlocks, - SmallPtrSet<BasicBlock *, 32> &LiveInBlocks); + const SmallPtrSetImpl<BasicBlock *> &DefBlocks, + SmallPtrSetImpl<BasicBlock *> &LiveInBlocks); void RenamePass(BasicBlock *BB, BasicBlock *Pred, RenamePassData::ValVector &IncVals, std::vector<RenamePassData> &Worklist); @@ -412,7 +416,8 @@ static bool rewriteSingleStoreAlloca(AllocaInst *AI, AllocaInfo &Info, // Record debuginfo for the store and remove the declaration's // debuginfo. if (DbgDeclareInst *DDI = Info.DbgDeclare) { - DIBuilder DIB(*AI->getParent()->getParent()->getParent()); + DIBuilder DIB(*AI->getParent()->getParent()->getParent(), + /*AllowUnresolved*/ false); ConvertDebugDeclareToDebugValue(DDI, Info.OnlyStore, DIB); DDI->eraseFromParent(); LBI.deleteValue(DDI); @@ -495,7 +500,8 @@ static void promoteSingleBlockAlloca(AllocaInst *AI, const AllocaInfo &Info, StoreInst *SI = cast<StoreInst>(AI->user_back()); // Record debuginfo for the store before removing it. if (DbgDeclareInst *DDI = Info.DbgDeclare) { - DIBuilder DIB(*AI->getParent()->getParent()->getParent()); + DIBuilder DIB(*AI->getParent()->getParent()->getParent(), + /*AllowUnresolved*/ false); ConvertDebugDeclareToDebugValue(DDI, SI, DIB); } SI->eraseFromParent(); @@ -685,7 +691,7 @@ void PromoteMem2Reg::run() { PHINode *PN = I->second; // If this PHI node merges one value and/or undefs, get the value. - if (Value *V = SimplifyInstruction(PN, nullptr, nullptr, &DT)) { + if (Value *V = SimplifyInstruction(PN, nullptr, nullptr, &DT, AC)) { if (AST && PN->getType()->isPointerTy()) AST->deleteValue(PN); PN->replaceAllUsesWith(V); @@ -766,8 +772,8 @@ void PromoteMem2Reg::run() { /// inserted phi nodes would be dead). void PromoteMem2Reg::ComputeLiveInBlocks( AllocaInst *AI, AllocaInfo &Info, - const SmallPtrSet<BasicBlock *, 32> &DefBlocks, - SmallPtrSet<BasicBlock *, 32> &LiveInBlocks) { + const SmallPtrSetImpl<BasicBlock *> &DefBlocks, + SmallPtrSetImpl<BasicBlock *> &LiveInBlocks) { // To determine liveness, we must iterate through the predecessors of blocks // where the def is live. Blocks are added to the worklist if we need to @@ -816,7 +822,7 @@ void PromoteMem2Reg::ComputeLiveInBlocks( // The block really is live in here, insert it into the set. If already in // the set, then it has already been processed. - if (!LiveInBlocks.insert(BB)) + if (!LiveInBlocks.insert(BB).second) continue; // Since the value is live into BB, it is either defined in a predecessor or @@ -857,10 +863,8 @@ void PromoteMem2Reg::DetermineInsertionPoint(AllocaInst *AI, unsigned AllocaNum, less_second> IDFPriorityQueue; IDFPriorityQueue PQ; - for (SmallPtrSet<BasicBlock *, 32>::const_iterator I = DefBlocks.begin(), - E = DefBlocks.end(); - I != E; ++I) { - if (DomTreeNode *Node = DT.getNode(*I)) + for (BasicBlock *BB : DefBlocks) { + if (DomTreeNode *Node = DT.getNode(BB)) PQ.push(std::make_pair(Node, DomLevels[Node])); } @@ -898,7 +902,7 @@ void PromoteMem2Reg::DetermineInsertionPoint(AllocaInst *AI, unsigned AllocaNum, if (SuccLevel > RootLevel) continue; - if (!Visited.insert(SuccNode)) + if (!Visited.insert(SuccNode).second) continue; BasicBlock *SuccBB = SuccNode->getBlock(); @@ -1003,7 +1007,7 @@ NextIteration: } // Don't revisit blocks. - if (!Visited.insert(BB)) + if (!Visited.insert(BB).second) return; for (BasicBlock::iterator II = BB->begin(); !isa<TerminatorInst>(II);) { @@ -1060,17 +1064,17 @@ NextIteration: ++I; for (; I != E; ++I) - if (VisitedSuccs.insert(*I)) + if (VisitedSuccs.insert(*I).second) Worklist.push_back(RenamePassData(*I, Pred, IncomingVals)); goto NextIteration; } void llvm::PromoteMemToReg(ArrayRef<AllocaInst *> Allocas, DominatorTree &DT, - AliasSetTracker *AST) { + AliasSetTracker *AST, AssumptionCache *AC) { // If there is nothing to do, bail out... if (Allocas.empty()) return; - PromoteMem2Reg(Allocas, DT, AST).run(); + PromoteMem2Reg(Allocas, DT, AST, AC).run(); } diff --git a/lib/Transforms/Utils/SimplifyCFG.cpp b/lib/Transforms/Utils/SimplifyCFG.cpp index 24bb63bb60a5..f6867c24ee58 100644 --- a/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/lib/Transforms/Utils/SimplifyCFG.cpp @@ -43,6 +43,8 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/ValueMapper.h" #include <algorithm> #include <map> #include <set> @@ -68,12 +70,24 @@ static cl::opt<bool> HoistCondStores( cl::desc("Hoist conditional stores if an unconditional store precedes")); STATISTIC(NumBitMaps, "Number of switch instructions turned into bitmaps"); +STATISTIC(NumLinearMaps, "Number of switch instructions turned into linear mapping"); STATISTIC(NumLookupTables, "Number of switch instructions turned into lookup tables"); STATISTIC(NumLookupTablesHoles, "Number of switch instructions turned into lookup tables (holes checked)"); +STATISTIC(NumTableCmpReuses, "Number of reused switch table lookup compares"); STATISTIC(NumSinkCommons, "Number of common instructions sunk down to the end block"); STATISTIC(NumSpeculations, "Number of speculative executed instructions"); namespace { + // The first field contains the value that the switch produces when a certain + // case group is selected, and the second field is a vector containing the cases + // composing the case group. + typedef SmallVector<std::pair<Constant *, SmallVector<ConstantInt *, 4>>, 2> + SwitchCaseResultVectorTy; + // The first field contains the phi node that generates a result of the switch + // and the second field contains the value generated for a certain case in the switch + // for that PHI. + typedef SmallVector<std::pair<PHINode *, Constant *>, 4> SwitchCaseResultsTy; + /// ValueEqualityComparisonCase - Represents a case of a switch. struct ValueEqualityComparisonCase { ConstantInt *Value; @@ -92,7 +106,9 @@ namespace { class SimplifyCFGOpt { const TargetTransformInfo &TTI; + unsigned BonusInstThreshold; const DataLayout *const DL; + AssumptionCache *AC; Value *isValueEqualityComparison(TerminatorInst *TI); BasicBlock *GetValueEqualityComparisonCases(TerminatorInst *TI, std::vector<ValueEqualityComparisonCase> &Cases); @@ -111,8 +127,9 @@ class SimplifyCFGOpt { bool SimplifyCondBranch(BranchInst *BI, IRBuilder <>&Builder); public: - SimplifyCFGOpt(const TargetTransformInfo &TTI, const DataLayout *DL) - : TTI(TTI), DL(DL) {} + SimplifyCFGOpt(const TargetTransformInfo &TTI, unsigned BonusInstThreshold, + const DataLayout *DL, AssumptionCache *AC) + : TTI(TTI), BonusInstThreshold(BonusInstThreshold), DL(DL), AC(AC) {} bool run(BasicBlock *BB); }; } @@ -256,7 +273,7 @@ static unsigned ComputeSpeculationCost(const User *I, const DataLayout *DL) { /// V plus its non-dominating operands. If that cost is greater than /// CostRemaining, false is returned and CostRemaining is undefined. static bool DominatesMergePoint(Value *V, BasicBlock *BB, - SmallPtrSet<Instruction*, 4> *AggressiveInsts, + SmallPtrSetImpl<Instruction*> *AggressiveInsts, unsigned &CostRemaining, const DataLayout *DL) { Instruction *I = dyn_cast<Instruction>(V); @@ -341,114 +358,177 @@ static ConstantInt *GetConstantInt(Value *V, const DataLayout *DL) { return nullptr; } -/// GatherConstantCompares - Given a potentially 'or'd or 'and'd together -/// collection of icmp eq/ne instructions that compare a value against a -/// constant, return the value being compared, and stick the constant into the -/// Values vector. -static Value * -GatherConstantCompares(Value *V, std::vector<ConstantInt*> &Vals, Value *&Extra, - const DataLayout *DL, bool isEQ, unsigned &UsedICmps) { - Instruction *I = dyn_cast<Instruction>(V); - if (!I) return nullptr; - - // If this is an icmp against a constant, handle this as one of the cases. - if (ICmpInst *ICI = dyn_cast<ICmpInst>(I)) { - if (ConstantInt *C = GetConstantInt(I->getOperand(1), DL)) { - Value *RHSVal; - ConstantInt *RHSC; - - if (ICI->getPredicate() == (isEQ ? ICmpInst::ICMP_EQ:ICmpInst::ICMP_NE)) { - // (x & ~2^x) == y --> x == y || x == y|2^x - // This undoes a transformation done by instcombine to fuse 2 compares. - if (match(ICI->getOperand(0), - m_And(m_Value(RHSVal), m_ConstantInt(RHSC)))) { - APInt Not = ~RHSC->getValue(); - if (Not.isPowerOf2()) { - Vals.push_back(C); - Vals.push_back( - ConstantInt::get(C->getContext(), C->getValue() | Not)); - UsedICmps++; - return RHSVal; - } - } +namespace { + +/// Given a chain of or (||) or and (&&) comparison of a value against a +/// constant, this will try to recover the information required for a switch +/// structure. +/// It will depth-first traverse the chain of comparison, seeking for patterns +/// like %a == 12 or %a < 4 and combine them to produce a set of integer +/// representing the different cases for the switch. +/// Note that if the chain is composed of '||' it will build the set of elements +/// that matches the comparisons (i.e. any of this value validate the chain) +/// while for a chain of '&&' it will build the set elements that make the test +/// fail. +struct ConstantComparesGatherer { + + Value *CompValue; /// Value found for the switch comparison + Value *Extra; /// Extra clause to be checked before the switch + SmallVector<ConstantInt *, 8> Vals; /// Set of integers to match in switch + unsigned UsedICmps; /// Number of comparisons matched in the and/or chain + + /// Construct and compute the result for the comparison instruction Cond + ConstantComparesGatherer(Instruction *Cond, const DataLayout *DL) + : CompValue(nullptr), Extra(nullptr), UsedICmps(0) { + gather(Cond, DL); + } + + /// Prevent copy + ConstantComparesGatherer(const ConstantComparesGatherer &) + LLVM_DELETED_FUNCTION; + ConstantComparesGatherer & + operator=(const ConstantComparesGatherer &) LLVM_DELETED_FUNCTION; + +private: - UsedICmps++; - Vals.push_back(C); - return I->getOperand(0); + /// Try to set the current value used for the comparison, it succeeds only if + /// it wasn't set before or if the new value is the same as the old one + bool setValueOnce(Value *NewVal) { + if(CompValue && CompValue != NewVal) return false; + CompValue = NewVal; + return (CompValue != nullptr); + } + + /// Try to match Instruction "I" as a comparison against a constant and + /// populates the array Vals with the set of values that match (or do not + /// match depending on isEQ). + /// Return false on failure. On success, the Value the comparison matched + /// against is placed in CompValue. + /// If CompValue is already set, the function is expected to fail if a match + /// is found but the value compared to is different. + bool matchInstruction(Instruction *I, const DataLayout *DL, bool isEQ) { + // If this is an icmp against a constant, handle this as one of the cases. + ICmpInst *ICI; + ConstantInt *C; + if (!((ICI = dyn_cast<ICmpInst>(I)) && + (C = GetConstantInt(I->getOperand(1), DL)))) { + return false; + } + + Value *RHSVal; + ConstantInt *RHSC; + + // Pattern match a special case + // (x & ~2^x) == y --> x == y || x == y|2^x + // This undoes a transformation done by instcombine to fuse 2 compares. + if (ICI->getPredicate() == (isEQ ? ICmpInst::ICMP_EQ:ICmpInst::ICMP_NE)) { + if (match(ICI->getOperand(0), + m_And(m_Value(RHSVal), m_ConstantInt(RHSC)))) { + APInt Not = ~RHSC->getValue(); + if (Not.isPowerOf2()) { + // If we already have a value for the switch, it has to match! + if(!setValueOnce(RHSVal)) + return false; + + Vals.push_back(C); + Vals.push_back(ConstantInt::get(C->getContext(), + C->getValue() | Not)); + UsedICmps++; + return true; + } } - // If we have "x ult 3" comparison, for example, then we can add 0,1,2 to - // the set. - ConstantRange Span = - ConstantRange::makeICmpRegion(ICI->getPredicate(), C->getValue()); - - // Shift the range if the compare is fed by an add. This is the range - // compare idiom as emitted by instcombine. - bool hasAdd = - match(I->getOperand(0), m_Add(m_Value(RHSVal), m_ConstantInt(RHSC))); - if (hasAdd) - Span = Span.subtract(RHSC->getValue()); - - // If this is an and/!= check then we want to optimize "x ugt 2" into - // x != 0 && x != 1. - if (!isEQ) - Span = Span.inverse(); - - // If there are a ton of values, we don't want to make a ginormous switch. - if (Span.getSetSize().ugt(8) || Span.isEmptySet()) - return nullptr; - - for (APInt Tmp = Span.getLower(); Tmp != Span.getUpper(); ++Tmp) - Vals.push_back(ConstantInt::get(V->getContext(), Tmp)); + // If we already have a value for the switch, it has to match! + if(!setValueOnce(ICI->getOperand(0))) + return false; + UsedICmps++; - return hasAdd ? RHSVal : I->getOperand(0); + Vals.push_back(C); + return ICI->getOperand(0); } - return nullptr; - } - // Otherwise, we can only handle an | or &, depending on isEQ. - if (I->getOpcode() != (isEQ ? Instruction::Or : Instruction::And)) - return nullptr; + // If we have "x ult 3", for example, then we can add 0,1,2 to the set. + ConstantRange Span = ConstantRange::makeICmpRegion(ICI->getPredicate(), + C->getValue()); - unsigned NumValsBeforeLHS = Vals.size(); - unsigned UsedICmpsBeforeLHS = UsedICmps; - if (Value *LHS = GatherConstantCompares(I->getOperand(0), Vals, Extra, DL, - isEQ, UsedICmps)) { - unsigned NumVals = Vals.size(); - unsigned UsedICmpsBeforeRHS = UsedICmps; - if (Value *RHS = GatherConstantCompares(I->getOperand(1), Vals, Extra, DL, - isEQ, UsedICmps)) { - if (LHS == RHS) - return LHS; - Vals.resize(NumVals); - UsedICmps = UsedICmpsBeforeRHS; + // Shift the range if the compare is fed by an add. This is the range + // compare idiom as emitted by instcombine. + Value *CandidateVal = I->getOperand(0); + if(match(I->getOperand(0), m_Add(m_Value(RHSVal), m_ConstantInt(RHSC)))) { + Span = Span.subtract(RHSC->getValue()); + CandidateVal = RHSVal; } - // The RHS of the or/and can't be folded in and we haven't used "Extra" yet, - // set it and return success. - if (Extra == nullptr || Extra == I->getOperand(1)) { - Extra = I->getOperand(1); - return LHS; + // If this is an and/!= check, then we are looking to build the set of + // value that *don't* pass the and chain. I.e. to turn "x ugt 2" into + // x != 0 && x != 1. + if (!isEQ) + Span = Span.inverse(); + + // If there are a ton of values, we don't want to make a ginormous switch. + if (Span.getSetSize().ugt(8) || Span.isEmptySet()) { + return false; } - Vals.resize(NumValsBeforeLHS); - UsedICmps = UsedICmpsBeforeLHS; - return nullptr; + // If we already have a value for the switch, it has to match! + if(!setValueOnce(CandidateVal)) + return false; + + // Add all values from the range to the set + for (APInt Tmp = Span.getLower(); Tmp != Span.getUpper(); ++Tmp) + Vals.push_back(ConstantInt::get(I->getContext(), Tmp)); + + UsedICmps++; + return true; + } - // If the LHS can't be folded in, but Extra is available and RHS can, try to - // use LHS as Extra. - if (Extra == nullptr || Extra == I->getOperand(0)) { - Value *OldExtra = Extra; - Extra = I->getOperand(0); - if (Value *RHS = GatherConstantCompares(I->getOperand(1), Vals, Extra, DL, - isEQ, UsedICmps)) - return RHS; - assert(Vals.size() == NumValsBeforeLHS); - Extra = OldExtra; + /// gather - Given a potentially 'or'd or 'and'd together collection of icmp + /// eq/ne/lt/gt instructions that compare a value against a constant, extract + /// the value being compared, and stick the list constants into the Vals + /// vector. + /// One "Extra" case is allowed to differ from the other. + void gather(Value *V, const DataLayout *DL) { + Instruction *I = dyn_cast<Instruction>(V); + bool isEQ = (I->getOpcode() == Instruction::Or); + + // Keep a stack (SmallVector for efficiency) for depth-first traversal + SmallVector<Value *, 8> DFT; + + // Initialize + DFT.push_back(V); + + while(!DFT.empty()) { + V = DFT.pop_back_val(); + + if (Instruction *I = dyn_cast<Instruction>(V)) { + // If it is a || (or && depending on isEQ), process the operands. + if (I->getOpcode() == (isEQ ? Instruction::Or : Instruction::And)) { + DFT.push_back(I->getOperand(1)); + DFT.push_back(I->getOperand(0)); + continue; + } + + // Try to match the current instruction + if (matchInstruction(I, DL, isEQ)) + // Match succeed, continue the loop + continue; + } + + // One element of the sequence of || (or &&) could not be match as a + // comparison against the same value as the others. + // We allow only one "Extra" case to be checked before the switch + if (!Extra) { + Extra = V; + continue; + } + // Failed to parse a proper sequence, abort now + CompValue = nullptr; + break; + } } +}; - return nullptr; } static void EraseTerminatorInstAndDCECond(TerminatorInst *TI) { @@ -628,13 +708,12 @@ SimplifyEqualityComparisonWithOnlyPredecessor(TerminatorInst *TI, // Collect branch weights into a vector. SmallVector<uint32_t, 8> Weights; - MDNode* MD = SI->getMetadata(LLVMContext::MD_prof); + MDNode *MD = SI->getMetadata(LLVMContext::MD_prof); bool HasWeight = MD && (MD->getNumOperands() == 2 + SI->getNumCases()); if (HasWeight) for (unsigned MD_i = 1, MD_e = MD->getNumOperands(); MD_i < MD_e; ++MD_i) { - ConstantInt* CI = dyn_cast<ConstantInt>(MD->getOperand(MD_i)); - assert(CI); + ConstantInt *CI = mdconst::extract<ConstantInt>(MD->getOperand(MD_i)); Weights.push_back(CI->getValue().getZExtValue()); } for (SwitchInst::CaseIt i = SI->case_end(), e = SI->case_begin(); i != e;) { @@ -723,7 +802,7 @@ static int ConstantIntSortPredicate(ConstantInt *const *P1, } static inline bool HasBranchWeights(const Instruction* I) { - MDNode* ProfMD = I->getMetadata(LLVMContext::MD_prof); + MDNode *ProfMD = I->getMetadata(LLVMContext::MD_prof); if (ProfMD && ProfMD->getOperand(0)) if (MDString* MDS = dyn_cast<MDString>(ProfMD->getOperand(0))) return MDS->getString().equals("branch_weights"); @@ -736,10 +815,10 @@ static inline bool HasBranchWeights(const Instruction* I) { /// metadata. static void GetBranchWeights(TerminatorInst *TI, SmallVectorImpl<uint64_t> &Weights) { - MDNode* MD = TI->getMetadata(LLVMContext::MD_prof); + MDNode *MD = TI->getMetadata(LLVMContext::MD_prof); assert(MD); for (unsigned i = 1, e = MD->getNumOperands(); i < e; ++i) { - ConstantInt *CI = cast<ConstantInt>(MD->getOperand(i)); + ConstantInt *CI = mdconst::extract<ConstantInt>(MD->getOperand(i)); Weights.push_back(CI->getValue().getZExtValue()); } @@ -995,6 +1074,8 @@ static bool isSafeToHoistInvoke(BasicBlock *BB1, BasicBlock *BB2, return true; } +static bool passingValueIsAlwaysUndefined(Value *V, Instruction *I); + /// HoistThenElseCodeToIf - Given a conditional branch that goes to BB1 and /// BB2, hoist any common code in the two blocks up into the branch block. The /// caller of this function guarantees that BI's block dominates BB1 and BB2. @@ -1040,6 +1121,14 @@ static bool HoistThenElseCodeToIf(BranchInst *BI, const DataLayout *DL) { if (!I2->use_empty()) I2->replaceAllUsesWith(I1); I1->intersectOptionalDataWith(I2); + unsigned KnownIDs[] = { + LLVMContext::MD_tbaa, + LLVMContext::MD_range, + LLVMContext::MD_fpmath, + LLVMContext::MD_invariant_load, + LLVMContext::MD_nonnull + }; + combineMetadata(I1, I2, KnownIDs); I2->eraseFromParent(); Changed = true; @@ -1072,6 +1161,12 @@ HoistTerminator: if (BB1V == BB2V) continue; + // Check for passingValueIsAlwaysUndefined here because we would rather + // eliminate undefined control flow then converting it to a select. + if (passingValueIsAlwaysUndefined(BB1V, PN) || + passingValueIsAlwaysUndefined(BB2V, PN)) + return Changed; + if (isa<ConstantExpr>(BB1V) && !isSafeToSpeculativelyExecute(BB1V, DL)) return Changed; if (isa<ConstantExpr>(BB2V) && !isSafeToSpeculativelyExecute(BB2V, DL)) @@ -1149,14 +1244,13 @@ static bool SinkThenElseCodeToEnd(BranchInst *BI1) { return false; // Gather the PHI nodes in BBEnd. - std::map<Value*, std::pair<Value*, PHINode*> > MapValueFromBB1ToBB2; + SmallDenseMap<std::pair<Value *, Value *>, PHINode *> JointValueMap; Instruction *FirstNonPhiInBBEnd = nullptr; - for (BasicBlock::iterator I = BBEnd->begin(), E = BBEnd->end(); - I != E; ++I) { + for (BasicBlock::iterator I = BBEnd->begin(), E = BBEnd->end(); I != E; ++I) { if (PHINode *PN = dyn_cast<PHINode>(I)) { Value *BB1V = PN->getIncomingValueForBlock(BB1); Value *BB2V = PN->getIncomingValueForBlock(BB2); - MapValueFromBB1ToBB2[BB1V] = std::make_pair(BB2V, PN); + JointValueMap[std::make_pair(BB1V, BB2V)] = PN; } else { FirstNonPhiInBBEnd = &*I; break; @@ -1165,13 +1259,13 @@ static bool SinkThenElseCodeToEnd(BranchInst *BI1) { if (!FirstNonPhiInBBEnd) return false; - // This does very trivial matching, with limited scanning, to find identical // instructions in the two blocks. We scan backward for obviously identical // instructions in an identical order. BasicBlock::InstListType::reverse_iterator RI1 = BB1->getInstList().rbegin(), - RE1 = BB1->getInstList().rend(), RI2 = BB2->getInstList().rbegin(), - RE2 = BB2->getInstList().rend(); + RE1 = BB1->getInstList().rend(), + RI2 = BB2->getInstList().rbegin(), + RE2 = BB2->getInstList().rend(); // Skip debug info. while (RI1 != RE1 && isa<DbgInfoIntrinsic>(&*RI1)) ++RI1; if (RI1 == RE1) @@ -1194,6 +1288,7 @@ static bool SinkThenElseCodeToEnd(BranchInst *BI1) { return Changed; Instruction *I1 = &*RI1, *I2 = &*RI2; + auto InstPair = std::make_pair(I1, I2); // I1 and I2 should have a single use in the same PHI node, and they // perform the same operation. // Cannot move control-flow-involving, volatile loads, vaarg, etc. @@ -1204,11 +1299,11 @@ static bool SinkThenElseCodeToEnd(BranchInst *BI1) { I1->mayHaveSideEffects() || I2->mayHaveSideEffects() || I1->mayReadOrWriteMemory() || I2->mayReadOrWriteMemory() || !I1->hasOneUse() || !I2->hasOneUse() || - MapValueFromBB1ToBB2.find(I1) == MapValueFromBB1ToBB2.end() || - MapValueFromBB1ToBB2[I1].first != I2) + !JointValueMap.count(InstPair)) return Changed; // Check whether we should swap the operands of ICmpInst. + // TODO: Add support of communativity. ICmpInst *ICmp1 = dyn_cast<ICmpInst>(I1), *ICmp2 = dyn_cast<ICmpInst>(I2); bool SwapOpnds = false; if (ICmp1 && ICmp2 && @@ -1229,16 +1324,13 @@ static bool SinkThenElseCodeToEnd(BranchInst *BI1) { // with a PHI node after sinking. We only handle the case where there is // a single pair of different operands. Value *DifferentOp1 = nullptr, *DifferentOp2 = nullptr; - unsigned Op1Idx = 0; + unsigned Op1Idx = ~0U; for (unsigned I = 0, E = I1->getNumOperands(); I != E; ++I) { if (I1->getOperand(I) == I2->getOperand(I)) continue; - // Early exit if we have more-than one pair of different operands or - // the different operand is already in MapValueFromBB1ToBB2. - // Early exit if we need a PHI node to replace a constant. - if (DifferentOp1 || - MapValueFromBB1ToBB2.find(I1->getOperand(I)) != - MapValueFromBB1ToBB2.end() || + // Early exit if we have more-than one pair of different operands or if + // we need a PHI node to replace a constant. + if (Op1Idx != ~0U || isa<Constant>(I1->getOperand(I)) || isa<Constant>(I2->getOperand(I))) { // If we can't sink the instructions, undo the swapping. @@ -1251,24 +1343,27 @@ static bool SinkThenElseCodeToEnd(BranchInst *BI1) { DifferentOp2 = I2->getOperand(I); } - // We insert the pair of different operands to MapValueFromBB1ToBB2 and - // remove (I1, I2) from MapValueFromBB1ToBB2. - if (DifferentOp1) { - PHINode *NewPN = PHINode::Create(DifferentOp1->getType(), 2, - DifferentOp1->getName() + ".sink", - BBEnd->begin()); - MapValueFromBB1ToBB2[DifferentOp1] = std::make_pair(DifferentOp2, NewPN); + DEBUG(dbgs() << "SINK common instructions " << *I1 << "\n"); + DEBUG(dbgs() << " " << *I2 << "\n"); + + // We insert the pair of different operands to JointValueMap and + // remove (I1, I2) from JointValueMap. + if (Op1Idx != ~0U) { + auto &NewPN = JointValueMap[std::make_pair(DifferentOp1, DifferentOp2)]; + if (!NewPN) { + NewPN = + PHINode::Create(DifferentOp1->getType(), 2, + DifferentOp1->getName() + ".sink", BBEnd->begin()); + NewPN->addIncoming(DifferentOp1, BB1); + NewPN->addIncoming(DifferentOp2, BB2); + DEBUG(dbgs() << "Create PHI node " << *NewPN << "\n";); + } // I1 should use NewPN instead of DifferentOp1. I1->setOperand(Op1Idx, NewPN); - NewPN->addIncoming(DifferentOp1, BB1); - NewPN->addIncoming(DifferentOp2, BB2); - DEBUG(dbgs() << "Create PHI node " << *NewPN << "\n";); } - PHINode *OldPN = MapValueFromBB1ToBB2[I1].second; - MapValueFromBB1ToBB2.erase(I1); + PHINode *OldPN = JointValueMap[InstPair]; + JointValueMap.erase(InstPair); - DEBUG(dbgs() << "SINK common instructions " << *I1 << "\n";); - DEBUG(dbgs() << " " << *I2 << "\n";); // We need to update RE1 and RE2 if we are going to sink the first // instruction in the basic block down. bool UpdateRE1 = (I1 == BB1->begin()), UpdateRE2 = (I2 == BB2->begin()); @@ -1281,6 +1376,8 @@ static bool SinkThenElseCodeToEnd(BranchInst *BI1) { if (!I2->use_empty()) I2->replaceAllUsesWith(I1); I1->intersectOptionalDataWith(I2); + // TODO: Use combineMetadata here to preserve what metadata we can + // (analogous to the hoisting case above). I2->eraseFromParent(); if (UpdateRE1) @@ -1486,6 +1583,11 @@ static bool SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB, if (ThenV == OrigV) continue; + // Don't convert to selects if we could remove undefined behavior instead. + if (passingValueIsAlwaysUndefined(OrigV, PN) || + passingValueIsAlwaysUndefined(ThenV, PN)) + return false; + HaveRewritablePHIs = true; ConstantExpr *OrigCE = dyn_cast<ConstantExpr>(OrigV); ConstantExpr *ThenCE = dyn_cast<ConstantExpr>(ThenV); @@ -1934,8 +2036,10 @@ static bool ExtractBranchMetadata(BranchInst *BI, "Looking for probabilities on unconditional branch?"); MDNode *ProfileData = BI->getMetadata(LLVMContext::MD_prof); if (!ProfileData || ProfileData->getNumOperands() != 3) return false; - ConstantInt *CITrue = dyn_cast<ConstantInt>(ProfileData->getOperand(1)); - ConstantInt *CIFalse = dyn_cast<ConstantInt>(ProfileData->getOperand(2)); + ConstantInt *CITrue = + mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(1)); + ConstantInt *CIFalse = + mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(2)); if (!CITrue || !CIFalse) return false; ProbTrue = CITrue->getValue().getZExtValue(); ProbFalse = CIFalse->getValue().getZExtValue(); @@ -1963,7 +2067,8 @@ static bool checkCSEInPredecessor(Instruction *Inst, BasicBlock *PB) { /// FoldBranchToCommonDest - If this basic block is simple enough, and if a /// predecessor branches to us and one of our successors, fold the block into /// the predecessor and use logical operations to pick the right destination. -bool llvm::FoldBranchToCommonDest(BranchInst *BI, const DataLayout *DL) { +bool llvm::FoldBranchToCommonDest(BranchInst *BI, const DataLayout *DL, + unsigned BonusInstThreshold) { BasicBlock *BB = BI->getParent(); Instruction *Cond = nullptr; @@ -2000,33 +2105,6 @@ bool llvm::FoldBranchToCommonDest(BranchInst *BI, const DataLayout *DL) { Cond->getParent() != BB || !Cond->hasOneUse()) return false; - // Only allow this if the condition is a simple instruction that can be - // executed unconditionally. It must be in the same block as the branch, and - // must be at the front of the block. - BasicBlock::iterator FrontIt = BB->front(); - - // Ignore dbg intrinsics. - while (isa<DbgInfoIntrinsic>(FrontIt)) ++FrontIt; - - // Allow a single instruction to be hoisted in addition to the compare - // that feeds the branch. We later ensure that any values that _it_ uses - // were also live in the predecessor, so that we don't unnecessarily create - // register pressure or inhibit out-of-order execution. - Instruction *BonusInst = nullptr; - if (&*FrontIt != Cond && - FrontIt->hasOneUse() && FrontIt->user_back() == Cond && - isSafeToSpeculativelyExecute(FrontIt, DL)) { - BonusInst = &*FrontIt; - ++FrontIt; - - // Ignore dbg intrinsics. - while (isa<DbgInfoIntrinsic>(FrontIt)) ++FrontIt; - } - - // Only a single bonus inst is allowed. - if (&*FrontIt != Cond) - return false; - // Make sure the instruction after the condition is the cond branch. BasicBlock::iterator CondIt = Cond; ++CondIt; @@ -2036,6 +2114,31 @@ bool llvm::FoldBranchToCommonDest(BranchInst *BI, const DataLayout *DL) { if (&*CondIt != BI) return false; + // Only allow this transformation if computing the condition doesn't involve + // too many instructions and these involved instructions can be executed + // unconditionally. We denote all involved instructions except the condition + // as "bonus instructions", and only allow this transformation when the + // number of the bonus instructions does not exceed a certain threshold. + unsigned NumBonusInsts = 0; + for (auto I = BB->begin(); Cond != I; ++I) { + // Ignore dbg intrinsics. + if (isa<DbgInfoIntrinsic>(I)) + continue; + if (!I->hasOneUse() || !isSafeToSpeculativelyExecute(I, DL)) + return false; + // I has only one use and can be executed unconditionally. + Instruction *User = dyn_cast<Instruction>(I->user_back()); + if (User == nullptr || User->getParent() != BB) + return false; + // I is used in the same BB. Since BI uses Cond and doesn't have more slots + // to use any other instruction, User must be an instruction between next(I) + // and Cond. + ++NumBonusInsts; + // Early exits once we reach the limit. + if (NumBonusInsts > BonusInstThreshold) + return false; + } + // Cond is known to be a compare or binary operator. Check to make sure that // neither operand is a potentially-trapping constant expression. if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Cond->getOperand(0))) @@ -2086,49 +2189,6 @@ bool llvm::FoldBranchToCommonDest(BranchInst *BI, const DataLayout *DL) { continue; } - // Ensure that any values used in the bonus instruction are also used - // by the terminator of the predecessor. This means that those values - // must already have been resolved, so we won't be inhibiting the - // out-of-order core by speculating them earlier. We also allow - // instructions that are used by the terminator's condition because it - // exposes more merging opportunities. - bool UsedByBranch = (BonusInst && BonusInst->hasOneUse() && - BonusInst->user_back() == Cond); - - if (BonusInst && !UsedByBranch) { - // Collect the values used by the bonus inst - SmallPtrSet<Value*, 4> UsedValues; - for (Instruction::op_iterator OI = BonusInst->op_begin(), - OE = BonusInst->op_end(); OI != OE; ++OI) { - Value *V = *OI; - if (!isa<Constant>(V) && !isa<Argument>(V)) - UsedValues.insert(V); - } - - SmallVector<std::pair<Value*, unsigned>, 4> Worklist; - Worklist.push_back(std::make_pair(PBI->getOperand(0), 0)); - - // Walk up to four levels back up the use-def chain of the predecessor's - // terminator to see if all those values were used. The choice of four - // levels is arbitrary, to provide a compile-time-cost bound. - while (!Worklist.empty()) { - std::pair<Value*, unsigned> Pair = Worklist.back(); - Worklist.pop_back(); - - if (Pair.second >= 4) continue; - UsedValues.erase(Pair.first); - if (UsedValues.empty()) break; - - if (Instruction *I = dyn_cast<Instruction>(Pair.first)) { - for (Instruction::op_iterator OI = I->op_begin(), OE = I->op_end(); - OI != OE; ++OI) - Worklist.push_back(std::make_pair(OI->get(), Pair.second+1)); - } - } - - if (!UsedValues.empty()) return false; - } - DEBUG(dbgs() << "FOLDING BRANCH TO COMMON DEST:\n" << *PBI << *BB); IRBuilder<> Builder(PBI); @@ -2148,30 +2208,41 @@ bool llvm::FoldBranchToCommonDest(BranchInst *BI, const DataLayout *DL) { PBI->swapSuccessors(); } - // If we have a bonus inst, clone it into the predecessor block. - Instruction *NewBonus = nullptr; - if (BonusInst) { - NewBonus = BonusInst->clone(); + // If we have bonus instructions, clone them into the predecessor block. + // Note that there may be mutliple predecessor blocks, so we cannot move + // bonus instructions to a predecessor block. + ValueToValueMapTy VMap; // maps original values to cloned values + // We already make sure Cond is the last instruction before BI. Therefore, + // every instructions before Cond other than DbgInfoIntrinsic are bonus + // instructions. + for (auto BonusInst = BB->begin(); Cond != BonusInst; ++BonusInst) { + if (isa<DbgInfoIntrinsic>(BonusInst)) + continue; + Instruction *NewBonusInst = BonusInst->clone(); + RemapInstruction(NewBonusInst, VMap, + RF_NoModuleLevelChanges | RF_IgnoreMissingEntries); + VMap[BonusInst] = NewBonusInst; // If we moved a load, we cannot any longer claim any knowledge about // its potential value. The previous information might have been valid // only given the branch precondition. // For an analogous reason, we must also drop all the metadata whose // semantics we don't understand. - NewBonus->dropUnknownMetadata(LLVMContext::MD_dbg); + NewBonusInst->dropUnknownMetadata(LLVMContext::MD_dbg); - PredBlock->getInstList().insert(PBI, NewBonus); - NewBonus->takeName(BonusInst); - BonusInst->setName(BonusInst->getName()+".old"); + PredBlock->getInstList().insert(PBI, NewBonusInst); + NewBonusInst->takeName(BonusInst); + BonusInst->setName(BonusInst->getName() + ".old"); } // Clone Cond into the predecessor basic block, and or/and the // two conditions together. Instruction *New = Cond->clone(); - if (BonusInst) New->replaceUsesOfWith(BonusInst, NewBonus); + RemapInstruction(New, VMap, + RF_NoModuleLevelChanges | RF_IgnoreMissingEntries); PredBlock->getInstList().insert(PBI, New); New->takeName(Cond); - Cond->setName(New->getName()+".old"); + Cond->setName(New->getName() + ".old"); if (BI->isConditional()) { Instruction *NewCond = @@ -2649,7 +2720,7 @@ static bool SimplifyIndirectBrOnSelect(IndirectBrInst *IBI, SelectInst *SI) { /// the PHI, merging the third icmp into the switch. static bool TryToSimplifyUncondBranchWithICmpInIt( ICmpInst *ICI, IRBuilder<> &Builder, const TargetTransformInfo &TTI, - const DataLayout *DL) { + unsigned BonusInstThreshold, const DataLayout *DL, AssumptionCache *AC) { BasicBlock *BB = ICI->getParent(); // If the block has any PHIs in it or the icmp has multiple uses, it is too @@ -2682,7 +2753,7 @@ static bool TryToSimplifyUncondBranchWithICmpInIt( ICI->eraseFromParent(); } // BB is now empty, so it is likely to simplify away. - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AC) | true; } // Ok, the block is reachable from the default dest. If the constant we're @@ -2698,7 +2769,7 @@ static bool TryToSimplifyUncondBranchWithICmpInIt( ICI->replaceAllUsesWith(V); ICI->eraseFromParent(); // BB is now empty, so it is likely to simplify away. - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AC) | true; } // The use of the icmp has to be in the 'end' block, by the only PHI node in @@ -2759,24 +2830,17 @@ static bool SimplifyBranchOnICmpChain(BranchInst *BI, const DataLayout *DL, Instruction *Cond = dyn_cast<Instruction>(BI->getCondition()); if (!Cond) return false; - // Change br (X == 0 | X == 1), T, F into a switch instruction. // If this is a bunch of seteq's or'd together, or if it's a bunch of // 'setne's and'ed together, collect them. - Value *CompVal = nullptr; - std::vector<ConstantInt*> Values; - bool TrueWhenEqual = true; - Value *ExtraCase = nullptr; - unsigned UsedICmps = 0; - - if (Cond->getOpcode() == Instruction::Or) { - CompVal = GatherConstantCompares(Cond, Values, ExtraCase, DL, true, - UsedICmps); - } else if (Cond->getOpcode() == Instruction::And) { - CompVal = GatherConstantCompares(Cond, Values, ExtraCase, DL, false, - UsedICmps); - TrueWhenEqual = false; - } + + // Try to gather values from a chain of and/or to be turned into a switch + ConstantComparesGatherer ConstantCompare(Cond, DL); + // Unpack the result + SmallVectorImpl<ConstantInt*> &Values = ConstantCompare.Vals; + Value *CompVal = ConstantCompare.CompValue; + unsigned UsedICmps = ConstantCompare.UsedICmps; + Value *ExtraCase = ConstantCompare.Extra; // If we didn't have a multiply compared value, fail. if (!CompVal) return false; @@ -2785,6 +2849,8 @@ static bool SimplifyBranchOnICmpChain(BranchInst *BI, const DataLayout *DL, if (UsedICmps <= 1) return false; + bool TrueWhenEqual = (Cond->getOpcode() == Instruction::Or); + // There might be duplicate constants in the list, which the switch // instruction can't handle, remove them now. array_pod_sort(Values.begin(), Values.end(), ConstantIntSortPredicate); @@ -2954,7 +3020,7 @@ bool SimplifyCFGOpt::SimplifyReturn(ReturnInst *RI, IRBuilder<> &Builder) { } // If we eliminated all predecessors of the block, delete the block now. - if (pred_begin(BB) == pred_end(BB)) + if (pred_empty(BB)) // We know there are no successors, so just nuke the block. BB->eraseFromParent(); @@ -3127,7 +3193,7 @@ bool SimplifyCFGOpt::SimplifyUnreachable(UnreachableInst *UI) { } // If this block is now dead, remove it. - if (pred_begin(BB) == pred_end(BB) && + if (pred_empty(BB) && BB != &BB->getParent()->getEntryBlock()) { // We know there are no successors, so just nuke the block. BB->eraseFromParent(); @@ -3208,11 +3274,12 @@ static bool TurnSwitchRangeIntoICmp(SwitchInst *SI, IRBuilder<> &Builder) { /// EliminateDeadSwitchCases - Compute masked bits for the condition of a switch /// and use it to remove dead cases. -static bool EliminateDeadSwitchCases(SwitchInst *SI) { +static bool EliminateDeadSwitchCases(SwitchInst *SI, const DataLayout *DL, + AssumptionCache *AC) { Value *Cond = SI->getCondition(); unsigned Bits = Cond->getType()->getIntegerBitWidth(); APInt KnownZero(Bits, 0), KnownOne(Bits, 0); - computeKnownBits(Cond, KnownZero, KnownOne); + computeKnownBits(Cond, KnownZero, KnownOne, DL, 0, AC, SI); // Gather dead cases. SmallVector<ConstantInt*, 8> DeadCases; @@ -3419,6 +3486,21 @@ GetCaseResults(SwitchInst *SI, continue; } else if (Constant *C = ConstantFold(I, ConstantPool, DL)) { // Instruction is side-effect free and constant. + + // If the instruction has uses outside this block or a phi node slot for + // the block, it is not safe to bypass the instruction since it would then + // no longer dominate all its uses. + for (auto &Use : I->uses()) { + User *User = Use.getUser(); + if (Instruction *I = dyn_cast<Instruction>(User)) + if (I->getParent() == CaseDest) + continue; + if (PHINode *Phi = dyn_cast<PHINode>(User)) + if (Phi->getIncomingBlock(Use) == CaseDest) + continue; + return false; + } + ConstantPool.insert(std::make_pair(I, C)); } else { break; @@ -3444,12 +3526,6 @@ GetCaseResults(SwitchInst *SI, if (!ConstVal) return false; - // Note: If the constant comes from constant-propagating the case value - // through the CaseDest basic block, it will be safe to remove the - // instructions in that block. They cannot be used (except in the phi nodes - // we visit) outside CaseDest, because that block does not dominate its - // successor. If it did, we would not be in this phi node. - // Be conservative about which kinds of constants we support. if (!ValidLookupTableConstant(ConstVal)) return false; @@ -3460,6 +3536,163 @@ GetCaseResults(SwitchInst *SI, return Res.size() > 0; } +// MapCaseToResult - Helper function used to +// add CaseVal to the list of cases that generate Result. +static void MapCaseToResult(ConstantInt *CaseVal, + SwitchCaseResultVectorTy &UniqueResults, + Constant *Result) { + for (auto &I : UniqueResults) { + if (I.first == Result) { + I.second.push_back(CaseVal); + return; + } + } + UniqueResults.push_back(std::make_pair(Result, + SmallVector<ConstantInt*, 4>(1, CaseVal))); +} + +// InitializeUniqueCases - Helper function that initializes a map containing +// results for the PHI node of the common destination block for a switch +// instruction. Returns false if multiple PHI nodes have been found or if +// there is not a common destination block for the switch. +static bool InitializeUniqueCases( + SwitchInst *SI, const DataLayout *DL, PHINode *&PHI, + BasicBlock *&CommonDest, + SwitchCaseResultVectorTy &UniqueResults, + Constant *&DefaultResult) { + for (auto &I : SI->cases()) { + ConstantInt *CaseVal = I.getCaseValue(); + + // Resulting value at phi nodes for this case value. + SwitchCaseResultsTy Results; + if (!GetCaseResults(SI, CaseVal, I.getCaseSuccessor(), &CommonDest, Results, + DL)) + return false; + + // Only one value per case is permitted + if (Results.size() > 1) + return false; + MapCaseToResult(CaseVal, UniqueResults, Results.begin()->second); + + // Check the PHI consistency. + if (!PHI) + PHI = Results[0].first; + else if (PHI != Results[0].first) + return false; + } + // Find the default result value. + SmallVector<std::pair<PHINode *, Constant *>, 1> DefaultResults; + BasicBlock *DefaultDest = SI->getDefaultDest(); + GetCaseResults(SI, nullptr, SI->getDefaultDest(), &CommonDest, DefaultResults, + DL); + // If the default value is not found abort unless the default destination + // is unreachable. + DefaultResult = + DefaultResults.size() == 1 ? DefaultResults.begin()->second : nullptr; + if ((!DefaultResult && + !isa<UnreachableInst>(DefaultDest->getFirstNonPHIOrDbg()))) + return false; + + return true; +} + +// ConvertTwoCaseSwitch - Helper function that checks if it is possible to +// transform a switch with only two cases (or two cases + default) +// that produces a result into a value select. +// Example: +// switch (a) { +// case 10: %0 = icmp eq i32 %a, 10 +// return 10; %1 = select i1 %0, i32 10, i32 4 +// case 20: ----> %2 = icmp eq i32 %a, 20 +// return 2; %3 = select i1 %2, i32 2, i32 %1 +// default: +// return 4; +// } +static Value * +ConvertTwoCaseSwitch(const SwitchCaseResultVectorTy &ResultVector, + Constant *DefaultResult, Value *Condition, + IRBuilder<> &Builder) { + assert(ResultVector.size() == 2 && + "We should have exactly two unique results at this point"); + // If we are selecting between only two cases transform into a simple + // select or a two-way select if default is possible. + if (ResultVector[0].second.size() == 1 && + ResultVector[1].second.size() == 1) { + ConstantInt *const FirstCase = ResultVector[0].second[0]; + ConstantInt *const SecondCase = ResultVector[1].second[0]; + + bool DefaultCanTrigger = DefaultResult; + Value *SelectValue = ResultVector[1].first; + if (DefaultCanTrigger) { + Value *const ValueCompare = + Builder.CreateICmpEQ(Condition, SecondCase, "switch.selectcmp"); + SelectValue = Builder.CreateSelect(ValueCompare, ResultVector[1].first, + DefaultResult, "switch.select"); + } + Value *const ValueCompare = + Builder.CreateICmpEQ(Condition, FirstCase, "switch.selectcmp"); + return Builder.CreateSelect(ValueCompare, ResultVector[0].first, SelectValue, + "switch.select"); + } + + return nullptr; +} + +// RemoveSwitchAfterSelectConversion - Helper function to cleanup a switch +// instruction that has been converted into a select, fixing up PHI nodes and +// basic blocks. +static void RemoveSwitchAfterSelectConversion(SwitchInst *SI, PHINode *PHI, + Value *SelectValue, + IRBuilder<> &Builder) { + BasicBlock *SelectBB = SI->getParent(); + while (PHI->getBasicBlockIndex(SelectBB) >= 0) + PHI->removeIncomingValue(SelectBB); + PHI->addIncoming(SelectValue, SelectBB); + + Builder.CreateBr(PHI->getParent()); + + // Remove the switch. + for (unsigned i = 0, e = SI->getNumSuccessors(); i < e; ++i) { + BasicBlock *Succ = SI->getSuccessor(i); + + if (Succ == PHI->getParent()) + continue; + Succ->removePredecessor(SelectBB); + } + SI->eraseFromParent(); +} + +/// SwitchToSelect - If the switch is only used to initialize one or more +/// phi nodes in a common successor block with only two different +/// constant values, replace the switch with select. +static bool SwitchToSelect(SwitchInst *SI, IRBuilder<> &Builder, + const DataLayout *DL, AssumptionCache *AC) { + Value *const Cond = SI->getCondition(); + PHINode *PHI = nullptr; + BasicBlock *CommonDest = nullptr; + Constant *DefaultResult; + SwitchCaseResultVectorTy UniqueResults; + // Collect all the cases that will deliver the same value from the switch. + if (!InitializeUniqueCases(SI, DL, PHI, CommonDest, UniqueResults, + DefaultResult)) + return false; + // Selects choose between maximum two values. + if (UniqueResults.size() != 2) + return false; + assert(PHI != nullptr && "PHI for value select not found"); + + Builder.SetInsertPoint(SI); + Value *SelectValue = ConvertTwoCaseSwitch( + UniqueResults, + DefaultResult, Cond, Builder); + if (SelectValue) { + RemoveSwitchAfterSelectConversion(SI, PHI, SelectValue, Builder); + return true; + } + // The switch couldn't be converted into a select. + return false; +} + namespace { /// SwitchLookupTable - This class represents a lookup table that can be used /// to replace a switch. @@ -3493,6 +3726,11 @@ namespace { // store that single value and return it for each lookup. SingleValueKind, + // For tables where there is a linear relationship between table index + // and values. We calculate the result with a simple multiplication + // and addition instead of a table lookup. + LinearMapKind, + // For small tables with integer elements, we can pack them into a bitmap // that fits into a target-legal register. Values are retrieved by // shift and mask operations. @@ -3510,6 +3748,10 @@ namespace { ConstantInt *BitMap; IntegerType *BitMapElementTy; + // For LinearMapKind, these are the constants used to derive the value. + ConstantInt *LinearOffset; + ConstantInt *LinearMultiplier; + // For ArrayKind, this is the array. GlobalVariable *Array; }; @@ -3522,7 +3764,7 @@ SwitchLookupTable::SwitchLookupTable(Module &M, Constant *DefaultValue, const DataLayout *DL) : SingleValue(nullptr), BitMap(nullptr), BitMapElementTy(nullptr), - Array(nullptr) { + LinearOffset(nullptr), LinearMultiplier(nullptr), Array(nullptr) { assert(Values.size() && "Can't build lookup table without values!"); assert(TableSize >= Values.size() && "Can't fit values in table!"); @@ -3567,6 +3809,43 @@ SwitchLookupTable::SwitchLookupTable(Module &M, return; } + // Check if we can derive the value with a linear transformation from the + // table index. + if (isa<IntegerType>(ValueType)) { + bool LinearMappingPossible = true; + APInt PrevVal; + APInt DistToPrev; + assert(TableSize >= 2 && "Should be a SingleValue table."); + // Check if there is the same distance between two consecutive values. + for (uint64_t I = 0; I < TableSize; ++I) { + ConstantInt *ConstVal = dyn_cast<ConstantInt>(TableContents[I]); + if (!ConstVal) { + // This is an undef. We could deal with it, but undefs in lookup tables + // are very seldom. It's probably not worth the additional complexity. + LinearMappingPossible = false; + break; + } + APInt Val = ConstVal->getValue(); + if (I != 0) { + APInt Dist = Val - PrevVal; + if (I == 1) { + DistToPrev = Dist; + } else if (Dist != DistToPrev) { + LinearMappingPossible = false; + break; + } + } + PrevVal = Val; + } + if (LinearMappingPossible) { + LinearOffset = cast<ConstantInt>(TableContents[0]); + LinearMultiplier = ConstantInt::get(M.getContext(), DistToPrev); + Kind = LinearMapKind; + ++NumLinearMaps; + return; + } + } + // If the type is integer and the table fits in a register, build a bitmap. if (WouldFitInRegister(DL, TableSize, ValueType)) { IntegerType *IT = cast<IntegerType>(ValueType); @@ -3602,6 +3881,16 @@ Value *SwitchLookupTable::BuildLookup(Value *Index, IRBuilder<> &Builder) { switch (Kind) { case SingleValueKind: return SingleValue; + case LinearMapKind: { + // Derive the result value from the input value. + Value *Result = Builder.CreateIntCast(Index, LinearMultiplier->getType(), + false, "switch.idx.cast"); + if (!LinearMultiplier->isOne()) + Result = Builder.CreateMul(Result, LinearMultiplier, "switch.idx.mult"); + if (!LinearOffset->isZero()) + Result = Builder.CreateAdd(Result, LinearOffset, "switch.offset"); + return Result; + } case BitMapKind: { // Type of the bitmap (e.g. i59). IntegerType *MapTy = BitMap->getType(); @@ -3673,9 +3962,8 @@ static bool ShouldBuildLookupTable(SwitchInst *SI, bool AllTablesFitInRegister = true; bool HasIllegalType = false; - for (SmallDenseMap<PHINode*, Type*>::const_iterator I = ResultTypes.begin(), - E = ResultTypes.end(); I != E; ++I) { - Type *Ty = I->second; + for (const auto &I : ResultTypes) { + Type *Ty = I.second; // Saturate this flag to true. HasIllegalType = HasIllegalType || !TTI.isTypeLegal(Ty); @@ -3705,6 +3993,89 @@ static bool ShouldBuildLookupTable(SwitchInst *SI, return SI->getNumCases() * 10 >= TableSize * 4; } +/// Try to reuse the switch table index compare. Following pattern: +/// \code +/// if (idx < tablesize) +/// r = table[idx]; // table does not contain default_value +/// else +/// r = default_value; +/// if (r != default_value) +/// ... +/// \endcode +/// Is optimized to: +/// \code +/// cond = idx < tablesize; +/// if (cond) +/// r = table[idx]; +/// else +/// r = default_value; +/// if (cond) +/// ... +/// \endcode +/// Jump threading will then eliminate the second if(cond). +static void reuseTableCompare(User *PhiUser, BasicBlock *PhiBlock, + BranchInst *RangeCheckBranch, Constant *DefaultValue, + const SmallVectorImpl<std::pair<ConstantInt*, Constant*> >& Values) { + + ICmpInst *CmpInst = dyn_cast<ICmpInst>(PhiUser); + if (!CmpInst) + return; + + // We require that the compare is in the same block as the phi so that jump + // threading can do its work afterwards. + if (CmpInst->getParent() != PhiBlock) + return; + + Constant *CmpOp1 = dyn_cast<Constant>(CmpInst->getOperand(1)); + if (!CmpOp1) + return; + + Value *RangeCmp = RangeCheckBranch->getCondition(); + Constant *TrueConst = ConstantInt::getTrue(RangeCmp->getType()); + Constant *FalseConst = ConstantInt::getFalse(RangeCmp->getType()); + + // Check if the compare with the default value is constant true or false. + Constant *DefaultConst = ConstantExpr::getICmp(CmpInst->getPredicate(), + DefaultValue, CmpOp1, true); + if (DefaultConst != TrueConst && DefaultConst != FalseConst) + return; + + // Check if the compare with the case values is distinct from the default + // compare result. + for (auto ValuePair : Values) { + Constant *CaseConst = ConstantExpr::getICmp(CmpInst->getPredicate(), + ValuePair.second, CmpOp1, true); + if (!CaseConst || CaseConst == DefaultConst) + return; + assert((CaseConst == TrueConst || CaseConst == FalseConst) && + "Expect true or false as compare result."); + } + + // Check if the branch instruction dominates the phi node. It's a simple + // dominance check, but sufficient for our needs. + // Although this check is invariant in the calling loops, it's better to do it + // at this late stage. Practically we do it at most once for a switch. + BasicBlock *BranchBlock = RangeCheckBranch->getParent(); + for (auto PI = pred_begin(PhiBlock), E = pred_end(PhiBlock); PI != E; ++PI) { + BasicBlock *Pred = *PI; + if (Pred != BranchBlock && Pred->getUniquePredecessor() != BranchBlock) + return; + } + + if (DefaultConst == FalseConst) { + // The compare yields the same result. We can replace it. + CmpInst->replaceAllUsesWith(RangeCmp); + ++NumTableCmpReuses; + } else { + // The compare yields the same result, just inverted. We can replace it. + Value *InvertedTableCmp = BinaryOperator::CreateXor(RangeCmp, + ConstantInt::get(RangeCmp->getType(), 1), "inverted.cmp", + RangeCheckBranch); + CmpInst->replaceAllUsesWith(InvertedTableCmp); + ++NumTableCmpReuses; + } +} + /// SwitchToLookupTable - If the switch is only used to initialize one or more /// phi nodes in a common successor block with different constant values, /// replace the switch with lookup tables. @@ -3759,16 +4130,17 @@ static bool SwitchToLookupTable(SwitchInst *SI, return false; // Append the result from this case to the list for each phi. - for (ResultsTy::iterator I = Results.begin(), E = Results.end(); I!=E; ++I) { - if (!ResultLists.count(I->first)) - PHIs.push_back(I->first); - ResultLists[I->first].push_back(std::make_pair(CaseVal, I->second)); + for (const auto &I : Results) { + PHINode *PHI = I.first; + Constant *Value = I.second; + if (!ResultLists.count(PHI)) + PHIs.push_back(PHI); + ResultLists[PHI].push_back(std::make_pair(CaseVal, Value)); } } // Keep track of the result types. - for (size_t I = 0, E = PHIs.size(); I != E; ++I) { - PHINode *PHI = PHIs[I]; + for (PHINode *PHI : PHIs) { ResultTypes[PHI] = ResultLists[PHI][0].second->getType(); } @@ -3780,11 +4152,9 @@ static bool SwitchToLookupTable(SwitchInst *SI, // If the table has holes, we need a constant result for the default case // or a bitmask that fits in a register. SmallVector<std::pair<PHINode*, Constant*>, 4> DefaultResultsList; - bool HasDefaultResults = false; - if (TableHasHoles) { - HasDefaultResults = GetCaseResults(SI, nullptr, SI->getDefaultDest(), + bool HasDefaultResults = GetCaseResults(SI, nullptr, SI->getDefaultDest(), &CommonDest, DefaultResultsList, DL); - } + bool NeedMask = (TableHasHoles && !HasDefaultResults); if (NeedMask) { // As an extra penalty for the validity test we require more cases. @@ -3794,9 +4164,9 @@ static bool SwitchToLookupTable(SwitchInst *SI, return false; } - for (size_t I = 0, E = DefaultResultsList.size(); I != E; ++I) { - PHINode *PHI = DefaultResultsList[I].first; - Constant *Result = DefaultResultsList[I].second; + for (const auto &I : DefaultResultsList) { + PHINode *PHI = I.first; + Constant *Result = I.second; DefaultResults[PHI] = Result; } @@ -3827,14 +4197,19 @@ static bool SwitchToLookupTable(SwitchInst *SI, // lookup table BB. Otherwise, check if the condition value is within the case // range. If it is so, branch to the new BB. Otherwise branch to SI's default // destination. + BranchInst *RangeCheckBranch = nullptr; + const bool GeneratingCoveredLookupTable = MaxTableSize == TableSize; if (GeneratingCoveredLookupTable) { Builder.CreateBr(LookupBB); - SI->getDefaultDest()->removePredecessor(SI->getParent()); + // We cached PHINodes in PHIs, to avoid accessing deleted PHINodes later, + // do not delete PHINodes here. + SI->getDefaultDest()->removePredecessor(SI->getParent(), + true/*DontDeleteUselessPHIs*/); } else { Value *Cmp = Builder.CreateICmpULT(TableIndex, ConstantInt::get( MinCaseVal->getType(), TableSize)); - Builder.CreateCondBr(Cmp, LookupBB, SI->getDefaultDest()); + RangeCheckBranch = Builder.CreateCondBr(Cmp, LookupBB, SI->getDefaultDest()); } // Populate the BB that does the lookups. @@ -3851,9 +4226,12 @@ static bool SwitchToLookupTable(SwitchInst *SI, CommonDest->getParent(), CommonDest); + // Make the mask's bitwidth at least 8bit and a power-of-2 to avoid + // unnecessary illegal types. + uint64_t TableSizePowOf2 = NextPowerOf2(std::max(7ULL, TableSize - 1ULL)); + APInt MaskInt(TableSizePowOf2, 0); + APInt One(TableSizePowOf2, 1); // Build bitmask; fill in a 1 bit for every case. - APInt MaskInt(TableSize, 0); - APInt One(TableSize, 1); const ResultListTy &ResultList = ResultLists[PHIs[0]]; for (size_t I = 0, E = ResultList.size(); I != E; ++I) { uint64_t Idx = (ResultList[I].first->getValue() - @@ -3882,11 +4260,11 @@ static bool SwitchToLookupTable(SwitchInst *SI, bool ReturnedEarly = false; for (size_t I = 0, E = PHIs.size(); I != E; ++I) { PHINode *PHI = PHIs[I]; + const ResultListTy &ResultList = ResultLists[PHI]; // If using a bitmask, use any value to fill the lookup table holes. Constant *DV = NeedMask ? ResultLists[PHI][0].second : DefaultResults[PHI]; - SwitchLookupTable Table(Mod, TableSize, MinCaseVal, ResultLists[PHI], - DV, DL); + SwitchLookupTable Table(Mod, TableSize, MinCaseVal, ResultList, DV, DL); Value *Result = Table.BuildLookup(TableIndex, Builder); @@ -3899,6 +4277,16 @@ static bool SwitchToLookupTable(SwitchInst *SI, break; } + // Do a small peephole optimization: re-use the switch table compare if + // possible. + if (!TableHasHoles && HasDefaultResults && RangeCheckBranch) { + BasicBlock *PhiBlock = PHI->getParent(); + // Search for compare instructions which use the phi. + for (auto *User : PHI->users()) { + reuseTableCompare(User, PhiBlock, RangeCheckBranch, DV, ResultList); + } + } + PHI->addIncoming(Result, LookupBB); } @@ -3929,12 +4317,12 @@ bool SimplifyCFGOpt::SimplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) { // see if that predecessor totally determines the outcome of this switch. if (BasicBlock *OnlyPred = BB->getSinglePredecessor()) if (SimplifyEqualityComparisonWithOnlyPredecessor(SI, OnlyPred, Builder)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AC) | true; Value *Cond = SI->getCondition(); if (SelectInst *Select = dyn_cast<SelectInst>(Cond)) if (SimplifySwitchOnSelect(SI, Select)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AC) | true; // If the block only contains the switch, see if we can fold the block // away into any preds. @@ -3944,22 +4332,25 @@ bool SimplifyCFGOpt::SimplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) { ++BBI; if (SI == &*BBI) if (FoldValueComparisonIntoPredecessors(SI, Builder)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AC) | true; } // Try to transform the switch into an icmp and a branch. if (TurnSwitchRangeIntoICmp(SI, Builder)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AC) | true; // Remove unreachable cases. - if (EliminateDeadSwitchCases(SI)) - return SimplifyCFG(BB, TTI, DL) | true; + if (EliminateDeadSwitchCases(SI, DL, AC)) + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AC) | true; + + if (SwitchToSelect(SI, Builder, DL, AC)) + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AC) | true; if (ForwardSwitchConditionToPHI(SI)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AC) | true; if (SwitchToLookupTable(SI, Builder, TTI, DL)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AC) | true; return false; } @@ -3972,7 +4363,7 @@ bool SimplifyCFGOpt::SimplifyIndirectBr(IndirectBrInst *IBI) { SmallPtrSet<Value *, 8> Succs; for (unsigned i = 0, e = IBI->getNumDestinations(); i != e; ++i) { BasicBlock *Dest = IBI->getDestination(i); - if (!Dest->hasAddressTaken() || !Succs.insert(Dest)) { + if (!Dest->hasAddressTaken() || !Succs.insert(Dest).second) { Dest->removePredecessor(BB); IBI->removeDestination(i); --i; --e; @@ -3996,7 +4387,7 @@ bool SimplifyCFGOpt::SimplifyIndirectBr(IndirectBrInst *IBI) { if (SelectInst *SI = dyn_cast<SelectInst>(IBI->getAddress())) { if (SimplifyIndirectBrOnSelect(IBI, SI)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AC) | true; } return Changed; } @@ -4008,7 +4399,7 @@ bool SimplifyCFGOpt::SimplifyUncondBranch(BranchInst *BI, IRBuilder<> &Builder){ return true; // If the Terminator is the only non-phi instruction, simplify the block. - BasicBlock::iterator I = BB->getFirstNonPHIOrDbgOrLifetime(); + BasicBlock::iterator I = BB->getFirstNonPHIOrDbg(); if (I->isTerminator() && BB != &BB->getParent()->getEntryBlock() && TryToSimplifyUncondBranchFromEmptyBlock(BB)) return true; @@ -4020,7 +4411,8 @@ bool SimplifyCFGOpt::SimplifyUncondBranch(BranchInst *BI, IRBuilder<> &Builder){ for (++I; isa<DbgInfoIntrinsic>(I); ++I) ; if (I->isTerminator() && - TryToSimplifyUncondBranchWithICmpInIt(ICI, Builder, TTI, DL)) + TryToSimplifyUncondBranchWithICmpInIt(ICI, Builder, TTI, + BonusInstThreshold, DL, AC)) return true; } @@ -4028,8 +4420,8 @@ bool SimplifyCFGOpt::SimplifyUncondBranch(BranchInst *BI, IRBuilder<> &Builder){ // branches to us and our successor, fold the comparison into the // predecessor and use logical operations to update the incoming value // for PHI nodes in common successor. - if (FoldBranchToCommonDest(BI, DL)) - return SimplifyCFG(BB, TTI, DL) | true; + if (FoldBranchToCommonDest(BI, DL, BonusInstThreshold)) + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AC) | true; return false; } @@ -4044,7 +4436,7 @@ bool SimplifyCFGOpt::SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { // switch. if (BasicBlock *OnlyPred = BB->getSinglePredecessor()) if (SimplifyEqualityComparisonWithOnlyPredecessor(BI, OnlyPred, Builder)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AC) | true; // This block must be empty, except for the setcond inst, if it exists. // Ignore dbg intrinsics. @@ -4054,14 +4446,14 @@ bool SimplifyCFGOpt::SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { ++I; if (&*I == BI) { if (FoldValueComparisonIntoPredecessors(BI, Builder)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AC) | true; } else if (&*I == cast<Instruction>(BI->getCondition())){ ++I; // Ignore dbg intrinsics. while (isa<DbgInfoIntrinsic>(I)) ++I; if (&*I == BI && FoldValueComparisonIntoPredecessors(BI, Builder)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AC) | true; } } @@ -4072,8 +4464,8 @@ bool SimplifyCFGOpt::SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { // If this basic block is ONLY a compare and a branch, and if a predecessor // branches to us and one of our successors, fold the comparison into the // predecessor and use logical operations to pick the right destination. - if (FoldBranchToCommonDest(BI, DL)) - return SimplifyCFG(BB, TTI, DL) | true; + if (FoldBranchToCommonDest(BI, DL, BonusInstThreshold)) + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AC) | true; // We have a conditional branch to two blocks that are only reachable // from BI. We know that the condbr dominates the two blocks, so see if @@ -4082,7 +4474,7 @@ bool SimplifyCFGOpt::SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { if (BI->getSuccessor(0)->getSinglePredecessor()) { if (BI->getSuccessor(1)->getSinglePredecessor()) { if (HoistThenElseCodeToIf(BI, DL)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AC) | true; } else { // If Successor #1 has multiple preds, we may be able to conditionally // execute Successor #0 if it branches to Successor #1. @@ -4090,7 +4482,7 @@ bool SimplifyCFGOpt::SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { if (Succ0TI->getNumSuccessors() == 1 && Succ0TI->getSuccessor(0) == BI->getSuccessor(1)) if (SpeculativelyExecuteBB(BI, BI->getSuccessor(0), DL)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AC) | true; } } else if (BI->getSuccessor(1)->getSinglePredecessor()) { // If Successor #0 has multiple preds, we may be able to conditionally @@ -4099,7 +4491,7 @@ bool SimplifyCFGOpt::SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { if (Succ1TI->getNumSuccessors() == 1 && Succ1TI->getSuccessor(0) == BI->getSuccessor(0)) if (SpeculativelyExecuteBB(BI, BI->getSuccessor(1), DL)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AC) | true; } // If this is a branch on a phi node in the current block, thread control @@ -4107,14 +4499,14 @@ bool SimplifyCFGOpt::SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { if (PHINode *PN = dyn_cast<PHINode>(BI->getCondition())) if (PN->getParent() == BI->getParent()) if (FoldCondBranchOnPHI(BI, DL)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AC) | true; // Scan predecessor blocks for conditional branches. for (pred_iterator PI = pred_begin(BB), E = pred_end(BB); PI != E; ++PI) if (BranchInst *PBI = dyn_cast<BranchInst>((*PI)->getTerminator())) if (PBI != BI && PBI->isConditional()) if (SimplifyCondBranchToCondBranch(PBI, BI)) - return SimplifyCFG(BB, TTI, DL) | true; + return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AC) | true; return false; } @@ -4195,7 +4587,7 @@ bool SimplifyCFGOpt::run(BasicBlock *BB) { // Remove basic blocks that have no predecessors (except the entry block)... // or that just have themself as a predecessor. These are unreachable. - if ((pred_begin(BB) == pred_end(BB) && + if ((pred_empty(BB) && BB != &BB->getParent()->getEntryBlock()) || BB->getSinglePredecessor() == BB) { DEBUG(dbgs() << "Removing BB: \n" << *BB); @@ -4258,6 +4650,7 @@ bool SimplifyCFGOpt::run(BasicBlock *BB) { /// of the CFG. It returns true if a modification was made. /// bool llvm::SimplifyCFG(BasicBlock *BB, const TargetTransformInfo &TTI, - const DataLayout *DL) { - return SimplifyCFGOpt(TTI, DL).run(BB); + unsigned BonusInstThreshold, const DataLayout *DL, + AssumptionCache *AC) { + return SimplifyCFGOpt(TTI, BonusInstThreshold, DL, AC).run(BB); } diff --git a/lib/Transforms/Utils/SimplifyIndVar.cpp b/lib/Transforms/Utils/SimplifyIndVar.cpp index b284e6f6c1f6..f8aa1d3eec12 100644 --- a/lib/Transforms/Utils/SimplifyIndVar.cpp +++ b/lib/Transforms/Utils/SimplifyIndVar.cpp @@ -40,7 +40,7 @@ STATISTIC(NumElimRem , "Number of IV remainder operations eliminated"); STATISTIC(NumElimCmp , "Number of IV comparisons eliminated"); namespace { - /// SimplifyIndvar - This is a utility for simplifying induction variables + /// This is a utility for simplifying induction variables /// based on ScalarEvolution. It is the primary instrument of the /// IndvarSimplify pass, but it may also be directly invoked to cleanup after /// other loop passes that preserve SCEV. @@ -80,13 +80,14 @@ namespace { void eliminateIVComparison(ICmpInst *ICmp, Value *IVOperand); void eliminateIVRemainder(BinaryOperator *Rem, Value *IVOperand, bool IsSigned); + bool strengthenOverflowingOperation(BinaryOperator *OBO, Value *IVOperand); Instruction *splitOverflowIntrinsic(Instruction *IVUser, const DominatorTree *DT); }; } -/// foldIVUser - Fold an IV operand into its use. This removes increments of an +/// Fold an IV operand into its use. This removes increments of an /// aligned IV when used by a instruction that ignores the low bits. /// /// IVOperand is guaranteed SCEVable, but UseInst may not be. @@ -152,7 +153,7 @@ Value *SimplifyIndvar::foldIVUser(Instruction *UseInst, Instruction *IVOperand) return IVSrc; } -/// eliminateIVComparison - SimplifyIVUsers helper for eliminating useless +/// SimplifyIVUsers helper for eliminating useless /// comparisons against an induction variable. void SimplifyIndvar::eliminateIVComparison(ICmpInst *ICmp, Value *IVOperand) { unsigned IVOperIdx = 0; @@ -188,7 +189,7 @@ void SimplifyIndvar::eliminateIVComparison(ICmpInst *ICmp, Value *IVOperand) { DeadInsts.push_back(ICmp); } -/// eliminateIVRemainder - SimplifyIVUsers helper for eliminating useless +/// SimplifyIVUsers helper for eliminating useless /// remainder operations operating on an induction variable. void SimplifyIndvar::eliminateIVRemainder(BinaryOperator *Rem, Value *IVOperand, @@ -239,7 +240,7 @@ void SimplifyIndvar::eliminateIVRemainder(BinaryOperator *Rem, DeadInsts.push_back(Rem); } -/// eliminateIVUser - Eliminate an operation that consumes a simple IV and has +/// Eliminate an operation that consumes a simple IV and has /// no observable side-effect given the range of IV values. /// IVOperand is guaranteed SCEVable, but UseInst may not be. bool SimplifyIndvar::eliminateIVUser(Instruction *UseInst, @@ -271,6 +272,120 @@ bool SimplifyIndvar::eliminateIVUser(Instruction *UseInst, return true; } +/// Annotate BO with nsw / nuw if it provably does not signed-overflow / +/// unsigned-overflow. Returns true if anything changed, false otherwise. +bool SimplifyIndvar::strengthenOverflowingOperation(BinaryOperator *BO, + Value *IVOperand) { + + // Currently we only handle instructions of the form "add <indvar> <value>" + // and "sub <indvar> <value>". + unsigned Op = BO->getOpcode(); + if (!(Op == Instruction::Add || Op == Instruction::Sub)) + return false; + + // If BO is already both nuw and nsw then there is nothing left to do + if (BO->hasNoUnsignedWrap() && BO->hasNoSignedWrap()) + return false; + + IntegerType *IT = cast<IntegerType>(IVOperand->getType()); + Value *OtherOperand = nullptr; + int OtherOperandIdx = -1; + if (BO->getOperand(0) == IVOperand) { + OtherOperand = BO->getOperand(1); + OtherOperandIdx = 1; + } else { + assert(BO->getOperand(1) == IVOperand && "only other use!"); + OtherOperand = BO->getOperand(0); + OtherOperandIdx = 0; + } + + bool Changed = false; + const SCEV *OtherOpSCEV = SE->getSCEV(OtherOperand); + if (OtherOpSCEV == SE->getCouldNotCompute()) + return false; + + if (Op == Instruction::Sub) { + // If the subtraction is of the form "sub <indvar>, <op>", then pretend it + // is "add <indvar>, -<op>" and continue, else bail out. + if (OtherOperandIdx != 1) + return false; + + OtherOpSCEV = SE->getNegativeSCEV(OtherOpSCEV); + } + + const SCEV *IVOpSCEV = SE->getSCEV(IVOperand); + const SCEV *ZeroSCEV = SE->getConstant(IVOpSCEV->getType(), 0); + + if (!BO->hasNoSignedWrap()) { + // Upgrade the add to an "add nsw" if we can prove that it will never + // sign-overflow or sign-underflow. + + const SCEV *SignedMax = + SE->getConstant(APInt::getSignedMaxValue(IT->getBitWidth())); + const SCEV *SignedMin = + SE->getConstant(APInt::getSignedMinValue(IT->getBitWidth())); + + // The addition "IVOperand + OtherOp" does not sign-overflow if the result + // is sign-representable in 2's complement in the given bit-width. + // + // If OtherOp is SLT 0, then for an IVOperand in [SignedMin - OtherOp, + // SignedMax], "IVOperand + OtherOp" is in [SignedMin, SignedMax + OtherOp]. + // Everything in [SignedMin, SignedMax + OtherOp] is representable since + // SignedMax + OtherOp is at least -1. + // + // If OtherOp is SGE 0, then for an IVOperand in [SignedMin, SignedMax - + // OtherOp], "IVOperand + OtherOp" is in [SignedMin + OtherOp, SignedMax]. + // Everything in [SignedMin + OtherOp, SignedMax] is representable since + // SignedMin + OtherOp is at most -1. + // + // It follows that for all values of IVOperand in [SignedMin - smin(0, + // OtherOp), SignedMax - smax(0, OtherOp)] the result of the add is + // representable (i.e. there is no sign-overflow). + + const SCEV *UpperDelta = SE->getSMaxExpr(ZeroSCEV, OtherOpSCEV); + const SCEV *UpperLimit = SE->getMinusSCEV(SignedMax, UpperDelta); + + bool NeverSignedOverflows = + SE->isKnownPredicate(ICmpInst::ICMP_SLE, IVOpSCEV, UpperLimit); + + if (NeverSignedOverflows) { + const SCEV *LowerDelta = SE->getSMinExpr(ZeroSCEV, OtherOpSCEV); + const SCEV *LowerLimit = SE->getMinusSCEV(SignedMin, LowerDelta); + + bool NeverSignedUnderflows = + SE->isKnownPredicate(ICmpInst::ICMP_SGE, IVOpSCEV, LowerLimit); + if (NeverSignedUnderflows) { + BO->setHasNoSignedWrap(true); + Changed = true; + } + } + } + + if (!BO->hasNoUnsignedWrap()) { + // Upgrade the add computing "IVOperand + OtherOp" to an "add nuw" if we can + // prove that it will never unsigned-overflow (i.e. the result will always + // be representable in the given bit-width). + // + // "IVOperand + OtherOp" is unsigned-representable in 2's complement iff it + // does not produce a carry. "IVOperand + OtherOp" produces no carry iff + // IVOperand ULE (UnsignedMax - OtherOp). + + const SCEV *UnsignedMax = + SE->getConstant(APInt::getMaxValue(IT->getBitWidth())); + const SCEV *UpperLimit = SE->getMinusSCEV(UnsignedMax, OtherOpSCEV); + + bool NeverUnsignedOverflows = + SE->isKnownPredicate(ICmpInst::ICMP_ULE, IVOpSCEV, UpperLimit); + + if (NeverUnsignedOverflows) { + BO->setHasNoUnsignedWrap(true); + Changed = true; + } + } + + return Changed; +} + /// \brief Split sadd.with.overflow into add + sadd.with.overflow to allow /// analysis and optimization. /// @@ -334,8 +449,7 @@ Instruction *SimplifyIndvar::splitOverflowIntrinsic(Instruction *IVUser, return AddInst; } -/// pushIVUsers - Add all uses of Def to the current IV's worklist. -/// +/// Add all uses of Def to the current IV's worklist. static void pushIVUsers( Instruction *Def, SmallPtrSet<Instruction*,16> &Simplified, @@ -348,12 +462,12 @@ static void pushIVUsers( // Also ensure unique worklist users. // If Def is a LoopPhi, it may not be in the Simplified set, so check for // self edges first. - if (UI != Def && Simplified.insert(UI)) + if (UI != Def && Simplified.insert(UI).second) SimpleIVUsers.push_back(std::make_pair(UI, Def)); } } -/// isSimpleIVUser - Return true if this instruction generates a simple SCEV +/// Return true if this instruction generates a simple SCEV /// expression in terms of that IV. /// /// This is similar to IVUsers' isInteresting() but processes each instruction @@ -374,7 +488,7 @@ static bool isSimpleIVUser(Instruction *I, const Loop *L, ScalarEvolution *SE) { return false; } -/// simplifyUsers - Iteratively perform simplification on a worklist of users +/// Iteratively perform simplification on a worklist of users /// of the specified induction variable. Each successive simplification may push /// more users which may themselves be candidates for simplification. /// @@ -431,6 +545,16 @@ void SimplifyIndvar::simplifyUsers(PHINode *CurrIV, IVVisitor *V) { pushIVUsers(IVOperand, Simplified, SimpleIVUsers); continue; } + + if (BinaryOperator *BO = dyn_cast<BinaryOperator>(UseOper.first)) { + if (isa<OverflowingBinaryOperator>(BO) && + strengthenOverflowingOperation(BO, IVOperand)) { + // re-queue uses of the now modified binary operator and fall + // through to the checks that remain. + pushIVUsers(IVOperand, Simplified, SimpleIVUsers); + } + } + CastInst *Cast = dyn_cast<CastInst>(UseOper.first); if (V && Cast) { V->visitCast(Cast); @@ -446,7 +570,7 @@ namespace llvm { void IVVisitor::anchor() { } -/// simplifyUsersOfIV - Simplify instructions that use this induction variable +/// Simplify instructions that use this induction variable /// by using ScalarEvolution to analyze the IV's recurrence. bool simplifyUsersOfIV(PHINode *CurrIV, ScalarEvolution *SE, LPPassManager *LPM, SmallVectorImpl<WeakVH> &Dead, IVVisitor *V) @@ -457,7 +581,7 @@ bool simplifyUsersOfIV(PHINode *CurrIV, ScalarEvolution *SE, LPPassManager *LPM, return SIV.hasChanged(); } -/// simplifyLoopIVs - Simplify users of induction variables within this +/// Simplify users of induction variables within this /// loop. This does not actually change or add IVs. bool simplifyLoopIVs(Loop *L, ScalarEvolution *SE, LPPassManager *LPM, SmallVectorImpl<WeakVH> &Dead) { diff --git a/lib/Transforms/Utils/SimplifyInstructions.cpp b/lib/Transforms/Utils/SimplifyInstructions.cpp index 33b36378027d..cc97098d010a 100644 --- a/lib/Transforms/Utils/SimplifyInstructions.cpp +++ b/lib/Transforms/Utils/SimplifyInstructions.cpp @@ -18,6 +18,7 @@ #include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Dominators.h" @@ -41,6 +42,7 @@ namespace { void getAnalysisUsage(AnalysisUsage &AU) const override { AU.setPreservesCFG(); + AU.addRequired<AssumptionCacheTracker>(); AU.addRequired<TargetLibraryInfo>(); } @@ -52,6 +54,8 @@ namespace { DataLayoutPass *DLP = getAnalysisIfAvailable<DataLayoutPass>(); const DataLayout *DL = DLP ? &DLP->getDataLayout() : nullptr; const TargetLibraryInfo *TLI = &getAnalysis<TargetLibraryInfo>(); + AssumptionCache *AC = + &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); SmallPtrSet<const Instruction*, 8> S1, S2, *ToSimplify = &S1, *Next = &S2; bool Changed = false; @@ -68,7 +72,7 @@ namespace { continue; // Don't waste time simplifying unused instructions. if (!I->use_empty()) - if (Value *V = SimplifyInstruction(I, DL, TLI, DT)) { + if (Value *V = SimplifyInstruction(I, DL, TLI, DT, AC)) { // Mark all uses for resimplification next time round the loop. for (User *U : I->users()) Next->insert(cast<Instruction>(U)); @@ -101,6 +105,7 @@ namespace { char InstSimplifier::ID = 0; INITIALIZE_PASS_BEGIN(InstSimplifier, "instsimplify", "Remove redundant instructions", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfo) INITIALIZE_PASS_END(InstSimplifier, "instsimplify", "Remove redundant instructions", false, false) diff --git a/lib/Transforms/Utils/SimplifyLibCalls.cpp b/lib/Transforms/Utils/SimplifyLibCalls.cpp index 3b61bb575a8d..5b4647ddcb5e 100644 --- a/lib/Transforms/Utils/SimplifyLibCalls.cpp +++ b/lib/Transforms/Utils/SimplifyLibCalls.cpp @@ -27,65 +27,43 @@ #include "llvm/IR/Intrinsics.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/Support/Allocator.h" #include "llvm/Support/CommandLine.h" #include "llvm/Target/TargetLibraryInfo.h" #include "llvm/Transforms/Utils/BuildLibCalls.h" using namespace llvm; +using namespace PatternMatch; static cl::opt<bool> -ColdErrorCalls("error-reporting-is-cold", cl::init(true), - cl::Hidden, cl::desc("Treat error-reporting calls as cold")); - -/// This class is the abstract base class for the set of optimizations that -/// corresponds to one library call. -namespace { -class LibCallOptimization { -protected: - Function *Caller; - const DataLayout *DL; - const TargetLibraryInfo *TLI; - const LibCallSimplifier *LCS; - LLVMContext* Context; -public: - LibCallOptimization() { } - virtual ~LibCallOptimization() {} - - /// callOptimizer - This pure virtual method is implemented by base classes to - /// do various optimizations. If this returns null then no transformation was - /// performed. If it returns CI, then it transformed the call and CI is to be - /// deleted. If it returns something else, replace CI with the new value and - /// delete CI. - virtual Value *callOptimizer(Function *Callee, CallInst *CI, IRBuilder<> &B) - =0; - - /// ignoreCallingConv - Returns false if this transformation could possibly - /// change the calling convention. - virtual bool ignoreCallingConv() { return false; } - - Value *optimizeCall(CallInst *CI, const DataLayout *DL, - const TargetLibraryInfo *TLI, - const LibCallSimplifier *LCS, IRBuilder<> &B) { - Caller = CI->getParent()->getParent(); - this->DL = DL; - this->TLI = TLI; - this->LCS = LCS; - if (CI->getCalledFunction()) - Context = &CI->getCalledFunction()->getContext(); + ColdErrorCalls("error-reporting-is-cold", cl::init(true), cl::Hidden, + cl::desc("Treat error-reporting calls as cold")); - // We never change the calling convention. - if (!ignoreCallingConv() && CI->getCallingConv() != llvm::CallingConv::C) - return nullptr; +static cl::opt<bool> + EnableUnsafeFPShrink("enable-double-float-shrink", cl::Hidden, + cl::init(false), + cl::desc("Enable unsafe double to float " + "shrinking for math lib calls")); - return callOptimizer(CI->getCalledFunction(), CI, B); - } -}; //===----------------------------------------------------------------------===// // Helper Functions //===----------------------------------------------------------------------===// +static bool ignoreCallingConv(LibFunc::Func Func) { + switch (Func) { + case LibFunc::abs: + case LibFunc::labs: + case LibFunc::llabs: + case LibFunc::strlen: + return true; + default: + return false; + } + llvm_unreachable("All cases should be covered in the switch."); +} + /// isOnlyUsedInZeroEqualityComparison - Return true if it only matters that the /// value is equal or not-equal to zero. static bool isOnlyUsedInZeroEqualityComparison(Value *V) { @@ -138,1908 +116,1739 @@ static bool hasUnaryFloatFn(const TargetLibraryInfo *TLI, Type *Ty, } } -//===----------------------------------------------------------------------===// -// Fortified Library Call Optimizations -//===----------------------------------------------------------------------===// - -struct FortifiedLibCallOptimization : public LibCallOptimization { -protected: - virtual bool isFoldable(unsigned SizeCIOp, unsigned SizeArgOp, - bool isString) const = 0; -}; - -struct InstFortifiedLibCallOptimization : public FortifiedLibCallOptimization { - CallInst *CI; - - bool isFoldable(unsigned SizeCIOp, unsigned SizeArgOp, - bool isString) const override { - if (CI->getArgOperand(SizeCIOp) == CI->getArgOperand(SizeArgOp)) - return true; - if (ConstantInt *SizeCI = - dyn_cast<ConstantInt>(CI->getArgOperand(SizeCIOp))) { - if (SizeCI->isAllOnesValue()) - return true; - if (isString) { - uint64_t Len = GetStringLength(CI->getArgOperand(SizeArgOp)); - // If the length is 0 we don't know how long it is and so we can't - // remove the check. - if (Len == 0) return false; - return SizeCI->getZExtValue() >= Len; - } - if (ConstantInt *Arg = dyn_cast<ConstantInt>( - CI->getArgOperand(SizeArgOp))) - return SizeCI->getZExtValue() >= Arg->getZExtValue(); - } +/// \brief Returns whether \p F matches the signature expected for the +/// string/memory copying library function \p Func. +/// Acceptable functions are st[rp][n]?cpy, memove, memcpy, and memset. +/// Their fortified (_chk) counterparts are also accepted. +static bool checkStringCopyLibFuncSignature(Function *F, LibFunc::Func Func, + const DataLayout *DL) { + FunctionType *FT = F->getFunctionType(); + LLVMContext &Context = F->getContext(); + Type *PCharTy = Type::getInt8PtrTy(Context); + Type *SizeTTy = DL ? DL->getIntPtrType(Context) : nullptr; + unsigned NumParams = FT->getNumParams(); + + // All string libfuncs return the same type as the first parameter. + if (FT->getReturnType() != FT->getParamType(0)) return false; - } -}; - -struct MemCpyChkOpt : public InstFortifiedLibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - this->CI = CI; - FunctionType *FT = Callee->getFunctionType(); - LLVMContext &Context = CI->getParent()->getContext(); - - // Check if this has the right signature. - if (FT->getNumParams() != 4 || FT->getReturnType() != FT->getParamType(0) || - !FT->getParamType(0)->isPointerTy() || - !FT->getParamType(1)->isPointerTy() || - FT->getParamType(2) != DL->getIntPtrType(Context) || - FT->getParamType(3) != DL->getIntPtrType(Context)) - return nullptr; - if (isFoldable(3, 2, false)) { - B.CreateMemCpy(CI->getArgOperand(0), CI->getArgOperand(1), - CI->getArgOperand(2), 1); - return CI->getArgOperand(0); - } - return nullptr; + switch (Func) { + default: + llvm_unreachable("Can't check signature for non-string-copy libfunc."); + case LibFunc::stpncpy_chk: + case LibFunc::strncpy_chk: + --NumParams; // fallthrough + case LibFunc::stpncpy: + case LibFunc::strncpy: { + if (NumParams != 3 || FT->getParamType(0) != FT->getParamType(1) || + FT->getParamType(0) != PCharTy || !FT->getParamType(2)->isIntegerTy()) + return false; + break; + } + case LibFunc::strcpy_chk: + case LibFunc::stpcpy_chk: + --NumParams; // fallthrough + case LibFunc::stpcpy: + case LibFunc::strcpy: { + if (NumParams != 2 || FT->getParamType(0) != FT->getParamType(1) || + FT->getParamType(0) != PCharTy) + return false; + break; + } + case LibFunc::memmove_chk: + case LibFunc::memcpy_chk: + --NumParams; // fallthrough + case LibFunc::memmove: + case LibFunc::memcpy: { + if (NumParams != 3 || !FT->getParamType(0)->isPointerTy() || + !FT->getParamType(1)->isPointerTy() || FT->getParamType(2) != SizeTTy) + return false; + break; } -}; - -struct MemMoveChkOpt : public InstFortifiedLibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - this->CI = CI; - FunctionType *FT = Callee->getFunctionType(); - LLVMContext &Context = CI->getParent()->getContext(); + case LibFunc::memset_chk: + --NumParams; // fallthrough + case LibFunc::memset: { + if (NumParams != 3 || !FT->getParamType(0)->isPointerTy() || + !FT->getParamType(1)->isIntegerTy() || FT->getParamType(2) != SizeTTy) + return false; + break; + } + } + // If this is a fortified libcall, the last parameter is a size_t. + if (NumParams == FT->getNumParams() - 1) + return FT->getParamType(FT->getNumParams() - 1) == SizeTTy; + return true; +} - // Check if this has the right signature. - if (FT->getNumParams() != 4 || FT->getReturnType() != FT->getParamType(0) || - !FT->getParamType(0)->isPointerTy() || - !FT->getParamType(1)->isPointerTy() || - FT->getParamType(2) != DL->getIntPtrType(Context) || - FT->getParamType(3) != DL->getIntPtrType(Context)) - return nullptr; +//===----------------------------------------------------------------------===// +// String and Memory Library Call Optimizations +//===----------------------------------------------------------------------===// - if (isFoldable(3, 2, false)) { - B.CreateMemMove(CI->getArgOperand(0), CI->getArgOperand(1), - CI->getArgOperand(2), 1); - return CI->getArgOperand(0); - } +Value *LibCallSimplifier::optimizeStrCat(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + // Verify the "strcat" function prototype. + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 2|| + FT->getReturnType() != B.getInt8PtrTy() || + FT->getParamType(0) != FT->getReturnType() || + FT->getParamType(1) != FT->getReturnType()) return nullptr; - } -}; -struct MemSetChkOpt : public InstFortifiedLibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - this->CI = CI; - FunctionType *FT = Callee->getFunctionType(); - LLVMContext &Context = CI->getParent()->getContext(); - - // Check if this has the right signature. - if (FT->getNumParams() != 4 || FT->getReturnType() != FT->getParamType(0) || - !FT->getParamType(0)->isPointerTy() || - !FT->getParamType(1)->isIntegerTy() || - FT->getParamType(2) != DL->getIntPtrType(Context) || - FT->getParamType(3) != DL->getIntPtrType(Context)) - return nullptr; + // Extract some information from the instruction + Value *Dst = CI->getArgOperand(0); + Value *Src = CI->getArgOperand(1); - if (isFoldable(3, 2, false)) { - Value *Val = B.CreateIntCast(CI->getArgOperand(1), B.getInt8Ty(), - false); - B.CreateMemSet(CI->getArgOperand(0), Val, CI->getArgOperand(2), 1); - return CI->getArgOperand(0); - } + // See if we can get the length of the input string. + uint64_t Len = GetStringLength(Src); + if (Len == 0) return nullptr; - } -}; - -struct StrCpyChkOpt : public InstFortifiedLibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - this->CI = CI; - StringRef Name = Callee->getName(); - FunctionType *FT = Callee->getFunctionType(); - LLVMContext &Context = CI->getParent()->getContext(); + --Len; // Unbias length. - // Check if this has the right signature. - if (FT->getNumParams() != 3 || - FT->getReturnType() != FT->getParamType(0) || - FT->getParamType(0) != FT->getParamType(1) || - FT->getParamType(0) != Type::getInt8PtrTy(Context) || - FT->getParamType(2) != DL->getIntPtrType(Context)) - return nullptr; + // Handle the simple, do-nothing case: strcat(x, "") -> x + if (Len == 0) + return Dst; - Value *Dst = CI->getArgOperand(0), *Src = CI->getArgOperand(1); - if (Dst == Src) // __strcpy_chk(x,x) -> x - return Src; - - // If a) we don't have any length information, or b) we know this will - // fit then just lower to a plain strcpy. Otherwise we'll keep our - // strcpy_chk call which may fail at runtime if the size is too long. - // TODO: It might be nice to get a maximum length out of the possible - // string lengths for varying. - if (isFoldable(2, 1, true)) { - Value *Ret = EmitStrCpy(Dst, Src, B, DL, TLI, Name.substr(2, 6)); - return Ret; - } else { - // Maybe we can stil fold __strcpy_chk to __memcpy_chk. - uint64_t Len = GetStringLength(Src); - if (Len == 0) return nullptr; - - // This optimization require DataLayout. - if (!DL) return nullptr; - - Value *Ret = - EmitMemCpyChk(Dst, Src, - ConstantInt::get(DL->getIntPtrType(Context), Len), - CI->getArgOperand(2), B, DL, TLI); - return Ret; - } + // These optimizations require DataLayout. + if (!DL) return nullptr; - } -}; -struct StpCpyChkOpt : public InstFortifiedLibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - this->CI = CI; - StringRef Name = Callee->getName(); - FunctionType *FT = Callee->getFunctionType(); - LLVMContext &Context = CI->getParent()->getContext(); + return emitStrLenMemCpy(Src, Dst, Len, B); +} - // Check if this has the right signature. - if (FT->getNumParams() != 3 || - FT->getReturnType() != FT->getParamType(0) || - FT->getParamType(0) != FT->getParamType(1) || - FT->getParamType(0) != Type::getInt8PtrTy(Context) || - FT->getParamType(2) != DL->getIntPtrType(FT->getParamType(0))) - return nullptr; +Value *LibCallSimplifier::emitStrLenMemCpy(Value *Src, Value *Dst, uint64_t Len, + IRBuilder<> &B) { + // We need to find the end of the destination string. That's where the + // memory is to be moved to. We just generate a call to strlen. + Value *DstLen = EmitStrLen(Dst, B, DL, TLI); + if (!DstLen) + return nullptr; - Value *Dst = CI->getArgOperand(0), *Src = CI->getArgOperand(1); - if (Dst == Src) { // stpcpy(x,x) -> x+strlen(x) - Value *StrLen = EmitStrLen(Src, B, DL, TLI); - return StrLen ? B.CreateInBoundsGEP(Dst, StrLen) : nullptr; - } + // Now that we have the destination's length, we must index into the + // destination's pointer to get the actual memcpy destination (end of + // the string .. we're concatenating). + Value *CpyDst = B.CreateGEP(Dst, DstLen, "endptr"); + + // We have enough information to now generate the memcpy call to do the + // concatenation for us. Make a memcpy to copy the nul byte with align = 1. + B.CreateMemCpy( + CpyDst, Src, + ConstantInt::get(DL->getIntPtrType(Src->getContext()), Len + 1), 1); + return Dst; +} - // If a) we don't have any length information, or b) we know this will - // fit then just lower to a plain stpcpy. Otherwise we'll keep our - // stpcpy_chk call which may fail at runtime if the size is too long. - // TODO: It might be nice to get a maximum length out of the possible - // string lengths for varying. - if (isFoldable(2, 1, true)) { - Value *Ret = EmitStrCpy(Dst, Src, B, DL, TLI, Name.substr(2, 6)); - return Ret; - } else { - // Maybe we can stil fold __stpcpy_chk to __memcpy_chk. - uint64_t Len = GetStringLength(Src); - if (Len == 0) return nullptr; - - // This optimization require DataLayout. - if (!DL) return nullptr; - - Type *PT = FT->getParamType(0); - Value *LenV = ConstantInt::get(DL->getIntPtrType(PT), Len); - Value *DstEnd = B.CreateGEP(Dst, - ConstantInt::get(DL->getIntPtrType(PT), - Len - 1)); - if (!EmitMemCpyChk(Dst, Src, LenV, CI->getArgOperand(2), B, DL, TLI)) - return nullptr; - return DstEnd; - } +Value *LibCallSimplifier::optimizeStrNCat(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + // Verify the "strncat" function prototype. + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 3 || FT->getReturnType() != B.getInt8PtrTy() || + FT->getParamType(0) != FT->getReturnType() || + FT->getParamType(1) != FT->getReturnType() || + !FT->getParamType(2)->isIntegerTy()) return nullptr; - } -}; - -struct StrNCpyChkOpt : public InstFortifiedLibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - this->CI = CI; - StringRef Name = Callee->getName(); - FunctionType *FT = Callee->getFunctionType(); - LLVMContext &Context = CI->getParent()->getContext(); - // Check if this has the right signature. - if (FT->getNumParams() != 4 || FT->getReturnType() != FT->getParamType(0) || - FT->getParamType(0) != FT->getParamType(1) || - FT->getParamType(0) != Type::getInt8PtrTy(Context) || - !FT->getParamType(2)->isIntegerTy() || - FT->getParamType(3) != DL->getIntPtrType(Context)) - return nullptr; + // Extract some information from the instruction + Value *Dst = CI->getArgOperand(0); + Value *Src = CI->getArgOperand(1); + uint64_t Len; - if (isFoldable(3, 2, false)) { - Value *Ret = EmitStrNCpy(CI->getArgOperand(0), CI->getArgOperand(1), - CI->getArgOperand(2), B, DL, TLI, - Name.substr(2, 7)); - return Ret; - } + // We don't do anything if length is not constant + if (ConstantInt *LengthArg = dyn_cast<ConstantInt>(CI->getArgOperand(2))) + Len = LengthArg->getZExtValue(); + else return nullptr; - } -}; -//===----------------------------------------------------------------------===// -// String and Memory Library Call Optimizations -//===----------------------------------------------------------------------===// + // See if we can get the length of the input string. + uint64_t SrcLen = GetStringLength(Src); + if (SrcLen == 0) + return nullptr; + --SrcLen; // Unbias length. -struct StrCatOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - // Verify the "strcat" function prototype. - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 2 || - FT->getReturnType() != B.getInt8PtrTy() || - FT->getParamType(0) != FT->getReturnType() || - FT->getParamType(1) != FT->getReturnType()) - return nullptr; + // Handle the simple, do-nothing cases: + // strncat(x, "", c) -> x + // strncat(x, c, 0) -> x + if (SrcLen == 0 || Len == 0) + return Dst; - // Extract some information from the instruction - Value *Dst = CI->getArgOperand(0); - Value *Src = CI->getArgOperand(1); + // These optimizations require DataLayout. + if (!DL) + return nullptr; - // See if we can get the length of the input string. - uint64_t Len = GetStringLength(Src); - if (Len == 0) return nullptr; - --Len; // Unbias length. + // We don't optimize this case + if (Len < SrcLen) + return nullptr; - // Handle the simple, do-nothing case: strcat(x, "") -> x - if (Len == 0) - return Dst; + // strncat(x, s, c) -> strcat(x, s) + // s is constant so the strcat can be optimized further + return emitStrLenMemCpy(Src, Dst, SrcLen, B); +} - // These optimizations require DataLayout. - if (!DL) return nullptr; +Value *LibCallSimplifier::optimizeStrChr(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + // Verify the "strchr" function prototype. + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 2 || FT->getReturnType() != B.getInt8PtrTy() || + FT->getParamType(0) != FT->getReturnType() || + !FT->getParamType(1)->isIntegerTy(32)) + return nullptr; - return emitStrLenMemCpy(Src, Dst, Len, B); - } + Value *SrcStr = CI->getArgOperand(0); - Value *emitStrLenMemCpy(Value *Src, Value *Dst, uint64_t Len, - IRBuilder<> &B) { - // We need to find the end of the destination string. That's where the - // memory is to be moved to. We just generate a call to strlen. - Value *DstLen = EmitStrLen(Dst, B, DL, TLI); - if (!DstLen) + // If the second operand is non-constant, see if we can compute the length + // of the input string and turn this into memchr. + ConstantInt *CharC = dyn_cast<ConstantInt>(CI->getArgOperand(1)); + if (!CharC) { + // These optimizations require DataLayout. + if (!DL) return nullptr; - // Now that we have the destination's length, we must index into the - // destination's pointer to get the actual memcpy destination (end of - // the string .. we're concatenating). - Value *CpyDst = B.CreateGEP(Dst, DstLen, "endptr"); + uint64_t Len = GetStringLength(SrcStr); + if (Len == 0 || !FT->getParamType(1)->isIntegerTy(32)) // memchr needs i32. + return nullptr; - // We have enough information to now generate the memcpy call to do the - // concatenation for us. Make a memcpy to copy the nul byte with align = 1. - B.CreateMemCpy(CpyDst, Src, - ConstantInt::get(DL->getIntPtrType(*Context), Len + 1), 1); - return Dst; + return EmitMemChr( + SrcStr, CI->getArgOperand(1), // include nul. + ConstantInt::get(DL->getIntPtrType(CI->getContext()), Len), B, DL, TLI); } -}; - -struct StrNCatOpt : public StrCatOpt { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - // Verify the "strncat" function prototype. - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 3 || - FT->getReturnType() != B.getInt8PtrTy() || - FT->getParamType(0) != FT->getReturnType() || - FT->getParamType(1) != FT->getReturnType() || - !FT->getParamType(2)->isIntegerTy()) - return nullptr; - // Extract some information from the instruction - Value *Dst = CI->getArgOperand(0); - Value *Src = CI->getArgOperand(1); - uint64_t Len; + // Otherwise, the character is a constant, see if the first argument is + // a string literal. If so, we can constant fold. + StringRef Str; + if (!getConstantStringInfo(SrcStr, Str)) { + if (DL && CharC->isZero()) // strchr(p, 0) -> p + strlen(p) + return B.CreateGEP(SrcStr, EmitStrLen(SrcStr, B, DL, TLI), "strchr"); + return nullptr; + } - // We don't do anything if length is not constant - if (ConstantInt *LengthArg = dyn_cast<ConstantInt>(CI->getArgOperand(2))) - Len = LengthArg->getZExtValue(); - else - return nullptr; + // Compute the offset, make sure to handle the case when we're searching for + // zero (a weird way to spell strlen). + size_t I = (0xFF & CharC->getSExtValue()) == 0 + ? Str.size() + : Str.find(CharC->getSExtValue()); + if (I == StringRef::npos) // Didn't find the char. strchr returns null. + return Constant::getNullValue(CI->getType()); - // See if we can get the length of the input string. - uint64_t SrcLen = GetStringLength(Src); - if (SrcLen == 0) return nullptr; - --SrcLen; // Unbias length. + // strchr(s+n,c) -> gep(s+n+i,c) + return B.CreateGEP(SrcStr, B.getInt64(I), "strchr"); +} - // Handle the simple, do-nothing cases: - // strncat(x, "", c) -> x - // strncat(x, c, 0) -> x - if (SrcLen == 0 || Len == 0) return Dst; +Value *LibCallSimplifier::optimizeStrRChr(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + // Verify the "strrchr" function prototype. + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 2 || FT->getReturnType() != B.getInt8PtrTy() || + FT->getParamType(0) != FT->getReturnType() || + !FT->getParamType(1)->isIntegerTy(32)) + return nullptr; - // These optimizations require DataLayout. - if (!DL) return nullptr; - - // We don't optimize this case - if (Len < SrcLen) return nullptr; - - // strncat(x, s, c) -> strcat(x, s) - // s is constant so the strcat can be optimized further - return emitStrLenMemCpy(Src, Dst, SrcLen, B); - } -}; - -struct StrChrOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - // Verify the "strchr" function prototype. - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 2 || - FT->getReturnType() != B.getInt8PtrTy() || - FT->getParamType(0) != FT->getReturnType() || - !FT->getParamType(1)->isIntegerTy(32)) - return nullptr; + Value *SrcStr = CI->getArgOperand(0); + ConstantInt *CharC = dyn_cast<ConstantInt>(CI->getArgOperand(1)); - Value *SrcStr = CI->getArgOperand(0); + // Cannot fold anything if we're not looking for a constant. + if (!CharC) + return nullptr; - // If the second operand is non-constant, see if we can compute the length - // of the input string and turn this into memchr. - ConstantInt *CharC = dyn_cast<ConstantInt>(CI->getArgOperand(1)); - if (!CharC) { - // These optimizations require DataLayout. - if (!DL) return nullptr; + StringRef Str; + if (!getConstantStringInfo(SrcStr, Str)) { + // strrchr(s, 0) -> strchr(s, 0) + if (DL && CharC->isZero()) + return EmitStrChr(SrcStr, '\0', B, DL, TLI); + return nullptr; + } - uint64_t Len = GetStringLength(SrcStr); - if (Len == 0 || !FT->getParamType(1)->isIntegerTy(32))// memchr needs i32. - return nullptr; + // Compute the offset. + size_t I = (0xFF & CharC->getSExtValue()) == 0 + ? Str.size() + : Str.rfind(CharC->getSExtValue()); + if (I == StringRef::npos) // Didn't find the char. Return null. + return Constant::getNullValue(CI->getType()); - return EmitMemChr(SrcStr, CI->getArgOperand(1), // include nul. - ConstantInt::get(DL->getIntPtrType(*Context), Len), - B, DL, TLI); - } + // strrchr(s+n,c) -> gep(s+n+i,c) + return B.CreateGEP(SrcStr, B.getInt64(I), "strrchr"); +} - // Otherwise, the character is a constant, see if the first argument is - // a string literal. If so, we can constant fold. - StringRef Str; - if (!getConstantStringInfo(SrcStr, Str)) { - if (DL && CharC->isZero()) // strchr(p, 0) -> p + strlen(p) - return B.CreateGEP(SrcStr, EmitStrLen(SrcStr, B, DL, TLI), "strchr"); - return nullptr; - } +Value *LibCallSimplifier::optimizeStrCmp(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + // Verify the "strcmp" function prototype. + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 2 || !FT->getReturnType()->isIntegerTy(32) || + FT->getParamType(0) != FT->getParamType(1) || + FT->getParamType(0) != B.getInt8PtrTy()) + return nullptr; - // Compute the offset, make sure to handle the case when we're searching for - // zero (a weird way to spell strlen). - size_t I = (0xFF & CharC->getSExtValue()) == 0 ? - Str.size() : Str.find(CharC->getSExtValue()); - if (I == StringRef::npos) // Didn't find the char. strchr returns null. - return Constant::getNullValue(CI->getType()); + Value *Str1P = CI->getArgOperand(0), *Str2P = CI->getArgOperand(1); + if (Str1P == Str2P) // strcmp(x,x) -> 0 + return ConstantInt::get(CI->getType(), 0); - // strchr(s+n,c) -> gep(s+n+i,c) - return B.CreateGEP(SrcStr, B.getInt64(I), "strchr"); - } -}; + StringRef Str1, Str2; + bool HasStr1 = getConstantStringInfo(Str1P, Str1); + bool HasStr2 = getConstantStringInfo(Str2P, Str2); -struct StrRChrOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - // Verify the "strrchr" function prototype. - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 2 || - FT->getReturnType() != B.getInt8PtrTy() || - FT->getParamType(0) != FT->getReturnType() || - !FT->getParamType(1)->isIntegerTy(32)) - return nullptr; + // strcmp(x, y) -> cnst (if both x and y are constant strings) + if (HasStr1 && HasStr2) + return ConstantInt::get(CI->getType(), Str1.compare(Str2)); - Value *SrcStr = CI->getArgOperand(0); - ConstantInt *CharC = dyn_cast<ConstantInt>(CI->getArgOperand(1)); + if (HasStr1 && Str1.empty()) // strcmp("", x) -> -*x + return B.CreateNeg( + B.CreateZExt(B.CreateLoad(Str2P, "strcmpload"), CI->getType())); - // Cannot fold anything if we're not looking for a constant. - if (!CharC) - return nullptr; + if (HasStr2 && Str2.empty()) // strcmp(x,"") -> *x + return B.CreateZExt(B.CreateLoad(Str1P, "strcmpload"), CI->getType()); - StringRef Str; - if (!getConstantStringInfo(SrcStr, Str)) { - // strrchr(s, 0) -> strchr(s, 0) - if (DL && CharC->isZero()) - return EmitStrChr(SrcStr, '\0', B, DL, TLI); + // strcmp(P, "x") -> memcmp(P, "x", 2) + uint64_t Len1 = GetStringLength(Str1P); + uint64_t Len2 = GetStringLength(Str2P); + if (Len1 && Len2) { + // These optimizations require DataLayout. + if (!DL) return nullptr; - } - - // Compute the offset. - size_t I = (0xFF & CharC->getSExtValue()) == 0 ? - Str.size() : Str.rfind(CharC->getSExtValue()); - if (I == StringRef::npos) // Didn't find the char. Return null. - return Constant::getNullValue(CI->getType()); - // strrchr(s+n,c) -> gep(s+n+i,c) - return B.CreateGEP(SrcStr, B.getInt64(I), "strrchr"); + return EmitMemCmp(Str1P, Str2P, + ConstantInt::get(DL->getIntPtrType(CI->getContext()), + std::min(Len1, Len2)), + B, DL, TLI); } -}; - -struct StrCmpOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - // Verify the "strcmp" function prototype. - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 2 || - !FT->getReturnType()->isIntegerTy(32) || - FT->getParamType(0) != FT->getParamType(1) || - FT->getParamType(0) != B.getInt8PtrTy()) - return nullptr; - Value *Str1P = CI->getArgOperand(0), *Str2P = CI->getArgOperand(1); - if (Str1P == Str2P) // strcmp(x,x) -> 0 - return ConstantInt::get(CI->getType(), 0); + return nullptr; +} - StringRef Str1, Str2; - bool HasStr1 = getConstantStringInfo(Str1P, Str1); - bool HasStr2 = getConstantStringInfo(Str2P, Str2); +Value *LibCallSimplifier::optimizeStrNCmp(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + // Verify the "strncmp" function prototype. + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 3 || !FT->getReturnType()->isIntegerTy(32) || + FT->getParamType(0) != FT->getParamType(1) || + FT->getParamType(0) != B.getInt8PtrTy() || + !FT->getParamType(2)->isIntegerTy()) + return nullptr; - // strcmp(x, y) -> cnst (if both x and y are constant strings) - if (HasStr1 && HasStr2) - return ConstantInt::get(CI->getType(), Str1.compare(Str2)); + Value *Str1P = CI->getArgOperand(0), *Str2P = CI->getArgOperand(1); + if (Str1P == Str2P) // strncmp(x,x,n) -> 0 + return ConstantInt::get(CI->getType(), 0); - if (HasStr1 && Str1.empty()) // strcmp("", x) -> -*x - return B.CreateNeg(B.CreateZExt(B.CreateLoad(Str2P, "strcmpload"), - CI->getType())); + // Get the length argument if it is constant. + uint64_t Length; + if (ConstantInt *LengthArg = dyn_cast<ConstantInt>(CI->getArgOperand(2))) + Length = LengthArg->getZExtValue(); + else + return nullptr; - if (HasStr2 && Str2.empty()) // strcmp(x,"") -> *x - return B.CreateZExt(B.CreateLoad(Str1P, "strcmpload"), CI->getType()); + if (Length == 0) // strncmp(x,y,0) -> 0 + return ConstantInt::get(CI->getType(), 0); - // strcmp(P, "x") -> memcmp(P, "x", 2) - uint64_t Len1 = GetStringLength(Str1P); - uint64_t Len2 = GetStringLength(Str2P); - if (Len1 && Len2) { - // These optimizations require DataLayout. - if (!DL) return nullptr; + if (DL && Length == 1) // strncmp(x,y,1) -> memcmp(x,y,1) + return EmitMemCmp(Str1P, Str2P, CI->getArgOperand(2), B, DL, TLI); - return EmitMemCmp(Str1P, Str2P, - ConstantInt::get(DL->getIntPtrType(*Context), - std::min(Len1, Len2)), B, DL, TLI); - } + StringRef Str1, Str2; + bool HasStr1 = getConstantStringInfo(Str1P, Str1); + bool HasStr2 = getConstantStringInfo(Str2P, Str2); - return nullptr; + // strncmp(x, y) -> cnst (if both x and y are constant strings) + if (HasStr1 && HasStr2) { + StringRef SubStr1 = Str1.substr(0, Length); + StringRef SubStr2 = Str2.substr(0, Length); + return ConstantInt::get(CI->getType(), SubStr1.compare(SubStr2)); } -}; -struct StrNCmpOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - // Verify the "strncmp" function prototype. - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 3 || - !FT->getReturnType()->isIntegerTy(32) || - FT->getParamType(0) != FT->getParamType(1) || - FT->getParamType(0) != B.getInt8PtrTy() || - !FT->getParamType(2)->isIntegerTy()) - return nullptr; + if (HasStr1 && Str1.empty()) // strncmp("", x, n) -> -*x + return B.CreateNeg( + B.CreateZExt(B.CreateLoad(Str2P, "strcmpload"), CI->getType())); - Value *Str1P = CI->getArgOperand(0), *Str2P = CI->getArgOperand(1); - if (Str1P == Str2P) // strncmp(x,x,n) -> 0 - return ConstantInt::get(CI->getType(), 0); - - // Get the length argument if it is constant. - uint64_t Length; - if (ConstantInt *LengthArg = dyn_cast<ConstantInt>(CI->getArgOperand(2))) - Length = LengthArg->getZExtValue(); - else - return nullptr; + if (HasStr2 && Str2.empty()) // strncmp(x, "", n) -> *x + return B.CreateZExt(B.CreateLoad(Str1P, "strcmpload"), CI->getType()); - if (Length == 0) // strncmp(x,y,0) -> 0 - return ConstantInt::get(CI->getType(), 0); - - if (DL && Length == 1) // strncmp(x,y,1) -> memcmp(x,y,1) - return EmitMemCmp(Str1P, Str2P, CI->getArgOperand(2), B, DL, TLI); + return nullptr; +} - StringRef Str1, Str2; - bool HasStr1 = getConstantStringInfo(Str1P, Str1); - bool HasStr2 = getConstantStringInfo(Str2P, Str2); +Value *LibCallSimplifier::optimizeStrCpy(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); - // strncmp(x, y) -> cnst (if both x and y are constant strings) - if (HasStr1 && HasStr2) { - StringRef SubStr1 = Str1.substr(0, Length); - StringRef SubStr2 = Str2.substr(0, Length); - return ConstantInt::get(CI->getType(), SubStr1.compare(SubStr2)); - } + if (!checkStringCopyLibFuncSignature(Callee, LibFunc::strcpy, DL)) + return nullptr; - if (HasStr1 && Str1.empty()) // strncmp("", x, n) -> -*x - return B.CreateNeg(B.CreateZExt(B.CreateLoad(Str2P, "strcmpload"), - CI->getType())); + Value *Dst = CI->getArgOperand(0), *Src = CI->getArgOperand(1); + if (Dst == Src) // strcpy(x,x) -> x + return Src; - if (HasStr2 && Str2.empty()) // strncmp(x, "", n) -> *x - return B.CreateZExt(B.CreateLoad(Str1P, "strcmpload"), CI->getType()); + // These optimizations require DataLayout. + if (!DL) + return nullptr; + // See if we can get the length of the input string. + uint64_t Len = GetStringLength(Src); + if (Len == 0) return nullptr; - } -}; -struct StrCpyOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - // Verify the "strcpy" function prototype. - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 2 || - FT->getReturnType() != FT->getParamType(0) || - FT->getParamType(0) != FT->getParamType(1) || - FT->getParamType(0) != B.getInt8PtrTy()) - return nullptr; + // We have enough information to now generate the memcpy call to do the + // copy for us. Make a memcpy to copy the nul byte with align = 1. + B.CreateMemCpy(Dst, Src, + ConstantInt::get(DL->getIntPtrType(CI->getContext()), Len), 1); + return Dst; +} - Value *Dst = CI->getArgOperand(0), *Src = CI->getArgOperand(1); - if (Dst == Src) // strcpy(x,x) -> x - return Src; +Value *LibCallSimplifier::optimizeStpCpy(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + // Verify the "stpcpy" function prototype. + FunctionType *FT = Callee->getFunctionType(); - // These optimizations require DataLayout. - if (!DL) return nullptr; + if (!checkStringCopyLibFuncSignature(Callee, LibFunc::stpcpy, DL)) + return nullptr; - // See if we can get the length of the input string. - uint64_t Len = GetStringLength(Src); - if (Len == 0) return nullptr; + // These optimizations require DataLayout. + if (!DL) + return nullptr; - // We have enough information to now generate the memcpy call to do the - // copy for us. Make a memcpy to copy the nul byte with align = 1. - B.CreateMemCpy(Dst, Src, - ConstantInt::get(DL->getIntPtrType(*Context), Len), 1); - return Dst; + Value *Dst = CI->getArgOperand(0), *Src = CI->getArgOperand(1); + if (Dst == Src) { // stpcpy(x,x) -> x+strlen(x) + Value *StrLen = EmitStrLen(Src, B, DL, TLI); + return StrLen ? B.CreateInBoundsGEP(Dst, StrLen) : nullptr; } -}; - -struct StpCpyOpt: public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - // Verify the "stpcpy" function prototype. - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 2 || - FT->getReturnType() != FT->getParamType(0) || - FT->getParamType(0) != FT->getParamType(1) || - FT->getParamType(0) != B.getInt8PtrTy()) - return nullptr; - - // These optimizations require DataLayout. - if (!DL) return nullptr; - - Value *Dst = CI->getArgOperand(0), *Src = CI->getArgOperand(1); - if (Dst == Src) { // stpcpy(x,x) -> x+strlen(x) - Value *StrLen = EmitStrLen(Src, B, DL, TLI); - return StrLen ? B.CreateInBoundsGEP(Dst, StrLen) : nullptr; - } - - // See if we can get the length of the input string. - uint64_t Len = GetStringLength(Src); - if (Len == 0) return nullptr; - - Type *PT = FT->getParamType(0); - Value *LenV = ConstantInt::get(DL->getIntPtrType(PT), Len); - Value *DstEnd = B.CreateGEP(Dst, - ConstantInt::get(DL->getIntPtrType(PT), - Len - 1)); - - // We have enough information to now generate the memcpy call to do the - // copy for us. Make a memcpy to copy the nul byte with align = 1. - B.CreateMemCpy(Dst, Src, LenV, 1); - return DstEnd; - } -}; - -struct StrNCpyOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 3 || FT->getReturnType() != FT->getParamType(0) || - FT->getParamType(0) != FT->getParamType(1) || - FT->getParamType(0) != B.getInt8PtrTy() || - !FT->getParamType(2)->isIntegerTy()) - return nullptr; - - Value *Dst = CI->getArgOperand(0); - Value *Src = CI->getArgOperand(1); - Value *LenOp = CI->getArgOperand(2); - // See if we can get the length of the input string. - uint64_t SrcLen = GetStringLength(Src); - if (SrcLen == 0) return nullptr; - --SrcLen; + // See if we can get the length of the input string. + uint64_t Len = GetStringLength(Src); + if (Len == 0) + return nullptr; - if (SrcLen == 0) { - // strncpy(x, "", y) -> memset(x, '\0', y, 1) - B.CreateMemSet(Dst, B.getInt8('\0'), LenOp, 1); - return Dst; - } + Type *PT = FT->getParamType(0); + Value *LenV = ConstantInt::get(DL->getIntPtrType(PT), Len); + Value *DstEnd = + B.CreateGEP(Dst, ConstantInt::get(DL->getIntPtrType(PT), Len - 1)); - uint64_t Len; - if (ConstantInt *LengthArg = dyn_cast<ConstantInt>(LenOp)) - Len = LengthArg->getZExtValue(); - else - return nullptr; + // We have enough information to now generate the memcpy call to do the + // copy for us. Make a memcpy to copy the nul byte with align = 1. + B.CreateMemCpy(Dst, Src, LenV, 1); + return DstEnd; +} - if (Len == 0) return Dst; // strncpy(x, y, 0) -> x +Value *LibCallSimplifier::optimizeStrNCpy(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + FunctionType *FT = Callee->getFunctionType(); - // These optimizations require DataLayout. - if (!DL) return nullptr; + if (!checkStringCopyLibFuncSignature(Callee, LibFunc::strncpy, DL)) + return nullptr; - // Let strncpy handle the zero padding - if (Len > SrcLen+1) return nullptr; + Value *Dst = CI->getArgOperand(0); + Value *Src = CI->getArgOperand(1); + Value *LenOp = CI->getArgOperand(2); - Type *PT = FT->getParamType(0); - // strncpy(x, s, c) -> memcpy(x, s, c, 1) [s and c are constant] - B.CreateMemCpy(Dst, Src, - ConstantInt::get(DL->getIntPtrType(PT), Len), 1); + // See if we can get the length of the input string. + uint64_t SrcLen = GetStringLength(Src); + if (SrcLen == 0) + return nullptr; + --SrcLen; + if (SrcLen == 0) { + // strncpy(x, "", y) -> memset(x, '\0', y, 1) + B.CreateMemSet(Dst, B.getInt8('\0'), LenOp, 1); return Dst; } -}; - -struct StrLenOpt : public LibCallOptimization { - bool ignoreCallingConv() override { return true; } - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 1 || - FT->getParamType(0) != B.getInt8PtrTy() || - !FT->getReturnType()->isIntegerTy()) - return nullptr; - Value *Src = CI->getArgOperand(0); - - // Constant folding: strlen("xyz") -> 3 - if (uint64_t Len = GetStringLength(Src)) - return ConstantInt::get(CI->getType(), Len-1); - - // strlen(x?"foo":"bars") --> x ? 3 : 4 - if (SelectInst *SI = dyn_cast<SelectInst>(Src)) { - uint64_t LenTrue = GetStringLength(SI->getTrueValue()); - uint64_t LenFalse = GetStringLength(SI->getFalseValue()); - if (LenTrue && LenFalse) { - emitOptimizationRemark(*Context, "simplify-libcalls", *Caller, - SI->getDebugLoc(), - "folded strlen(select) to select of constants"); - return B.CreateSelect(SI->getCondition(), - ConstantInt::get(CI->getType(), LenTrue-1), - ConstantInt::get(CI->getType(), LenFalse-1)); - } - } + uint64_t Len; + if (ConstantInt *LengthArg = dyn_cast<ConstantInt>(LenOp)) + Len = LengthArg->getZExtValue(); + else + return nullptr; - // strlen(x) != 0 --> *x != 0 - // strlen(x) == 0 --> *x == 0 - if (isOnlyUsedInZeroEqualityComparison(CI)) - return B.CreateZExt(B.CreateLoad(Src, "strlenfirst"), CI->getType()); + if (Len == 0) + return Dst; // strncpy(x, y, 0) -> x + // These optimizations require DataLayout. + if (!DL) return nullptr; - } -}; -struct StrPBrkOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 2 || - FT->getParamType(0) != B.getInt8PtrTy() || - FT->getParamType(1) != FT->getParamType(0) || - FT->getReturnType() != FT->getParamType(0)) - return nullptr; + // Let strncpy handle the zero padding + if (Len > SrcLen + 1) + return nullptr; - StringRef S1, S2; - bool HasS1 = getConstantStringInfo(CI->getArgOperand(0), S1); - bool HasS2 = getConstantStringInfo(CI->getArgOperand(1), S2); + Type *PT = FT->getParamType(0); + // strncpy(x, s, c) -> memcpy(x, s, c, 1) [s and c are constant] + B.CreateMemCpy(Dst, Src, ConstantInt::get(DL->getIntPtrType(PT), Len), 1); - // strpbrk(s, "") -> NULL - // strpbrk("", s) -> NULL - if ((HasS1 && S1.empty()) || (HasS2 && S2.empty())) - return Constant::getNullValue(CI->getType()); + return Dst; +} - // Constant folding. - if (HasS1 && HasS2) { - size_t I = S1.find_first_of(S2); - if (I == StringRef::npos) // No match. - return Constant::getNullValue(CI->getType()); +Value *LibCallSimplifier::optimizeStrLen(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 1 || FT->getParamType(0) != B.getInt8PtrTy() || + !FT->getReturnType()->isIntegerTy()) + return nullptr; - return B.CreateGEP(CI->getArgOperand(0), B.getInt64(I), "strpbrk"); + Value *Src = CI->getArgOperand(0); + + // Constant folding: strlen("xyz") -> 3 + if (uint64_t Len = GetStringLength(Src)) + return ConstantInt::get(CI->getType(), Len - 1); + + // strlen(x?"foo":"bars") --> x ? 3 : 4 + if (SelectInst *SI = dyn_cast<SelectInst>(Src)) { + uint64_t LenTrue = GetStringLength(SI->getTrueValue()); + uint64_t LenFalse = GetStringLength(SI->getFalseValue()); + if (LenTrue && LenFalse) { + Function *Caller = CI->getParent()->getParent(); + emitOptimizationRemark(CI->getContext(), "simplify-libcalls", *Caller, + SI->getDebugLoc(), + "folded strlen(select) to select of constants"); + return B.CreateSelect(SI->getCondition(), + ConstantInt::get(CI->getType(), LenTrue - 1), + ConstantInt::get(CI->getType(), LenFalse - 1)); } - - // strpbrk(s, "a") -> strchr(s, 'a') - if (DL && HasS2 && S2.size() == 1) - return EmitStrChr(CI->getArgOperand(0), S2[0], B, DL, TLI); - - return nullptr; } -}; -struct StrToOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - FunctionType *FT = Callee->getFunctionType(); - if ((FT->getNumParams() != 2 && FT->getNumParams() != 3) || - !FT->getParamType(0)->isPointerTy() || - !FT->getParamType(1)->isPointerTy()) - return nullptr; + // strlen(x) != 0 --> *x != 0 + // strlen(x) == 0 --> *x == 0 + if (isOnlyUsedInZeroEqualityComparison(CI)) + return B.CreateZExt(B.CreateLoad(Src, "strlenfirst"), CI->getType()); - Value *EndPtr = CI->getArgOperand(1); - if (isa<ConstantPointerNull>(EndPtr)) { - // With a null EndPtr, this function won't capture the main argument. - // It would be readonly too, except that it still may write to errno. - CI->addAttribute(1, Attribute::NoCapture); - } + return nullptr; +} +Value *LibCallSimplifier::optimizeStrPBrk(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 2 || FT->getParamType(0) != B.getInt8PtrTy() || + FT->getParamType(1) != FT->getParamType(0) || + FT->getReturnType() != FT->getParamType(0)) return nullptr; - } -}; -struct StrSpnOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 2 || - FT->getParamType(0) != B.getInt8PtrTy() || - FT->getParamType(1) != FT->getParamType(0) || - !FT->getReturnType()->isIntegerTy()) - return nullptr; + StringRef S1, S2; + bool HasS1 = getConstantStringInfo(CI->getArgOperand(0), S1); + bool HasS2 = getConstantStringInfo(CI->getArgOperand(1), S2); - StringRef S1, S2; - bool HasS1 = getConstantStringInfo(CI->getArgOperand(0), S1); - bool HasS2 = getConstantStringInfo(CI->getArgOperand(1), S2); + // strpbrk(s, "") -> nullptr + // strpbrk("", s) -> nullptr + if ((HasS1 && S1.empty()) || (HasS2 && S2.empty())) + return Constant::getNullValue(CI->getType()); - // strspn(s, "") -> 0 - // strspn("", s) -> 0 - if ((HasS1 && S1.empty()) || (HasS2 && S2.empty())) + // Constant folding. + if (HasS1 && HasS2) { + size_t I = S1.find_first_of(S2); + if (I == StringRef::npos) // No match. return Constant::getNullValue(CI->getType()); - // Constant folding. - if (HasS1 && HasS2) { - size_t Pos = S1.find_first_not_of(S2); - if (Pos == StringRef::npos) Pos = S1.size(); - return ConstantInt::get(CI->getType(), Pos); - } - - return nullptr; + return B.CreateGEP(CI->getArgOperand(0), B.getInt64(I), "strpbrk"); } -}; -struct StrCSpnOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 2 || - FT->getParamType(0) != B.getInt8PtrTy() || - FT->getParamType(1) != FT->getParamType(0) || - !FT->getReturnType()->isIntegerTy()) - return nullptr; + // strpbrk(s, "a") -> strchr(s, 'a') + if (DL && HasS2 && S2.size() == 1) + return EmitStrChr(CI->getArgOperand(0), S2[0], B, DL, TLI); - StringRef S1, S2; - bool HasS1 = getConstantStringInfo(CI->getArgOperand(0), S1); - bool HasS2 = getConstantStringInfo(CI->getArgOperand(1), S2); + return nullptr; +} - // strcspn("", s) -> 0 - if (HasS1 && S1.empty()) - return Constant::getNullValue(CI->getType()); +Value *LibCallSimplifier::optimizeStrTo(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + FunctionType *FT = Callee->getFunctionType(); + if ((FT->getNumParams() != 2 && FT->getNumParams() != 3) || + !FT->getParamType(0)->isPointerTy() || + !FT->getParamType(1)->isPointerTy()) + return nullptr; - // Constant folding. - if (HasS1 && HasS2) { - size_t Pos = S1.find_first_of(S2); - if (Pos == StringRef::npos) Pos = S1.size(); - return ConstantInt::get(CI->getType(), Pos); - } + Value *EndPtr = CI->getArgOperand(1); + if (isa<ConstantPointerNull>(EndPtr)) { + // With a null EndPtr, this function won't capture the main argument. + // It would be readonly too, except that it still may write to errno. + CI->addAttribute(1, Attribute::NoCapture); + } - // strcspn(s, "") -> strlen(s) - if (DL && HasS2 && S2.empty()) - return EmitStrLen(CI->getArgOperand(0), B, DL, TLI); + return nullptr; +} +Value *LibCallSimplifier::optimizeStrSpn(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 2 || FT->getParamType(0) != B.getInt8PtrTy() || + FT->getParamType(1) != FT->getParamType(0) || + !FT->getReturnType()->isIntegerTy()) return nullptr; + + StringRef S1, S2; + bool HasS1 = getConstantStringInfo(CI->getArgOperand(0), S1); + bool HasS2 = getConstantStringInfo(CI->getArgOperand(1), S2); + + // strspn(s, "") -> 0 + // strspn("", s) -> 0 + if ((HasS1 && S1.empty()) || (HasS2 && S2.empty())) + return Constant::getNullValue(CI->getType()); + + // Constant folding. + if (HasS1 && HasS2) { + size_t Pos = S1.find_first_not_of(S2); + if (Pos == StringRef::npos) + Pos = S1.size(); + return ConstantInt::get(CI->getType(), Pos); } -}; -struct StrStrOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 2 || - !FT->getParamType(0)->isPointerTy() || - !FT->getParamType(1)->isPointerTy() || - !FT->getReturnType()->isPointerTy()) - return nullptr; + return nullptr; +} - // fold strstr(x, x) -> x. - if (CI->getArgOperand(0) == CI->getArgOperand(1)) - return B.CreateBitCast(CI->getArgOperand(0), CI->getType()); +Value *LibCallSimplifier::optimizeStrCSpn(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 2 || FT->getParamType(0) != B.getInt8PtrTy() || + FT->getParamType(1) != FT->getParamType(0) || + !FT->getReturnType()->isIntegerTy()) + return nullptr; - // fold strstr(a, b) == a -> strncmp(a, b, strlen(b)) == 0 - if (DL && isOnlyUsedInEqualityComparison(CI, CI->getArgOperand(0))) { - Value *StrLen = EmitStrLen(CI->getArgOperand(1), B, DL, TLI); - if (!StrLen) - return nullptr; - Value *StrNCmp = EmitStrNCmp(CI->getArgOperand(0), CI->getArgOperand(1), - StrLen, B, DL, TLI); - if (!StrNCmp) - return nullptr; - for (auto UI = CI->user_begin(), UE = CI->user_end(); UI != UE;) { - ICmpInst *Old = cast<ICmpInst>(*UI++); - Value *Cmp = B.CreateICmp(Old->getPredicate(), StrNCmp, - ConstantInt::getNullValue(StrNCmp->getType()), - "cmp"); - LCS->replaceAllUsesWith(Old, Cmp); - } - return CI; - } + StringRef S1, S2; + bool HasS1 = getConstantStringInfo(CI->getArgOperand(0), S1); + bool HasS2 = getConstantStringInfo(CI->getArgOperand(1), S2); - // See if either input string is a constant string. - StringRef SearchStr, ToFindStr; - bool HasStr1 = getConstantStringInfo(CI->getArgOperand(0), SearchStr); - bool HasStr2 = getConstantStringInfo(CI->getArgOperand(1), ToFindStr); + // strcspn("", s) -> 0 + if (HasS1 && S1.empty()) + return Constant::getNullValue(CI->getType()); - // fold strstr(x, "") -> x. - if (HasStr2 && ToFindStr.empty()) - return B.CreateBitCast(CI->getArgOperand(0), CI->getType()); + // Constant folding. + if (HasS1 && HasS2) { + size_t Pos = S1.find_first_of(S2); + if (Pos == StringRef::npos) + Pos = S1.size(); + return ConstantInt::get(CI->getType(), Pos); + } - // If both strings are known, constant fold it. - if (HasStr1 && HasStr2) { - size_t Offset = SearchStr.find(ToFindStr); + // strcspn(s, "") -> strlen(s) + if (DL && HasS2 && S2.empty()) + return EmitStrLen(CI->getArgOperand(0), B, DL, TLI); - if (Offset == StringRef::npos) // strstr("foo", "bar") -> null - return Constant::getNullValue(CI->getType()); + return nullptr; +} - // strstr("abcd", "bc") -> gep((char*)"abcd", 1) - Value *Result = CastToCStr(CI->getArgOperand(0), B); - Result = B.CreateConstInBoundsGEP1_64(Result, Offset, "strstr"); - return B.CreateBitCast(Result, CI->getType()); - } +Value *LibCallSimplifier::optimizeStrStr(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 2 || !FT->getParamType(0)->isPointerTy() || + !FT->getParamType(1)->isPointerTy() || + !FT->getReturnType()->isPointerTy()) + return nullptr; + + // fold strstr(x, x) -> x. + if (CI->getArgOperand(0) == CI->getArgOperand(1)) + return B.CreateBitCast(CI->getArgOperand(0), CI->getType()); - // fold strstr(x, "y") -> strchr(x, 'y'). - if (HasStr2 && ToFindStr.size() == 1) { - Value *StrChr= EmitStrChr(CI->getArgOperand(0), ToFindStr[0], B, DL, TLI); - return StrChr ? B.CreateBitCast(StrChr, CI->getType()) : nullptr; + // fold strstr(a, b) == a -> strncmp(a, b, strlen(b)) == 0 + if (DL && isOnlyUsedInEqualityComparison(CI, CI->getArgOperand(0))) { + Value *StrLen = EmitStrLen(CI->getArgOperand(1), B, DL, TLI); + if (!StrLen) + return nullptr; + Value *StrNCmp = EmitStrNCmp(CI->getArgOperand(0), CI->getArgOperand(1), + StrLen, B, DL, TLI); + if (!StrNCmp) + return nullptr; + for (auto UI = CI->user_begin(), UE = CI->user_end(); UI != UE;) { + ICmpInst *Old = cast<ICmpInst>(*UI++); + Value *Cmp = + B.CreateICmp(Old->getPredicate(), StrNCmp, + ConstantInt::getNullValue(StrNCmp->getType()), "cmp"); + replaceAllUsesWith(Old, Cmp); } - return nullptr; + return CI; } -}; -struct MemCmpOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 3 || !FT->getParamType(0)->isPointerTy() || - !FT->getParamType(1)->isPointerTy() || - !FT->getReturnType()->isIntegerTy(32)) - return nullptr; + // See if either input string is a constant string. + StringRef SearchStr, ToFindStr; + bool HasStr1 = getConstantStringInfo(CI->getArgOperand(0), SearchStr); + bool HasStr2 = getConstantStringInfo(CI->getArgOperand(1), ToFindStr); + + // fold strstr(x, "") -> x. + if (HasStr2 && ToFindStr.empty()) + return B.CreateBitCast(CI->getArgOperand(0), CI->getType()); - Value *LHS = CI->getArgOperand(0), *RHS = CI->getArgOperand(1); + // If both strings are known, constant fold it. + if (HasStr1 && HasStr2) { + size_t Offset = SearchStr.find(ToFindStr); - if (LHS == RHS) // memcmp(s,s,x) -> 0 + if (Offset == StringRef::npos) // strstr("foo", "bar") -> null return Constant::getNullValue(CI->getType()); - // Make sure we have a constant length. - ConstantInt *LenC = dyn_cast<ConstantInt>(CI->getArgOperand(2)); - if (!LenC) return nullptr; - uint64_t Len = LenC->getZExtValue(); + // strstr("abcd", "bc") -> gep((char*)"abcd", 1) + Value *Result = CastToCStr(CI->getArgOperand(0), B); + Result = B.CreateConstInBoundsGEP1_64(Result, Offset, "strstr"); + return B.CreateBitCast(Result, CI->getType()); + } - if (Len == 0) // memcmp(s1,s2,0) -> 0 - return Constant::getNullValue(CI->getType()); + // fold strstr(x, "y") -> strchr(x, 'y'). + if (HasStr2 && ToFindStr.size() == 1) { + Value *StrChr = EmitStrChr(CI->getArgOperand(0), ToFindStr[0], B, DL, TLI); + return StrChr ? B.CreateBitCast(StrChr, CI->getType()) : nullptr; + } + return nullptr; +} - // memcmp(S1,S2,1) -> *(unsigned char*)LHS - *(unsigned char*)RHS - if (Len == 1) { - Value *LHSV = B.CreateZExt(B.CreateLoad(CastToCStr(LHS, B), "lhsc"), - CI->getType(), "lhsv"); - Value *RHSV = B.CreateZExt(B.CreateLoad(CastToCStr(RHS, B), "rhsc"), - CI->getType(), "rhsv"); - return B.CreateSub(LHSV, RHSV, "chardiff"); - } +Value *LibCallSimplifier::optimizeMemCmp(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 3 || !FT->getParamType(0)->isPointerTy() || + !FT->getParamType(1)->isPointerTy() || + !FT->getReturnType()->isIntegerTy(32)) + return nullptr; - // Constant folding: memcmp(x, y, l) -> cnst (all arguments are constant) - StringRef LHSStr, RHSStr; - if (getConstantStringInfo(LHS, LHSStr) && - getConstantStringInfo(RHS, RHSStr)) { - // Make sure we're not reading out-of-bounds memory. - if (Len > LHSStr.size() || Len > RHSStr.size()) - return nullptr; - // Fold the memcmp and normalize the result. This way we get consistent - // results across multiple platforms. - uint64_t Ret = 0; - int Cmp = memcmp(LHSStr.data(), RHSStr.data(), Len); - if (Cmp < 0) - Ret = -1; - else if (Cmp > 0) - Ret = 1; - return ConstantInt::get(CI->getType(), Ret); - } + Value *LHS = CI->getArgOperand(0), *RHS = CI->getArgOperand(1); + if (LHS == RHS) // memcmp(s,s,x) -> 0 + return Constant::getNullValue(CI->getType()); + + // Make sure we have a constant length. + ConstantInt *LenC = dyn_cast<ConstantInt>(CI->getArgOperand(2)); + if (!LenC) return nullptr; + uint64_t Len = LenC->getZExtValue(); + + if (Len == 0) // memcmp(s1,s2,0) -> 0 + return Constant::getNullValue(CI->getType()); + + // memcmp(S1,S2,1) -> *(unsigned char*)LHS - *(unsigned char*)RHS + if (Len == 1) { + Value *LHSV = B.CreateZExt(B.CreateLoad(CastToCStr(LHS, B), "lhsc"), + CI->getType(), "lhsv"); + Value *RHSV = B.CreateZExt(B.CreateLoad(CastToCStr(RHS, B), "rhsc"), + CI->getType(), "rhsv"); + return B.CreateSub(LHSV, RHSV, "chardiff"); + } + + // Constant folding: memcmp(x, y, l) -> cnst (all arguments are constant) + StringRef LHSStr, RHSStr; + if (getConstantStringInfo(LHS, LHSStr) && + getConstantStringInfo(RHS, RHSStr)) { + // Make sure we're not reading out-of-bounds memory. + if (Len > LHSStr.size() || Len > RHSStr.size()) + return nullptr; + // Fold the memcmp and normalize the result. This way we get consistent + // results across multiple platforms. + uint64_t Ret = 0; + int Cmp = memcmp(LHSStr.data(), RHSStr.data(), Len); + if (Cmp < 0) + Ret = -1; + else if (Cmp > 0) + Ret = 1; + return ConstantInt::get(CI->getType(), Ret); } -}; -struct MemCpyOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - // These optimizations require DataLayout. - if (!DL) return nullptr; + return nullptr; +} - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 3 || FT->getReturnType() != FT->getParamType(0) || - !FT->getParamType(0)->isPointerTy() || - !FT->getParamType(1)->isPointerTy() || - FT->getParamType(2) != DL->getIntPtrType(*Context)) - return nullptr; +Value *LibCallSimplifier::optimizeMemCpy(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + // These optimizations require DataLayout. + if (!DL) + return nullptr; - // memcpy(x, y, n) -> llvm.memcpy(x, y, n, 1) - B.CreateMemCpy(CI->getArgOperand(0), CI->getArgOperand(1), - CI->getArgOperand(2), 1); - return CI->getArgOperand(0); - } -}; + if (!checkStringCopyLibFuncSignature(Callee, LibFunc::memcpy, DL)) + return nullptr; -struct MemMoveOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - // These optimizations require DataLayout. - if (!DL) return nullptr; + // memcpy(x, y, n) -> llvm.memcpy(x, y, n, 1) + B.CreateMemCpy(CI->getArgOperand(0), CI->getArgOperand(1), + CI->getArgOperand(2), 1); + return CI->getArgOperand(0); +} - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 3 || FT->getReturnType() != FT->getParamType(0) || - !FT->getParamType(0)->isPointerTy() || - !FT->getParamType(1)->isPointerTy() || - FT->getParamType(2) != DL->getIntPtrType(*Context)) - return nullptr; +Value *LibCallSimplifier::optimizeMemMove(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + // These optimizations require DataLayout. + if (!DL) + return nullptr; - // memmove(x, y, n) -> llvm.memmove(x, y, n, 1) - B.CreateMemMove(CI->getArgOperand(0), CI->getArgOperand(1), - CI->getArgOperand(2), 1); - return CI->getArgOperand(0); - } -}; + if (!checkStringCopyLibFuncSignature(Callee, LibFunc::memmove, DL)) + return nullptr; -struct MemSetOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - // These optimizations require DataLayout. - if (!DL) return nullptr; + // memmove(x, y, n) -> llvm.memmove(x, y, n, 1) + B.CreateMemMove(CI->getArgOperand(0), CI->getArgOperand(1), + CI->getArgOperand(2), 1); + return CI->getArgOperand(0); +} - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 3 || FT->getReturnType() != FT->getParamType(0) || - !FT->getParamType(0)->isPointerTy() || - !FT->getParamType(1)->isIntegerTy() || - FT->getParamType(2) != DL->getIntPtrType(FT->getParamType(0))) - return nullptr; +Value *LibCallSimplifier::optimizeMemSet(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + // These optimizations require DataLayout. + if (!DL) + return nullptr; - // memset(p, v, n) -> llvm.memset(p, v, n, 1) - Value *Val = B.CreateIntCast(CI->getArgOperand(1), B.getInt8Ty(), false); - B.CreateMemSet(CI->getArgOperand(0), Val, CI->getArgOperand(2), 1); - return CI->getArgOperand(0); - } -}; + if (!checkStringCopyLibFuncSignature(Callee, LibFunc::memset, DL)) + return nullptr; + + // memset(p, v, n) -> llvm.memset(p, v, n, 1) + Value *Val = B.CreateIntCast(CI->getArgOperand(1), B.getInt8Ty(), false); + B.CreateMemSet(CI->getArgOperand(0), Val, CI->getArgOperand(2), 1); + return CI->getArgOperand(0); +} //===----------------------------------------------------------------------===// // Math Library Optimizations //===----------------------------------------------------------------------===// +/// Return a variant of Val with float type. +/// Currently this works in two cases: If Val is an FPExtension of a float +/// value to something bigger, simply return the operand. +/// If Val is a ConstantFP but can be converted to a float ConstantFP without +/// loss of precision do so. +static Value *valueHasFloatPrecision(Value *Val) { + if (FPExtInst *Cast = dyn_cast<FPExtInst>(Val)) { + Value *Op = Cast->getOperand(0); + if (Op->getType()->isFloatTy()) + return Op; + } + if (ConstantFP *Const = dyn_cast<ConstantFP>(Val)) { + APFloat F = Const->getValueAPF(); + bool losesInfo; + (void)F.convert(APFloat::IEEEsingle, APFloat::rmNearestTiesToEven, + &losesInfo); + if (!losesInfo) + return ConstantFP::get(Const->getContext(), F); + } + return nullptr; +} + //===----------------------------------------------------------------------===// // Double -> Float Shrinking Optimizations for Unary Functions like 'floor' -struct UnaryDoubleFPOpt : public LibCallOptimization { - bool CheckRetType; - UnaryDoubleFPOpt(bool CheckReturnType): CheckRetType(CheckReturnType) {} - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 1 || !FT->getReturnType()->isDoubleTy() || - !FT->getParamType(0)->isDoubleTy()) - return nullptr; +Value *LibCallSimplifier::optimizeUnaryDoubleFP(CallInst *CI, IRBuilder<> &B, + bool CheckRetType) { + Function *Callee = CI->getCalledFunction(); + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 1 || !FT->getReturnType()->isDoubleTy() || + !FT->getParamType(0)->isDoubleTy()) + return nullptr; - if (CheckRetType) { - // Check if all the uses for function like 'sin' are converted to float. - for (User *U : CI->users()) { - FPTruncInst *Cast = dyn_cast<FPTruncInst>(U); - if (!Cast || !Cast->getType()->isFloatTy()) - return nullptr; - } + if (CheckRetType) { + // Check if all the uses for function like 'sin' are converted to float. + for (User *U : CI->users()) { + FPTruncInst *Cast = dyn_cast<FPTruncInst>(U); + if (!Cast || !Cast->getType()->isFloatTy()) + return nullptr; } + } - // If this is something like 'floor((double)floatval)', convert to floorf. - FPExtInst *Cast = dyn_cast<FPExtInst>(CI->getArgOperand(0)); - if (!Cast || !Cast->getOperand(0)->getType()->isFloatTy()) - return nullptr; + // If this is something like 'floor((double)floatval)', convert to floorf. + Value *V = valueHasFloatPrecision(CI->getArgOperand(0)); + if (V == nullptr) + return nullptr; - // floor((double)floatval) -> (double)floorf(floatval) - Value *V = Cast->getOperand(0); + // floor((double)floatval) -> (double)floorf(floatval) + if (Callee->isIntrinsic()) { + Module *M = CI->getParent()->getParent()->getParent(); + Intrinsic::ID IID = (Intrinsic::ID) Callee->getIntrinsicID(); + Function *F = Intrinsic::getDeclaration(M, IID, B.getFloatTy()); + V = B.CreateCall(F, V); + } else { + // The call is a library call rather than an intrinsic. V = EmitUnaryFloatFnCall(V, Callee->getName(), B, Callee->getAttributes()); - return B.CreateFPExt(V, B.getDoubleTy()); } -}; + + return B.CreateFPExt(V, B.getDoubleTy()); +} // Double -> Float Shrinking Optimizations for Binary Functions like 'fmin/fmax' -struct BinaryDoubleFPOpt : public LibCallOptimization { - bool CheckRetType; - BinaryDoubleFPOpt(bool CheckReturnType): CheckRetType(CheckReturnType) {} - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - FunctionType *FT = Callee->getFunctionType(); - // Just make sure this has 2 arguments of the same FP type, which match the - // result type. - if (FT->getNumParams() != 2 || FT->getReturnType() != FT->getParamType(0) || - FT->getParamType(0) != FT->getParamType(1) || - !FT->getParamType(0)->isFloatingPointTy()) - return nullptr; +Value *LibCallSimplifier::optimizeBinaryDoubleFP(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + FunctionType *FT = Callee->getFunctionType(); + // Just make sure this has 2 arguments of the same FP type, which match the + // result type. + if (FT->getNumParams() != 2 || FT->getReturnType() != FT->getParamType(0) || + FT->getParamType(0) != FT->getParamType(1) || + !FT->getParamType(0)->isFloatingPointTy()) + return nullptr; - if (CheckRetType) { - // Check if all the uses for function like 'fmin/fmax' are converted to - // float. - for (User *U : CI->users()) { - FPTruncInst *Cast = dyn_cast<FPTruncInst>(U); - if (!Cast || !Cast->getType()->isFloatTy()) - return nullptr; - } - } + // If this is something like 'fmin((double)floatval1, (double)floatval2)', + // or fmin(1.0, (double)floatval), then we convert it to fminf. + Value *V1 = valueHasFloatPrecision(CI->getArgOperand(0)); + if (V1 == nullptr) + return nullptr; + Value *V2 = valueHasFloatPrecision(CI->getArgOperand(1)); + if (V2 == nullptr) + return nullptr; - // If this is something like 'fmin((double)floatval1, (double)floatval2)', - // we convert it to fminf. - FPExtInst *Cast1 = dyn_cast<FPExtInst>(CI->getArgOperand(0)); - FPExtInst *Cast2 = dyn_cast<FPExtInst>(CI->getArgOperand(1)); - if (!Cast1 || !Cast1->getOperand(0)->getType()->isFloatTy() || - !Cast2 || !Cast2->getOperand(0)->getType()->isFloatTy()) - return nullptr; + // fmin((double)floatval1, (double)floatval2) + // -> (double)fminf(floatval1, floatval2) + // TODO: Handle intrinsics in the same way as in optimizeUnaryDoubleFP(). + Value *V = EmitBinaryFloatFnCall(V1, V2, Callee->getName(), B, + Callee->getAttributes()); + return B.CreateFPExt(V, B.getDoubleTy()); +} - // fmin((double)floatval1, (double)floatval2) - // -> (double)fmin(floatval1, floatval2) - Value *V = nullptr; - Value *V1 = Cast1->getOperand(0); - Value *V2 = Cast2->getOperand(0); - V = EmitBinaryFloatFnCall(V1, V2, Callee->getName(), B, - Callee->getAttributes()); - return B.CreateFPExt(V, B.getDoubleTy()); - } -}; - -struct UnsafeFPLibCallOptimization : public LibCallOptimization { - bool UnsafeFPShrink; - UnsafeFPLibCallOptimization(bool UnsafeFPShrink) { - this->UnsafeFPShrink = UnsafeFPShrink; - } -}; - -struct CosOpt : public UnsafeFPLibCallOptimization { - CosOpt(bool UnsafeFPShrink) : UnsafeFPLibCallOptimization(UnsafeFPShrink) {} - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - Value *Ret = nullptr; - if (UnsafeFPShrink && Callee->getName() == "cos" && - TLI->has(LibFunc::cosf)) { - UnaryDoubleFPOpt UnsafeUnaryDoubleFP(true); - Ret = UnsafeUnaryDoubleFP.callOptimizer(Callee, CI, B); - } +Value *LibCallSimplifier::optimizeCos(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + Value *Ret = nullptr; + if (UnsafeFPShrink && Callee->getName() == "cos" && TLI->has(LibFunc::cosf)) { + Ret = optimizeUnaryDoubleFP(CI, B, true); + } - FunctionType *FT = Callee->getFunctionType(); - // Just make sure this has 1 argument of FP type, which matches the - // result type. - if (FT->getNumParams() != 1 || FT->getReturnType() != FT->getParamType(0) || - !FT->getParamType(0)->isFloatingPointTy()) - return Ret; + FunctionType *FT = Callee->getFunctionType(); + // Just make sure this has 1 argument of FP type, which matches the + // result type. + if (FT->getNumParams() != 1 || FT->getReturnType() != FT->getParamType(0) || + !FT->getParamType(0)->isFloatingPointTy()) + return Ret; - // cos(-x) -> cos(x) - Value *Op1 = CI->getArgOperand(0); - if (BinaryOperator::isFNeg(Op1)) { - BinaryOperator *BinExpr = cast<BinaryOperator>(Op1); - return B.CreateCall(Callee, BinExpr->getOperand(1), "cos"); - } + // cos(-x) -> cos(x) + Value *Op1 = CI->getArgOperand(0); + if (BinaryOperator::isFNeg(Op1)) { + BinaryOperator *BinExpr = cast<BinaryOperator>(Op1); + return B.CreateCall(Callee, BinExpr->getOperand(1), "cos"); + } + return Ret; +} + +Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + + Value *Ret = nullptr; + if (UnsafeFPShrink && Callee->getName() == "pow" && TLI->has(LibFunc::powf)) { + Ret = optimizeUnaryDoubleFP(CI, B, true); + } + + FunctionType *FT = Callee->getFunctionType(); + // Just make sure this has 2 arguments of the same FP type, which match the + // result type. + if (FT->getNumParams() != 2 || FT->getReturnType() != FT->getParamType(0) || + FT->getParamType(0) != FT->getParamType(1) || + !FT->getParamType(0)->isFloatingPointTy()) return Ret; + + Value *Op1 = CI->getArgOperand(0), *Op2 = CI->getArgOperand(1); + if (ConstantFP *Op1C = dyn_cast<ConstantFP>(Op1)) { + // pow(1.0, x) -> 1.0 + if (Op1C->isExactlyValue(1.0)) + return Op1C; + // pow(2.0, x) -> exp2(x) + if (Op1C->isExactlyValue(2.0) && + hasUnaryFloatFn(TLI, Op1->getType(), LibFunc::exp2, LibFunc::exp2f, + LibFunc::exp2l)) + return EmitUnaryFloatFnCall(Op2, "exp2", B, Callee->getAttributes()); + // pow(10.0, x) -> exp10(x) + if (Op1C->isExactlyValue(10.0) && + hasUnaryFloatFn(TLI, Op1->getType(), LibFunc::exp10, LibFunc::exp10f, + LibFunc::exp10l)) + return EmitUnaryFloatFnCall(Op2, TLI->getName(LibFunc::exp10), B, + Callee->getAttributes()); + } + + ConstantFP *Op2C = dyn_cast<ConstantFP>(Op2); + if (!Op2C) + return Ret; + + if (Op2C->getValueAPF().isZero()) // pow(x, 0.0) -> 1.0 + return ConstantFP::get(CI->getType(), 1.0); + + if (Op2C->isExactlyValue(0.5) && + hasUnaryFloatFn(TLI, Op2->getType(), LibFunc::sqrt, LibFunc::sqrtf, + LibFunc::sqrtl) && + hasUnaryFloatFn(TLI, Op2->getType(), LibFunc::fabs, LibFunc::fabsf, + LibFunc::fabsl)) { + // Expand pow(x, 0.5) to (x == -infinity ? +infinity : fabs(sqrt(x))). + // This is faster than calling pow, and still handles negative zero + // and negative infinity correctly. + // TODO: In fast-math mode, this could be just sqrt(x). + // TODO: In finite-only mode, this could be just fabs(sqrt(x)). + Value *Inf = ConstantFP::getInfinity(CI->getType()); + Value *NegInf = ConstantFP::getInfinity(CI->getType(), true); + Value *Sqrt = EmitUnaryFloatFnCall(Op1, "sqrt", B, Callee->getAttributes()); + Value *FAbs = + EmitUnaryFloatFnCall(Sqrt, "fabs", B, Callee->getAttributes()); + Value *FCmp = B.CreateFCmpOEQ(Op1, NegInf); + Value *Sel = B.CreateSelect(FCmp, Inf, FAbs); + return Sel; + } + + if (Op2C->isExactlyValue(1.0)) // pow(x, 1.0) -> x + return Op1; + if (Op2C->isExactlyValue(2.0)) // pow(x, 2.0) -> x*x + return B.CreateFMul(Op1, Op1, "pow2"); + if (Op2C->isExactlyValue(-1.0)) // pow(x, -1.0) -> 1.0/x + return B.CreateFDiv(ConstantFP::get(CI->getType(), 1.0), Op1, "powrecip"); + return nullptr; +} + +Value *LibCallSimplifier::optimizeExp2(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + Function *Caller = CI->getParent()->getParent(); + + Value *Ret = nullptr; + if (UnsafeFPShrink && Callee->getName() == "exp2" && + TLI->has(LibFunc::exp2f)) { + Ret = optimizeUnaryDoubleFP(CI, B, true); } -}; - -struct PowOpt : public UnsafeFPLibCallOptimization { - PowOpt(bool UnsafeFPShrink) : UnsafeFPLibCallOptimization(UnsafeFPShrink) {} - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - Value *Ret = nullptr; - if (UnsafeFPShrink && Callee->getName() == "pow" && - TLI->has(LibFunc::powf)) { - UnaryDoubleFPOpt UnsafeUnaryDoubleFP(true); - Ret = UnsafeUnaryDoubleFP.callOptimizer(Callee, CI, B); - } - FunctionType *FT = Callee->getFunctionType(); - // Just make sure this has 2 arguments of the same FP type, which match the - // result type. - if (FT->getNumParams() != 2 || FT->getReturnType() != FT->getParamType(0) || - FT->getParamType(0) != FT->getParamType(1) || - !FT->getParamType(0)->isFloatingPointTy()) - return Ret; + FunctionType *FT = Callee->getFunctionType(); + // Just make sure this has 1 argument of FP type, which matches the + // result type. + if (FT->getNumParams() != 1 || FT->getReturnType() != FT->getParamType(0) || + !FT->getParamType(0)->isFloatingPointTy()) + return Ret; - Value *Op1 = CI->getArgOperand(0), *Op2 = CI->getArgOperand(1); - if (ConstantFP *Op1C = dyn_cast<ConstantFP>(Op1)) { - // pow(1.0, x) -> 1.0 - if (Op1C->isExactlyValue(1.0)) - return Op1C; - // pow(2.0, x) -> exp2(x) - if (Op1C->isExactlyValue(2.0) && - hasUnaryFloatFn(TLI, Op1->getType(), LibFunc::exp2, LibFunc::exp2f, - LibFunc::exp2l)) - return EmitUnaryFloatFnCall(Op2, "exp2", B, Callee->getAttributes()); - // pow(10.0, x) -> exp10(x) - if (Op1C->isExactlyValue(10.0) && - hasUnaryFloatFn(TLI, Op1->getType(), LibFunc::exp10, LibFunc::exp10f, - LibFunc::exp10l)) - return EmitUnaryFloatFnCall(Op2, TLI->getName(LibFunc::exp10), B, - Callee->getAttributes()); + Value *Op = CI->getArgOperand(0); + // Turn exp2(sitofp(x)) -> ldexp(1.0, sext(x)) if sizeof(x) <= 32 + // Turn exp2(uitofp(x)) -> ldexp(1.0, zext(x)) if sizeof(x) < 32 + LibFunc::Func LdExp = LibFunc::ldexpl; + if (Op->getType()->isFloatTy()) + LdExp = LibFunc::ldexpf; + else if (Op->getType()->isDoubleTy()) + LdExp = LibFunc::ldexp; + + if (TLI->has(LdExp)) { + Value *LdExpArg = nullptr; + if (SIToFPInst *OpC = dyn_cast<SIToFPInst>(Op)) { + if (OpC->getOperand(0)->getType()->getPrimitiveSizeInBits() <= 32) + LdExpArg = B.CreateSExt(OpC->getOperand(0), B.getInt32Ty()); + } else if (UIToFPInst *OpC = dyn_cast<UIToFPInst>(Op)) { + if (OpC->getOperand(0)->getType()->getPrimitiveSizeInBits() < 32) + LdExpArg = B.CreateZExt(OpC->getOperand(0), B.getInt32Ty()); } - ConstantFP *Op2C = dyn_cast<ConstantFP>(Op2); - if (!Op2C) return Ret; - - if (Op2C->getValueAPF().isZero()) // pow(x, 0.0) -> 1.0 - return ConstantFP::get(CI->getType(), 1.0); - - if (Op2C->isExactlyValue(0.5) && - hasUnaryFloatFn(TLI, Op2->getType(), LibFunc::sqrt, LibFunc::sqrtf, - LibFunc::sqrtl) && - hasUnaryFloatFn(TLI, Op2->getType(), LibFunc::fabs, LibFunc::fabsf, - LibFunc::fabsl)) { - // Expand pow(x, 0.5) to (x == -infinity ? +infinity : fabs(sqrt(x))). - // This is faster than calling pow, and still handles negative zero - // and negative infinity correctly. - // TODO: In fast-math mode, this could be just sqrt(x). - // TODO: In finite-only mode, this could be just fabs(sqrt(x)). - Value *Inf = ConstantFP::getInfinity(CI->getType()); - Value *NegInf = ConstantFP::getInfinity(CI->getType(), true); - Value *Sqrt = EmitUnaryFloatFnCall(Op1, "sqrt", B, - Callee->getAttributes()); - Value *FAbs = EmitUnaryFloatFnCall(Sqrt, "fabs", B, - Callee->getAttributes()); - Value *FCmp = B.CreateFCmpOEQ(Op1, NegInf); - Value *Sel = B.CreateSelect(FCmp, Inf, FAbs); - return Sel; - } + if (LdExpArg) { + Constant *One = ConstantFP::get(CI->getContext(), APFloat(1.0f)); + if (!Op->getType()->isFloatTy()) + One = ConstantExpr::getFPExtend(One, Op->getType()); + + Module *M = Caller->getParent(); + Value *Callee = + M->getOrInsertFunction(TLI->getName(LdExp), Op->getType(), + Op->getType(), B.getInt32Ty(), nullptr); + CallInst *CI = B.CreateCall2(Callee, One, LdExpArg); + if (const Function *F = dyn_cast<Function>(Callee->stripPointerCasts())) + CI->setCallingConv(F->getCallingConv()); - if (Op2C->isExactlyValue(1.0)) // pow(x, 1.0) -> x - return Op1; - if (Op2C->isExactlyValue(2.0)) // pow(x, 2.0) -> x*x - return B.CreateFMul(Op1, Op1, "pow2"); - if (Op2C->isExactlyValue(-1.0)) // pow(x, -1.0) -> 1.0/x - return B.CreateFDiv(ConstantFP::get(CI->getType(), 1.0), - Op1, "powrecip"); - return nullptr; - } -}; - -struct Exp2Opt : public UnsafeFPLibCallOptimization { - Exp2Opt(bool UnsafeFPShrink) : UnsafeFPLibCallOptimization(UnsafeFPShrink) {} - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - Value *Ret = nullptr; - if (UnsafeFPShrink && Callee->getName() == "exp2" && - TLI->has(LibFunc::exp2f)) { - UnaryDoubleFPOpt UnsafeUnaryDoubleFP(true); - Ret = UnsafeUnaryDoubleFP.callOptimizer(Callee, CI, B); + return CI; } + } + return Ret; +} - FunctionType *FT = Callee->getFunctionType(); - // Just make sure this has 1 argument of FP type, which matches the - // result type. - if (FT->getNumParams() != 1 || FT->getReturnType() != FT->getParamType(0) || - !FT->getParamType(0)->isFloatingPointTy()) - return Ret; +Value *LibCallSimplifier::optimizeFabs(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); - Value *Op = CI->getArgOperand(0); - // Turn exp2(sitofp(x)) -> ldexp(1.0, sext(x)) if sizeof(x) <= 32 - // Turn exp2(uitofp(x)) -> ldexp(1.0, zext(x)) if sizeof(x) < 32 - LibFunc::Func LdExp = LibFunc::ldexpl; - if (Op->getType()->isFloatTy()) - LdExp = LibFunc::ldexpf; - else if (Op->getType()->isDoubleTy()) - LdExp = LibFunc::ldexp; - - if (TLI->has(LdExp)) { - Value *LdExpArg = nullptr; - if (SIToFPInst *OpC = dyn_cast<SIToFPInst>(Op)) { - if (OpC->getOperand(0)->getType()->getPrimitiveSizeInBits() <= 32) - LdExpArg = B.CreateSExt(OpC->getOperand(0), B.getInt32Ty()); - } else if (UIToFPInst *OpC = dyn_cast<UIToFPInst>(Op)) { - if (OpC->getOperand(0)->getType()->getPrimitiveSizeInBits() < 32) - LdExpArg = B.CreateZExt(OpC->getOperand(0), B.getInt32Ty()); - } + Value *Ret = nullptr; + if (Callee->getName() == "fabs" && TLI->has(LibFunc::fabsf)) { + Ret = optimizeUnaryDoubleFP(CI, B, false); + } - if (LdExpArg) { - Constant *One = ConstantFP::get(*Context, APFloat(1.0f)); - if (!Op->getType()->isFloatTy()) - One = ConstantExpr::getFPExtend(One, Op->getType()); + FunctionType *FT = Callee->getFunctionType(); + // Make sure this has 1 argument of FP type which matches the result type. + if (FT->getNumParams() != 1 || FT->getReturnType() != FT->getParamType(0) || + !FT->getParamType(0)->isFloatingPointTy()) + return Ret; - Module *M = Caller->getParent(); - Value *Callee = - M->getOrInsertFunction(TLI->getName(LdExp), Op->getType(), - Op->getType(), B.getInt32Ty(), NULL); - CallInst *CI = B.CreateCall2(Callee, One, LdExpArg); - if (const Function *F = dyn_cast<Function>(Callee->stripPointerCasts())) - CI->setCallingConv(F->getCallingConv()); + Value *Op = CI->getArgOperand(0); + if (Instruction *I = dyn_cast<Instruction>(Op)) { + // Fold fabs(x * x) -> x * x; any squared FP value must already be positive. + if (I->getOpcode() == Instruction::FMul) + if (I->getOperand(0) == I->getOperand(1)) + return Op; + } + return Ret; +} - return CI; +Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + + Value *Ret = nullptr; + if (TLI->has(LibFunc::sqrtf) && (Callee->getName() == "sqrt" || + Callee->getIntrinsicID() == Intrinsic::sqrt)) + Ret = optimizeUnaryDoubleFP(CI, B, true); + + // FIXME: For finer-grain optimization, we need intrinsics to have the same + // fast-math flag decorations that are applied to FP instructions. For now, + // we have to rely on the function-level unsafe-fp-math attribute to do this + // optimization because there's no other way to express that the sqrt can be + // reassociated. + Function *F = CI->getParent()->getParent(); + if (F->hasFnAttribute("unsafe-fp-math")) { + // Check for unsafe-fp-math = true. + Attribute Attr = F->getFnAttribute("unsafe-fp-math"); + if (Attr.getValueAsString() != "true") + return Ret; + } + Value *Op = CI->getArgOperand(0); + if (Instruction *I = dyn_cast<Instruction>(Op)) { + if (I->getOpcode() == Instruction::FMul && I->hasUnsafeAlgebra()) { + // We're looking for a repeated factor in a multiplication tree, + // so we can do this fold: sqrt(x * x) -> fabs(x); + // or this fold: sqrt(x * x * y) -> fabs(x) * sqrt(y). + Value *Op0 = I->getOperand(0); + Value *Op1 = I->getOperand(1); + Value *RepeatOp = nullptr; + Value *OtherOp = nullptr; + if (Op0 == Op1) { + // Simple match: the operands of the multiply are identical. + RepeatOp = Op0; + } else { + // Look for a more complicated pattern: one of the operands is itself + // a multiply, so search for a common factor in that multiply. + // Note: We don't bother looking any deeper than this first level or for + // variations of this pattern because instcombine's visitFMUL and/or the + // reassociation pass should give us this form. + Value *OtherMul0, *OtherMul1; + if (match(Op0, m_FMul(m_Value(OtherMul0), m_Value(OtherMul1)))) { + // Pattern: sqrt((x * y) * z) + if (OtherMul0 == OtherMul1) { + // Matched: sqrt((x * x) * z) + RepeatOp = OtherMul0; + OtherOp = Op1; + } + } + } + if (RepeatOp) { + // Fast math flags for any created instructions should match the sqrt + // and multiply. + // FIXME: We're not checking the sqrt because it doesn't have + // fast-math-flags (see earlier comment). + IRBuilder<true, ConstantFolder, + IRBuilderDefaultInserter<true> >::FastMathFlagGuard Guard(B); + B.SetFastMathFlags(I->getFastMathFlags()); + // If we found a repeated factor, hoist it out of the square root and + // replace it with the fabs of that factor. + Module *M = Callee->getParent(); + Type *ArgType = Op->getType(); + Value *Fabs = Intrinsic::getDeclaration(M, Intrinsic::fabs, ArgType); + Value *FabsCall = B.CreateCall(Fabs, RepeatOp, "fabs"); + if (OtherOp) { + // If we found a non-repeated factor, we still need to get its square + // root. We then multiply that by the value that was simplified out + // of the square root calculation. + Value *Sqrt = Intrinsic::getDeclaration(M, Intrinsic::sqrt, ArgType); + Value *SqrtCall = B.CreateCall(Sqrt, OtherOp, "sqrt"); + return B.CreateFMul(FabsCall, SqrtCall); + } + return FabsCall; } } - return Ret; } -}; + return Ret; +} -struct SinCosPiOpt : public LibCallOptimization { - SinCosPiOpt() {} +static bool isTrigLibCall(CallInst *CI); +static void insertSinCosCall(IRBuilder<> &B, Function *OrigCallee, Value *Arg, + bool UseFloat, Value *&Sin, Value *&Cos, + Value *&SinCos); - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - // Make sure the prototype is as expected, otherwise the rest of the - // function is probably invalid and likely to abort. - if (!isTrigLibCall(CI)) - return nullptr; +Value *LibCallSimplifier::optimizeSinCosPi(CallInst *CI, IRBuilder<> &B) { - Value *Arg = CI->getArgOperand(0); - SmallVector<CallInst *, 1> SinCalls; - SmallVector<CallInst *, 1> CosCalls; - SmallVector<CallInst *, 1> SinCosCalls; + // Make sure the prototype is as expected, otherwise the rest of the + // function is probably invalid and likely to abort. + if (!isTrigLibCall(CI)) + return nullptr; - bool IsFloat = Arg->getType()->isFloatTy(); + Value *Arg = CI->getArgOperand(0); + SmallVector<CallInst *, 1> SinCalls; + SmallVector<CallInst *, 1> CosCalls; + SmallVector<CallInst *, 1> SinCosCalls; - // Look for all compatible sinpi, cospi and sincospi calls with the same - // argument. If there are enough (in some sense) we can make the - // substitution. - for (User *U : Arg->users()) - classifyArgUse(U, CI->getParent(), IsFloat, SinCalls, CosCalls, - SinCosCalls); + bool IsFloat = Arg->getType()->isFloatTy(); - // It's only worthwhile if both sinpi and cospi are actually used. - if (SinCosCalls.empty() && (SinCalls.empty() || CosCalls.empty())) - return nullptr; + // Look for all compatible sinpi, cospi and sincospi calls with the same + // argument. If there are enough (in some sense) we can make the + // substitution. + for (User *U : Arg->users()) + classifyArgUse(U, CI->getParent(), IsFloat, SinCalls, CosCalls, + SinCosCalls); - Value *Sin, *Cos, *SinCos; - insertSinCosCall(B, CI->getCalledFunction(), Arg, IsFloat, Sin, Cos, - SinCos); - - replaceTrigInsts(SinCalls, Sin); - replaceTrigInsts(CosCalls, Cos); - replaceTrigInsts(SinCosCalls, SinCos); - - return nullptr; - } - - bool isTrigLibCall(CallInst *CI) { - Function *Callee = CI->getCalledFunction(); - FunctionType *FT = Callee->getFunctionType(); - - // We can only hope to do anything useful if we can ignore things like errno - // and floating-point exceptions. - bool AttributesSafe = CI->hasFnAttr(Attribute::NoUnwind) && - CI->hasFnAttr(Attribute::ReadNone); - - // Other than that we need float(float) or double(double) - return AttributesSafe && FT->getNumParams() == 1 && - FT->getReturnType() == FT->getParamType(0) && - (FT->getParamType(0)->isFloatTy() || - FT->getParamType(0)->isDoubleTy()); - } - - void classifyArgUse(Value *Val, BasicBlock *BB, bool IsFloat, - SmallVectorImpl<CallInst *> &SinCalls, - SmallVectorImpl<CallInst *> &CosCalls, - SmallVectorImpl<CallInst *> &SinCosCalls) { - CallInst *CI = dyn_cast<CallInst>(Val); - - if (!CI) - return; - - Function *Callee = CI->getCalledFunction(); - StringRef FuncName = Callee->getName(); - LibFunc::Func Func; - if (!TLI->getLibFunc(FuncName, Func) || !TLI->has(Func) || - !isTrigLibCall(CI)) - return; - - if (IsFloat) { - if (Func == LibFunc::sinpif) - SinCalls.push_back(CI); - else if (Func == LibFunc::cospif) - CosCalls.push_back(CI); - else if (Func == LibFunc::sincospif_stret) - SinCosCalls.push_back(CI); - } else { - if (Func == LibFunc::sinpi) - SinCalls.push_back(CI); - else if (Func == LibFunc::cospi) - CosCalls.push_back(CI); - else if (Func == LibFunc::sincospi_stret) - SinCosCalls.push_back(CI); - } - } + // It's only worthwhile if both sinpi and cospi are actually used. + if (SinCosCalls.empty() && (SinCalls.empty() || CosCalls.empty())) + return nullptr; - void replaceTrigInsts(SmallVectorImpl<CallInst*> &Calls, Value *Res) { - for (SmallVectorImpl<CallInst*>::iterator I = Calls.begin(), - E = Calls.end(); - I != E; ++I) { - LCS->replaceAllUsesWith(*I, Res); - } - } + Value *Sin, *Cos, *SinCos; + insertSinCosCall(B, CI->getCalledFunction(), Arg, IsFloat, Sin, Cos, SinCos); - void insertSinCosCall(IRBuilder<> &B, Function *OrigCallee, Value *Arg, - bool UseFloat, Value *&Sin, Value *&Cos, - Value *&SinCos) { - Type *ArgTy = Arg->getType(); - Type *ResTy; - StringRef Name; - - Triple T(OrigCallee->getParent()->getTargetTriple()); - if (UseFloat) { - Name = "__sincospif_stret"; - - assert(T.getArch() != Triple::x86 && "x86 messy and unsupported for now"); - // x86_64 can't use {float, float} since that would be returned in both - // xmm0 and xmm1, which isn't what a real struct would do. - ResTy = T.getArch() == Triple::x86_64 - ? static_cast<Type *>(VectorType::get(ArgTy, 2)) - : static_cast<Type *>(StructType::get(ArgTy, ArgTy, NULL)); - } else { - Name = "__sincospi_stret"; - ResTy = StructType::get(ArgTy, ArgTy, NULL); - } + replaceTrigInsts(SinCalls, Sin); + replaceTrigInsts(CosCalls, Cos); + replaceTrigInsts(SinCosCalls, SinCos); - Module *M = OrigCallee->getParent(); - Value *Callee = M->getOrInsertFunction(Name, OrigCallee->getAttributes(), - ResTy, ArgTy, NULL); - - if (Instruction *ArgInst = dyn_cast<Instruction>(Arg)) { - // If the argument is an instruction, it must dominate all uses so put our - // sincos call there. - BasicBlock::iterator Loc = ArgInst; - B.SetInsertPoint(ArgInst->getParent(), ++Loc); - } else { - // Otherwise (e.g. for a constant) the beginning of the function is as - // good a place as any. - BasicBlock &EntryBB = B.GetInsertBlock()->getParent()->getEntryBlock(); - B.SetInsertPoint(&EntryBB, EntryBB.begin()); - } + return nullptr; +} - SinCos = B.CreateCall(Callee, Arg, "sincospi"); +static bool isTrigLibCall(CallInst *CI) { + Function *Callee = CI->getCalledFunction(); + FunctionType *FT = Callee->getFunctionType(); + + // We can only hope to do anything useful if we can ignore things like errno + // and floating-point exceptions. + bool AttributesSafe = + CI->hasFnAttr(Attribute::NoUnwind) && CI->hasFnAttr(Attribute::ReadNone); + + // Other than that we need float(float) or double(double) + return AttributesSafe && FT->getNumParams() == 1 && + FT->getReturnType() == FT->getParamType(0) && + (FT->getParamType(0)->isFloatTy() || + FT->getParamType(0)->isDoubleTy()); +} - if (SinCos->getType()->isStructTy()) { - Sin = B.CreateExtractValue(SinCos, 0, "sinpi"); - Cos = B.CreateExtractValue(SinCos, 1, "cospi"); - } else { - Sin = B.CreateExtractElement(SinCos, ConstantInt::get(B.getInt32Ty(), 0), - "sinpi"); - Cos = B.CreateExtractElement(SinCos, ConstantInt::get(B.getInt32Ty(), 1), - "cospi"); - } +void +LibCallSimplifier::classifyArgUse(Value *Val, BasicBlock *BB, bool IsFloat, + SmallVectorImpl<CallInst *> &SinCalls, + SmallVectorImpl<CallInst *> &CosCalls, + SmallVectorImpl<CallInst *> &SinCosCalls) { + CallInst *CI = dyn_cast<CallInst>(Val); + + if (!CI) + return; + + Function *Callee = CI->getCalledFunction(); + StringRef FuncName = Callee->getName(); + LibFunc::Func Func; + if (!TLI->getLibFunc(FuncName, Func) || !TLI->has(Func) || !isTrigLibCall(CI)) + return; + + if (IsFloat) { + if (Func == LibFunc::sinpif) + SinCalls.push_back(CI); + else if (Func == LibFunc::cospif) + CosCalls.push_back(CI); + else if (Func == LibFunc::sincospif_stret) + SinCosCalls.push_back(CI); + } else { + if (Func == LibFunc::sinpi) + SinCalls.push_back(CI); + else if (Func == LibFunc::cospi) + CosCalls.push_back(CI); + else if (Func == LibFunc::sincospi_stret) + SinCosCalls.push_back(CI); + } +} + +void LibCallSimplifier::replaceTrigInsts(SmallVectorImpl<CallInst *> &Calls, + Value *Res) { + for (SmallVectorImpl<CallInst *>::iterator I = Calls.begin(), E = Calls.end(); + I != E; ++I) { + replaceAllUsesWith(*I, Res); } +} -}; +void insertSinCosCall(IRBuilder<> &B, Function *OrigCallee, Value *Arg, + bool UseFloat, Value *&Sin, Value *&Cos, Value *&SinCos) { + Type *ArgTy = Arg->getType(); + Type *ResTy; + StringRef Name; + + Triple T(OrigCallee->getParent()->getTargetTriple()); + if (UseFloat) { + Name = "__sincospif_stret"; + + assert(T.getArch() != Triple::x86 && "x86 messy and unsupported for now"); + // x86_64 can't use {float, float} since that would be returned in both + // xmm0 and xmm1, which isn't what a real struct would do. + ResTy = T.getArch() == Triple::x86_64 + ? static_cast<Type *>(VectorType::get(ArgTy, 2)) + : static_cast<Type *>(StructType::get(ArgTy, ArgTy, nullptr)); + } else { + Name = "__sincospi_stret"; + ResTy = StructType::get(ArgTy, ArgTy, nullptr); + } + + Module *M = OrigCallee->getParent(); + Value *Callee = M->getOrInsertFunction(Name, OrigCallee->getAttributes(), + ResTy, ArgTy, nullptr); + + if (Instruction *ArgInst = dyn_cast<Instruction>(Arg)) { + // If the argument is an instruction, it must dominate all uses so put our + // sincos call there. + BasicBlock::iterator Loc = ArgInst; + B.SetInsertPoint(ArgInst->getParent(), ++Loc); + } else { + // Otherwise (e.g. for a constant) the beginning of the function is as + // good a place as any. + BasicBlock &EntryBB = B.GetInsertBlock()->getParent()->getEntryBlock(); + B.SetInsertPoint(&EntryBB, EntryBB.begin()); + } + + SinCos = B.CreateCall(Callee, Arg, "sincospi"); + + if (SinCos->getType()->isStructTy()) { + Sin = B.CreateExtractValue(SinCos, 0, "sinpi"); + Cos = B.CreateExtractValue(SinCos, 1, "cospi"); + } else { + Sin = B.CreateExtractElement(SinCos, ConstantInt::get(B.getInt32Ty(), 0), + "sinpi"); + Cos = B.CreateExtractElement(SinCos, ConstantInt::get(B.getInt32Ty(), 1), + "cospi"); + } +} //===----------------------------------------------------------------------===// // Integer Library Call Optimizations //===----------------------------------------------------------------------===// -struct FFSOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - FunctionType *FT = Callee->getFunctionType(); - // Just make sure this has 2 arguments of the same FP type, which match the - // result type. - if (FT->getNumParams() != 1 || - !FT->getReturnType()->isIntegerTy(32) || - !FT->getParamType(0)->isIntegerTy()) - return nullptr; +Value *LibCallSimplifier::optimizeFFS(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + FunctionType *FT = Callee->getFunctionType(); + // Just make sure this has 2 arguments of the same FP type, which match the + // result type. + if (FT->getNumParams() != 1 || !FT->getReturnType()->isIntegerTy(32) || + !FT->getParamType(0)->isIntegerTy()) + return nullptr; - Value *Op = CI->getArgOperand(0); + Value *Op = CI->getArgOperand(0); - // Constant fold. - if (ConstantInt *CI = dyn_cast<ConstantInt>(Op)) { - if (CI->isZero()) // ffs(0) -> 0. - return B.getInt32(0); - // ffs(c) -> cttz(c)+1 - return B.getInt32(CI->getValue().countTrailingZeros() + 1); - } + // Constant fold. + if (ConstantInt *CI = dyn_cast<ConstantInt>(Op)) { + if (CI->isZero()) // ffs(0) -> 0. + return B.getInt32(0); + // ffs(c) -> cttz(c)+1 + return B.getInt32(CI->getValue().countTrailingZeros() + 1); + } - // ffs(x) -> x != 0 ? (i32)llvm.cttz(x)+1 : 0 - Type *ArgType = Op->getType(); - Value *F = Intrinsic::getDeclaration(Callee->getParent(), - Intrinsic::cttz, ArgType); - Value *V = B.CreateCall2(F, Op, B.getFalse(), "cttz"); - V = B.CreateAdd(V, ConstantInt::get(V->getType(), 1)); - V = B.CreateIntCast(V, B.getInt32Ty(), false); - - Value *Cond = B.CreateICmpNE(Op, Constant::getNullValue(ArgType)); - return B.CreateSelect(Cond, V, B.getInt32(0)); - } -}; - -struct AbsOpt : public LibCallOptimization { - bool ignoreCallingConv() override { return true; } - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - FunctionType *FT = Callee->getFunctionType(); - // We require integer(integer) where the types agree. - if (FT->getNumParams() != 1 || !FT->getReturnType()->isIntegerTy() || - FT->getParamType(0) != FT->getReturnType()) - return nullptr; + // ffs(x) -> x != 0 ? (i32)llvm.cttz(x)+1 : 0 + Type *ArgType = Op->getType(); + Value *F = + Intrinsic::getDeclaration(Callee->getParent(), Intrinsic::cttz, ArgType); + Value *V = B.CreateCall2(F, Op, B.getFalse(), "cttz"); + V = B.CreateAdd(V, ConstantInt::get(V->getType(), 1)); + V = B.CreateIntCast(V, B.getInt32Ty(), false); - // abs(x) -> x >s -1 ? x : -x - Value *Op = CI->getArgOperand(0); - Value *Pos = B.CreateICmpSGT(Op, Constant::getAllOnesValue(Op->getType()), - "ispos"); - Value *Neg = B.CreateNeg(Op, "neg"); - return B.CreateSelect(Pos, Op, Neg); - } -}; - -struct IsDigitOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - FunctionType *FT = Callee->getFunctionType(); - // We require integer(i32) - if (FT->getNumParams() != 1 || !FT->getReturnType()->isIntegerTy() || - !FT->getParamType(0)->isIntegerTy(32)) - return nullptr; + Value *Cond = B.CreateICmpNE(Op, Constant::getNullValue(ArgType)); + return B.CreateSelect(Cond, V, B.getInt32(0)); +} - // isdigit(c) -> (c-'0') <u 10 - Value *Op = CI->getArgOperand(0); - Op = B.CreateSub(Op, B.getInt32('0'), "isdigittmp"); - Op = B.CreateICmpULT(Op, B.getInt32(10), "isdigit"); - return B.CreateZExt(Op, CI->getType()); - } -}; - -struct IsAsciiOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - FunctionType *FT = Callee->getFunctionType(); - // We require integer(i32) - if (FT->getNumParams() != 1 || !FT->getReturnType()->isIntegerTy() || - !FT->getParamType(0)->isIntegerTy(32)) - return nullptr; +Value *LibCallSimplifier::optimizeAbs(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + FunctionType *FT = Callee->getFunctionType(); + // We require integer(integer) where the types agree. + if (FT->getNumParams() != 1 || !FT->getReturnType()->isIntegerTy() || + FT->getParamType(0) != FT->getReturnType()) + return nullptr; - // isascii(c) -> c <u 128 - Value *Op = CI->getArgOperand(0); - Op = B.CreateICmpULT(Op, B.getInt32(128), "isascii"); - return B.CreateZExt(Op, CI->getType()); - } -}; + // abs(x) -> x >s -1 ? x : -x + Value *Op = CI->getArgOperand(0); + Value *Pos = + B.CreateICmpSGT(Op, Constant::getAllOnesValue(Op->getType()), "ispos"); + Value *Neg = B.CreateNeg(Op, "neg"); + return B.CreateSelect(Pos, Op, Neg); +} -struct ToAsciiOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - FunctionType *FT = Callee->getFunctionType(); - // We require i32(i32) - if (FT->getNumParams() != 1 || FT->getReturnType() != FT->getParamType(0) || - !FT->getParamType(0)->isIntegerTy(32)) - return nullptr; +Value *LibCallSimplifier::optimizeIsDigit(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + FunctionType *FT = Callee->getFunctionType(); + // We require integer(i32) + if (FT->getNumParams() != 1 || !FT->getReturnType()->isIntegerTy() || + !FT->getParamType(0)->isIntegerTy(32)) + return nullptr; - // toascii(c) -> c & 0x7f - return B.CreateAnd(CI->getArgOperand(0), - ConstantInt::get(CI->getType(),0x7F)); - } -}; + // isdigit(c) -> (c-'0') <u 10 + Value *Op = CI->getArgOperand(0); + Op = B.CreateSub(Op, B.getInt32('0'), "isdigittmp"); + Op = B.CreateICmpULT(Op, B.getInt32(10), "isdigit"); + return B.CreateZExt(Op, CI->getType()); +} + +Value *LibCallSimplifier::optimizeIsAscii(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + FunctionType *FT = Callee->getFunctionType(); + // We require integer(i32) + if (FT->getNumParams() != 1 || !FT->getReturnType()->isIntegerTy() || + !FT->getParamType(0)->isIntegerTy(32)) + return nullptr; + + // isascii(c) -> c <u 128 + Value *Op = CI->getArgOperand(0); + Op = B.CreateICmpULT(Op, B.getInt32(128), "isascii"); + return B.CreateZExt(Op, CI->getType()); +} + +Value *LibCallSimplifier::optimizeToAscii(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + FunctionType *FT = Callee->getFunctionType(); + // We require i32(i32) + if (FT->getNumParams() != 1 || FT->getReturnType() != FT->getParamType(0) || + !FT->getParamType(0)->isIntegerTy(32)) + return nullptr; + + // toascii(c) -> c & 0x7f + return B.CreateAnd(CI->getArgOperand(0), + ConstantInt::get(CI->getType(), 0x7F)); +} //===----------------------------------------------------------------------===// // Formatting and IO Library Call Optimizations //===----------------------------------------------------------------------===// -struct ErrorReportingOpt : public LibCallOptimization { - ErrorReportingOpt(int S = -1) : StreamArg(S) {} +static bool isReportingError(Function *Callee, CallInst *CI, int StreamArg); - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &) override { - // Error reporting calls should be cold, mark them as such. - // This applies even to non-builtin calls: it is only a hint and applies to - // functions that the frontend might not understand as builtins. +Value *LibCallSimplifier::optimizeErrorReporting(CallInst *CI, IRBuilder<> &B, + int StreamArg) { + // Error reporting calls should be cold, mark them as such. + // This applies even to non-builtin calls: it is only a hint and applies to + // functions that the frontend might not understand as builtins. - // This heuristic was suggested in: - // Improving Static Branch Prediction in a Compiler - // Brian L. Deitrich, Ben-Chung Cheng, Wen-mei W. Hwu - // Proceedings of PACT'98, Oct. 1998, IEEE - - if (!CI->hasFnAttr(Attribute::Cold) && isReportingError(Callee, CI)) { - CI->addAttribute(AttributeSet::FunctionIndex, Attribute::Cold); - } + // This heuristic was suggested in: + // Improving Static Branch Prediction in a Compiler + // Brian L. Deitrich, Ben-Chung Cheng, Wen-mei W. Hwu + // Proceedings of PACT'98, Oct. 1998, IEEE + Function *Callee = CI->getCalledFunction(); - return nullptr; + if (!CI->hasFnAttr(Attribute::Cold) && + isReportingError(Callee, CI, StreamArg)) { + CI->addAttribute(AttributeSet::FunctionIndex, Attribute::Cold); } -protected: - bool isReportingError(Function *Callee, CallInst *CI) { - if (!ColdErrorCalls) - return false; - - if (!Callee || !Callee->isDeclaration()) - return false; + return nullptr; +} - if (StreamArg < 0) - return true; +static bool isReportingError(Function *Callee, CallInst *CI, int StreamArg) { + if (!ColdErrorCalls) + return false; - // These functions might be considered cold, but only if their stream - // argument is stderr. + if (!Callee || !Callee->isDeclaration()) + return false; - if (StreamArg >= (int) CI->getNumArgOperands()) - return false; - LoadInst *LI = dyn_cast<LoadInst>(CI->getArgOperand(StreamArg)); - if (!LI) - return false; - GlobalVariable *GV = dyn_cast<GlobalVariable>(LI->getPointerOperand()); - if (!GV || !GV->isDeclaration()) - return false; - return GV->getName() == "stderr"; - } + if (StreamArg < 0) + return true; - int StreamArg; -}; + // These functions might be considered cold, but only if their stream + // argument is stderr. -struct PrintFOpt : public LibCallOptimization { - Value *optimizeFixedFormatString(Function *Callee, CallInst *CI, - IRBuilder<> &B) { - // Check for a fixed format string. - StringRef FormatStr; - if (!getConstantStringInfo(CI->getArgOperand(0), FormatStr)) - return nullptr; + if (StreamArg >= (int)CI->getNumArgOperands()) + return false; + LoadInst *LI = dyn_cast<LoadInst>(CI->getArgOperand(StreamArg)); + if (!LI) + return false; + GlobalVariable *GV = dyn_cast<GlobalVariable>(LI->getPointerOperand()); + if (!GV || !GV->isDeclaration()) + return false; + return GV->getName() == "stderr"; +} - // Empty format string -> noop. - if (FormatStr.empty()) // Tolerate printf's declared void. - return CI->use_empty() ? (Value*)CI : - ConstantInt::get(CI->getType(), 0); +Value *LibCallSimplifier::optimizePrintFString(CallInst *CI, IRBuilder<> &B) { + // Check for a fixed format string. + StringRef FormatStr; + if (!getConstantStringInfo(CI->getArgOperand(0), FormatStr)) + return nullptr; - // Do not do any of the following transformations if the printf return value - // is used, in general the printf return value is not compatible with either - // putchar() or puts(). - if (!CI->use_empty()) - return nullptr; + // Empty format string -> noop. + if (FormatStr.empty()) // Tolerate printf's declared void. + return CI->use_empty() ? (Value *)CI : ConstantInt::get(CI->getType(), 0); - // printf("x") -> putchar('x'), even for '%'. - if (FormatStr.size() == 1) { - Value *Res = EmitPutChar(B.getInt32(FormatStr[0]), B, DL, TLI); - if (CI->use_empty() || !Res) return Res; - return B.CreateIntCast(Res, CI->getType(), true); - } + // Do not do any of the following transformations if the printf return value + // is used, in general the printf return value is not compatible with either + // putchar() or puts(). + if (!CI->use_empty()) + return nullptr; - // printf("foo\n") --> puts("foo") - if (FormatStr[FormatStr.size()-1] == '\n' && - FormatStr.find('%') == StringRef::npos) { // No format characters. - // Create a string literal with no \n on it. We expect the constant merge - // pass to be run after this pass, to merge duplicate strings. - FormatStr = FormatStr.drop_back(); - Value *GV = B.CreateGlobalString(FormatStr, "str"); - Value *NewCI = EmitPutS(GV, B, DL, TLI); - return (CI->use_empty() || !NewCI) ? - NewCI : - ConstantInt::get(CI->getType(), FormatStr.size()+1); - } + // printf("x") -> putchar('x'), even for '%'. + if (FormatStr.size() == 1) { + Value *Res = EmitPutChar(B.getInt32(FormatStr[0]), B, DL, TLI); + if (CI->use_empty() || !Res) + return Res; + return B.CreateIntCast(Res, CI->getType(), true); + } - // Optimize specific format strings. - // printf("%c", chr) --> putchar(chr) - if (FormatStr == "%c" && CI->getNumArgOperands() > 1 && - CI->getArgOperand(1)->getType()->isIntegerTy()) { - Value *Res = EmitPutChar(CI->getArgOperand(1), B, DL, TLI); + // printf("foo\n") --> puts("foo") + if (FormatStr[FormatStr.size() - 1] == '\n' && + FormatStr.find('%') == StringRef::npos) { // No format characters. + // Create a string literal with no \n on it. We expect the constant merge + // pass to be run after this pass, to merge duplicate strings. + FormatStr = FormatStr.drop_back(); + Value *GV = B.CreateGlobalString(FormatStr, "str"); + Value *NewCI = EmitPutS(GV, B, DL, TLI); + return (CI->use_empty() || !NewCI) + ? NewCI + : ConstantInt::get(CI->getType(), FormatStr.size() + 1); + } - if (CI->use_empty() || !Res) return Res; - return B.CreateIntCast(Res, CI->getType(), true); - } + // Optimize specific format strings. + // printf("%c", chr) --> putchar(chr) + if (FormatStr == "%c" && CI->getNumArgOperands() > 1 && + CI->getArgOperand(1)->getType()->isIntegerTy()) { + Value *Res = EmitPutChar(CI->getArgOperand(1), B, DL, TLI); - // printf("%s\n", str) --> puts(str) - if (FormatStr == "%s\n" && CI->getNumArgOperands() > 1 && - CI->getArgOperand(1)->getType()->isPointerTy()) { - return EmitPutS(CI->getArgOperand(1), B, DL, TLI); - } - return nullptr; + if (CI->use_empty() || !Res) + return Res; + return B.CreateIntCast(Res, CI->getType(), true); } - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - // Require one fixed pointer argument and an integer/void result. - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() < 1 || !FT->getParamType(0)->isPointerTy() || - !(FT->getReturnType()->isIntegerTy() || - FT->getReturnType()->isVoidTy())) - return nullptr; + // printf("%s\n", str) --> puts(str) + if (FormatStr == "%s\n" && CI->getNumArgOperands() > 1 && + CI->getArgOperand(1)->getType()->isPointerTy()) { + return EmitPutS(CI->getArgOperand(1), B, DL, TLI); + } + return nullptr; +} - if (Value *V = optimizeFixedFormatString(Callee, CI, B)) { - return V; - } +Value *LibCallSimplifier::optimizePrintF(CallInst *CI, IRBuilder<> &B) { - // printf(format, ...) -> iprintf(format, ...) if no floating point - // arguments. - if (TLI->has(LibFunc::iprintf) && !callHasFloatingPointArgument(CI)) { - Module *M = B.GetInsertBlock()->getParent()->getParent(); - Constant *IPrintFFn = - M->getOrInsertFunction("iprintf", FT, Callee->getAttributes()); - CallInst *New = cast<CallInst>(CI->clone()); - New->setCalledFunction(IPrintFFn); - B.Insert(New); - return New; - } + Function *Callee = CI->getCalledFunction(); + // Require one fixed pointer argument and an integer/void result. + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() < 1 || !FT->getParamType(0)->isPointerTy() || + !(FT->getReturnType()->isIntegerTy() || FT->getReturnType()->isVoidTy())) return nullptr; + + if (Value *V = optimizePrintFString(CI, B)) { + return V; } -}; -struct SPrintFOpt : public LibCallOptimization { - Value *OptimizeFixedFormatString(Function *Callee, CallInst *CI, - IRBuilder<> &B) { - // Check for a fixed format string. - StringRef FormatStr; - if (!getConstantStringInfo(CI->getArgOperand(1), FormatStr)) - return nullptr; + // printf(format, ...) -> iprintf(format, ...) if no floating point + // arguments. + if (TLI->has(LibFunc::iprintf) && !callHasFloatingPointArgument(CI)) { + Module *M = B.GetInsertBlock()->getParent()->getParent(); + Constant *IPrintFFn = + M->getOrInsertFunction("iprintf", FT, Callee->getAttributes()); + CallInst *New = cast<CallInst>(CI->clone()); + New->setCalledFunction(IPrintFFn); + B.Insert(New); + return New; + } + return nullptr; +} - // If we just have a format string (nothing else crazy) transform it. - if (CI->getNumArgOperands() == 2) { - // Make sure there's no % in the constant array. We could try to handle - // %% -> % in the future if we cared. - for (unsigned i = 0, e = FormatStr.size(); i != e; ++i) - if (FormatStr[i] == '%') - return nullptr; // we found a format specifier, bail out. - - // These optimizations require DataLayout. - if (!DL) return nullptr; - - // sprintf(str, fmt) -> llvm.memcpy(str, fmt, strlen(fmt)+1, 1) - B.CreateMemCpy(CI->getArgOperand(0), CI->getArgOperand(1), - ConstantInt::get(DL->getIntPtrType(*Context), // Copy the - FormatStr.size() + 1), 1); // nul byte. - return ConstantInt::get(CI->getType(), FormatStr.size()); - } +Value *LibCallSimplifier::optimizeSPrintFString(CallInst *CI, IRBuilder<> &B) { + // Check for a fixed format string. + StringRef FormatStr; + if (!getConstantStringInfo(CI->getArgOperand(1), FormatStr)) + return nullptr; - // The remaining optimizations require the format string to be "%s" or "%c" - // and have an extra operand. - if (FormatStr.size() != 2 || FormatStr[0] != '%' || - CI->getNumArgOperands() < 3) - return nullptr; + // If we just have a format string (nothing else crazy) transform it. + if (CI->getNumArgOperands() == 2) { + // Make sure there's no % in the constant array. We could try to handle + // %% -> % in the future if we cared. + for (unsigned i = 0, e = FormatStr.size(); i != e; ++i) + if (FormatStr[i] == '%') + return nullptr; // we found a format specifier, bail out. - // Decode the second character of the format string. - if (FormatStr[1] == 'c') { - // sprintf(dst, "%c", chr) --> *(i8*)dst = chr; *((i8*)dst+1) = 0 - if (!CI->getArgOperand(2)->getType()->isIntegerTy()) return nullptr; - Value *V = B.CreateTrunc(CI->getArgOperand(2), B.getInt8Ty(), "char"); - Value *Ptr = CastToCStr(CI->getArgOperand(0), B); - B.CreateStore(V, Ptr); - Ptr = B.CreateGEP(Ptr, B.getInt32(1), "nul"); - B.CreateStore(B.getInt8(0), Ptr); - - return ConstantInt::get(CI->getType(), 1); - } + // These optimizations require DataLayout. + if (!DL) + return nullptr; - if (FormatStr[1] == 's') { - // These optimizations require DataLayout. - if (!DL) return nullptr; + // sprintf(str, fmt) -> llvm.memcpy(str, fmt, strlen(fmt)+1, 1) + B.CreateMemCpy( + CI->getArgOperand(0), CI->getArgOperand(1), + ConstantInt::get(DL->getIntPtrType(CI->getContext()), + FormatStr.size() + 1), + 1); // Copy the null byte. + return ConstantInt::get(CI->getType(), FormatStr.size()); + } - // sprintf(dest, "%s", str) -> llvm.memcpy(dest, str, strlen(str)+1, 1) - if (!CI->getArgOperand(2)->getType()->isPointerTy()) return nullptr; + // The remaining optimizations require the format string to be "%s" or "%c" + // and have an extra operand. + if (FormatStr.size() != 2 || FormatStr[0] != '%' || + CI->getNumArgOperands() < 3) + return nullptr; - Value *Len = EmitStrLen(CI->getArgOperand(2), B, DL, TLI); - if (!Len) - return nullptr; - Value *IncLen = B.CreateAdd(Len, - ConstantInt::get(Len->getType(), 1), - "leninc"); - B.CreateMemCpy(CI->getArgOperand(0), CI->getArgOperand(2), IncLen, 1); + // Decode the second character of the format string. + if (FormatStr[1] == 'c') { + // sprintf(dst, "%c", chr) --> *(i8*)dst = chr; *((i8*)dst+1) = 0 + if (!CI->getArgOperand(2)->getType()->isIntegerTy()) + return nullptr; + Value *V = B.CreateTrunc(CI->getArgOperand(2), B.getInt8Ty(), "char"); + Value *Ptr = CastToCStr(CI->getArgOperand(0), B); + B.CreateStore(V, Ptr); + Ptr = B.CreateGEP(Ptr, B.getInt32(1), "nul"); + B.CreateStore(B.getInt8(0), Ptr); - // The sprintf result is the unincremented number of bytes in the string. - return B.CreateIntCast(Len, CI->getType(), false); - } - return nullptr; + return ConstantInt::get(CI->getType(), 1); } - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - // Require two fixed pointer arguments and an integer result. - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 2 || !FT->getParamType(0)->isPointerTy() || - !FT->getParamType(1)->isPointerTy() || - !FT->getReturnType()->isIntegerTy()) + if (FormatStr[1] == 's') { + // These optimizations require DataLayout. + if (!DL) return nullptr; - if (Value *V = OptimizeFixedFormatString(Callee, CI, B)) { - return V; - } + // sprintf(dest, "%s", str) -> llvm.memcpy(dest, str, strlen(str)+1, 1) + if (!CI->getArgOperand(2)->getType()->isPointerTy()) + return nullptr; - // sprintf(str, format, ...) -> siprintf(str, format, ...) if no floating - // point arguments. - if (TLI->has(LibFunc::siprintf) && !callHasFloatingPointArgument(CI)) { - Module *M = B.GetInsertBlock()->getParent()->getParent(); - Constant *SIPrintFFn = - M->getOrInsertFunction("siprintf", FT, Callee->getAttributes()); - CallInst *New = cast<CallInst>(CI->clone()); - New->setCalledFunction(SIPrintFFn); - B.Insert(New); - return New; - } - return nullptr; + Value *Len = EmitStrLen(CI->getArgOperand(2), B, DL, TLI); + if (!Len) + return nullptr; + Value *IncLen = + B.CreateAdd(Len, ConstantInt::get(Len->getType(), 1), "leninc"); + B.CreateMemCpy(CI->getArgOperand(0), CI->getArgOperand(2), IncLen, 1); + + // The sprintf result is the unincremented number of bytes in the string. + return B.CreateIntCast(Len, CI->getType(), false); } -}; + return nullptr; +} -struct FPrintFOpt : public LibCallOptimization { - Value *optimizeFixedFormatString(Function *Callee, CallInst *CI, - IRBuilder<> &B) { - ErrorReportingOpt ER(/* StreamArg = */ 0); - (void) ER.callOptimizer(Callee, CI, B); +Value *LibCallSimplifier::optimizeSPrintF(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + // Require two fixed pointer arguments and an integer result. + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 2 || !FT->getParamType(0)->isPointerTy() || + !FT->getParamType(1)->isPointerTy() || + !FT->getReturnType()->isIntegerTy()) + return nullptr; - // All the optimizations depend on the format string. - StringRef FormatStr; - if (!getConstantStringInfo(CI->getArgOperand(1), FormatStr)) - return nullptr; + if (Value *V = optimizeSPrintFString(CI, B)) { + return V; + } - // Do not do any of the following transformations if the fprintf return - // value is used, in general the fprintf return value is not compatible - // with fwrite(), fputc() or fputs(). - if (!CI->use_empty()) - return nullptr; + // sprintf(str, format, ...) -> siprintf(str, format, ...) if no floating + // point arguments. + if (TLI->has(LibFunc::siprintf) && !callHasFloatingPointArgument(CI)) { + Module *M = B.GetInsertBlock()->getParent()->getParent(); + Constant *SIPrintFFn = + M->getOrInsertFunction("siprintf", FT, Callee->getAttributes()); + CallInst *New = cast<CallInst>(CI->clone()); + New->setCalledFunction(SIPrintFFn); + B.Insert(New); + return New; + } + return nullptr; +} + +Value *LibCallSimplifier::optimizeFPrintFString(CallInst *CI, IRBuilder<> &B) { + optimizeErrorReporting(CI, B, 0); - // fprintf(F, "foo") --> fwrite("foo", 3, 1, F) - if (CI->getNumArgOperands() == 2) { - for (unsigned i = 0, e = FormatStr.size(); i != e; ++i) - if (FormatStr[i] == '%') // Could handle %% -> % if we cared. - return nullptr; // We found a format specifier. + // All the optimizations depend on the format string. + StringRef FormatStr; + if (!getConstantStringInfo(CI->getArgOperand(1), FormatStr)) + return nullptr; - // These optimizations require DataLayout. - if (!DL) return nullptr; + // Do not do any of the following transformations if the fprintf return + // value is used, in general the fprintf return value is not compatible + // with fwrite(), fputc() or fputs(). + if (!CI->use_empty()) + return nullptr; - return EmitFWrite(CI->getArgOperand(1), - ConstantInt::get(DL->getIntPtrType(*Context), - FormatStr.size()), - CI->getArgOperand(0), B, DL, TLI); - } + // fprintf(F, "foo") --> fwrite("foo", 3, 1, F) + if (CI->getNumArgOperands() == 2) { + for (unsigned i = 0, e = FormatStr.size(); i != e; ++i) + if (FormatStr[i] == '%') // Could handle %% -> % if we cared. + return nullptr; // We found a format specifier. - // The remaining optimizations require the format string to be "%s" or "%c" - // and have an extra operand. - if (FormatStr.size() != 2 || FormatStr[0] != '%' || - CI->getNumArgOperands() < 3) + // These optimizations require DataLayout. + if (!DL) return nullptr; - // Decode the second character of the format string. - if (FormatStr[1] == 'c') { - // fprintf(F, "%c", chr) --> fputc(chr, F) - if (!CI->getArgOperand(2)->getType()->isIntegerTy()) return nullptr; - return EmitFPutC(CI->getArgOperand(2), CI->getArgOperand(0), B, DL, TLI); - } + return EmitFWrite( + CI->getArgOperand(1), + ConstantInt::get(DL->getIntPtrType(CI->getContext()), FormatStr.size()), + CI->getArgOperand(0), B, DL, TLI); + } - if (FormatStr[1] == 's') { - // fprintf(F, "%s", str) --> fputs(str, F) - if (!CI->getArgOperand(2)->getType()->isPointerTy()) - return nullptr; - return EmitFPutS(CI->getArgOperand(2), CI->getArgOperand(0), B, DL, TLI); - } + // The remaining optimizations require the format string to be "%s" or "%c" + // and have an extra operand. + if (FormatStr.size() != 2 || FormatStr[0] != '%' || + CI->getNumArgOperands() < 3) return nullptr; - } - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - // Require two fixed paramters as pointers and integer result. - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 2 || !FT->getParamType(0)->isPointerTy() || - !FT->getParamType(1)->isPointerTy() || - !FT->getReturnType()->isIntegerTy()) + // Decode the second character of the format string. + if (FormatStr[1] == 'c') { + // fprintf(F, "%c", chr) --> fputc(chr, F) + if (!CI->getArgOperand(2)->getType()->isIntegerTy()) return nullptr; + return EmitFPutC(CI->getArgOperand(2), CI->getArgOperand(0), B, DL, TLI); + } - if (Value *V = optimizeFixedFormatString(Callee, CI, B)) { - return V; - } + if (FormatStr[1] == 's') { + // fprintf(F, "%s", str) --> fputs(str, F) + if (!CI->getArgOperand(2)->getType()->isPointerTy()) + return nullptr; + return EmitFPutS(CI->getArgOperand(2), CI->getArgOperand(0), B, DL, TLI); + } + return nullptr; +} - // fprintf(stream, format, ...) -> fiprintf(stream, format, ...) if no - // floating point arguments. - if (TLI->has(LibFunc::fiprintf) && !callHasFloatingPointArgument(CI)) { - Module *M = B.GetInsertBlock()->getParent()->getParent(); - Constant *FIPrintFFn = - M->getOrInsertFunction("fiprintf", FT, Callee->getAttributes()); - CallInst *New = cast<CallInst>(CI->clone()); - New->setCalledFunction(FIPrintFFn); - B.Insert(New); - return New; - } +Value *LibCallSimplifier::optimizeFPrintF(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + // Require two fixed paramters as pointers and integer result. + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 2 || !FT->getParamType(0)->isPointerTy() || + !FT->getParamType(1)->isPointerTy() || + !FT->getReturnType()->isIntegerTy()) return nullptr; + + if (Value *V = optimizeFPrintFString(CI, B)) { + return V; } -}; -struct FWriteOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - ErrorReportingOpt ER(/* StreamArg = */ 3); - (void) ER.callOptimizer(Callee, CI, B); + // fprintf(stream, format, ...) -> fiprintf(stream, format, ...) if no + // floating point arguments. + if (TLI->has(LibFunc::fiprintf) && !callHasFloatingPointArgument(CI)) { + Module *M = B.GetInsertBlock()->getParent()->getParent(); + Constant *FIPrintFFn = + M->getOrInsertFunction("fiprintf", FT, Callee->getAttributes()); + CallInst *New = cast<CallInst>(CI->clone()); + New->setCalledFunction(FIPrintFFn); + B.Insert(New); + return New; + } + return nullptr; +} - // Require a pointer, an integer, an integer, a pointer, returning integer. - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 4 || !FT->getParamType(0)->isPointerTy() || - !FT->getParamType(1)->isIntegerTy() || - !FT->getParamType(2)->isIntegerTy() || - !FT->getParamType(3)->isPointerTy() || - !FT->getReturnType()->isIntegerTy()) - return nullptr; +Value *LibCallSimplifier::optimizeFWrite(CallInst *CI, IRBuilder<> &B) { + optimizeErrorReporting(CI, B, 3); - // Get the element size and count. - ConstantInt *SizeC = dyn_cast<ConstantInt>(CI->getArgOperand(1)); - ConstantInt *CountC = dyn_cast<ConstantInt>(CI->getArgOperand(2)); - if (!SizeC || !CountC) return nullptr; - uint64_t Bytes = SizeC->getZExtValue()*CountC->getZExtValue(); - - // If this is writing zero records, remove the call (it's a noop). - if (Bytes == 0) - return ConstantInt::get(CI->getType(), 0); - - // If this is writing one byte, turn it into fputc. - // This optimisation is only valid, if the return value is unused. - if (Bytes == 1 && CI->use_empty()) { // fwrite(S,1,1,F) -> fputc(S[0],F) - Value *Char = B.CreateLoad(CastToCStr(CI->getArgOperand(0), B), "char"); - Value *NewCI = EmitFPutC(Char, CI->getArgOperand(3), B, DL, TLI); - return NewCI ? ConstantInt::get(CI->getType(), 1) : nullptr; - } + Function *Callee = CI->getCalledFunction(); + // Require a pointer, an integer, an integer, a pointer, returning integer. + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 4 || !FT->getParamType(0)->isPointerTy() || + !FT->getParamType(1)->isIntegerTy() || + !FT->getParamType(2)->isIntegerTy() || + !FT->getParamType(3)->isPointerTy() || + !FT->getReturnType()->isIntegerTy()) + return nullptr; + // Get the element size and count. + ConstantInt *SizeC = dyn_cast<ConstantInt>(CI->getArgOperand(1)); + ConstantInt *CountC = dyn_cast<ConstantInt>(CI->getArgOperand(2)); + if (!SizeC || !CountC) return nullptr; - } -}; + uint64_t Bytes = SizeC->getZExtValue() * CountC->getZExtValue(); -struct FPutsOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - ErrorReportingOpt ER(/* StreamArg = */ 1); - (void) ER.callOptimizer(Callee, CI, B); + // If this is writing zero records, remove the call (it's a noop). + if (Bytes == 0) + return ConstantInt::get(CI->getType(), 0); - // These optimizations require DataLayout. - if (!DL) return nullptr; + // If this is writing one byte, turn it into fputc. + // This optimisation is only valid, if the return value is unused. + if (Bytes == 1 && CI->use_empty()) { // fwrite(S,1,1,F) -> fputc(S[0],F) + Value *Char = B.CreateLoad(CastToCStr(CI->getArgOperand(0), B), "char"); + Value *NewCI = EmitFPutC(Char, CI->getArgOperand(3), B, DL, TLI); + return NewCI ? ConstantInt::get(CI->getType(), 1) : nullptr; + } - // Require two pointers. Also, we can't optimize if return value is used. - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() != 2 || !FT->getParamType(0)->isPointerTy() || - !FT->getParamType(1)->isPointerTy() || - !CI->use_empty()) - return nullptr; + return nullptr; +} - // fputs(s,F) --> fwrite(s,1,strlen(s),F) - uint64_t Len = GetStringLength(CI->getArgOperand(0)); - if (!Len) return nullptr; - // Known to have no uses (see above). - return EmitFWrite(CI->getArgOperand(0), - ConstantInt::get(DL->getIntPtrType(*Context), Len-1), - CI->getArgOperand(1), B, DL, TLI); - } -}; - -struct PutsOpt : public LibCallOptimization { - Value *callOptimizer(Function *Callee, CallInst *CI, - IRBuilder<> &B) override { - // Require one fixed pointer argument and an integer/void result. - FunctionType *FT = Callee->getFunctionType(); - if (FT->getNumParams() < 1 || !FT->getParamType(0)->isPointerTy() || - !(FT->getReturnType()->isIntegerTy() || - FT->getReturnType()->isVoidTy())) - return nullptr; +Value *LibCallSimplifier::optimizeFPuts(CallInst *CI, IRBuilder<> &B) { + optimizeErrorReporting(CI, B, 1); - // Check for a constant string. - StringRef Str; - if (!getConstantStringInfo(CI->getArgOperand(0), Str)) - return nullptr; + Function *Callee = CI->getCalledFunction(); - if (Str.empty() && CI->use_empty()) { - // puts("") -> putchar('\n') - Value *Res = EmitPutChar(B.getInt32('\n'), B, DL, TLI); - if (CI->use_empty() || !Res) return Res; - return B.CreateIntCast(Res, CI->getType(), true); - } + // These optimizations require DataLayout. + if (!DL) + return nullptr; + // Require two pointers. Also, we can't optimize if return value is used. + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() != 2 || !FT->getParamType(0)->isPointerTy() || + !FT->getParamType(1)->isPointerTy() || !CI->use_empty()) + return nullptr; + + // fputs(s,F) --> fwrite(s,1,strlen(s),F) + uint64_t Len = GetStringLength(CI->getArgOperand(0)); + if (!Len) return nullptr; - } -}; -} // End anonymous namespace. + // Known to have no uses (see above). + return EmitFWrite( + CI->getArgOperand(0), + ConstantInt::get(DL->getIntPtrType(CI->getContext()), Len - 1), + CI->getArgOperand(1), B, DL, TLI); +} -namespace llvm { +Value *LibCallSimplifier::optimizePuts(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + // Require one fixed pointer argument and an integer/void result. + FunctionType *FT = Callee->getFunctionType(); + if (FT->getNumParams() < 1 || !FT->getParamType(0)->isPointerTy() || + !(FT->getReturnType()->isIntegerTy() || FT->getReturnType()->isVoidTy())) + return nullptr; -class LibCallSimplifierImpl { - const DataLayout *DL; - const TargetLibraryInfo *TLI; - const LibCallSimplifier *LCS; - bool UnsafeFPShrink; + // Check for a constant string. + StringRef Str; + if (!getConstantStringInfo(CI->getArgOperand(0), Str)) + return nullptr; - // Math library call optimizations. - CosOpt Cos; - PowOpt Pow; - Exp2Opt Exp2; -public: - LibCallSimplifierImpl(const DataLayout *DL, const TargetLibraryInfo *TLI, - const LibCallSimplifier *LCS, - bool UnsafeFPShrink = false) - : Cos(UnsafeFPShrink), Pow(UnsafeFPShrink), Exp2(UnsafeFPShrink) { - this->DL = DL; - this->TLI = TLI; - this->LCS = LCS; - this->UnsafeFPShrink = UnsafeFPShrink; + if (Str.empty() && CI->use_empty()) { + // puts("") -> putchar('\n') + Value *Res = EmitPutChar(B.getInt32('\n'), B, DL, TLI); + if (CI->use_empty() || !Res) + return Res; + return B.CreateIntCast(Res, CI->getType(), true); } - Value *optimizeCall(CallInst *CI); - LibCallOptimization *lookupOptimization(CallInst *CI); - bool hasFloatVersion(StringRef FuncName); -}; + return nullptr; +} -bool LibCallSimplifierImpl::hasFloatVersion(StringRef FuncName) { +bool LibCallSimplifier::hasFloatVersion(StringRef FuncName) { LibFunc::Func Func; SmallString<20> FloatFuncName = FuncName; FloatFuncName += 'f'; @@ -2048,263 +1857,239 @@ bool LibCallSimplifierImpl::hasFloatVersion(StringRef FuncName) { return false; } -// Fortified library call optimizations. -static MemCpyChkOpt MemCpyChk; -static MemMoveChkOpt MemMoveChk; -static MemSetChkOpt MemSetChk; -static StrCpyChkOpt StrCpyChk; -static StpCpyChkOpt StpCpyChk; -static StrNCpyChkOpt StrNCpyChk; - -// String library call optimizations. -static StrCatOpt StrCat; -static StrNCatOpt StrNCat; -static StrChrOpt StrChr; -static StrRChrOpt StrRChr; -static StrCmpOpt StrCmp; -static StrNCmpOpt StrNCmp; -static StrCpyOpt StrCpy; -static StpCpyOpt StpCpy; -static StrNCpyOpt StrNCpy; -static StrLenOpt StrLen; -static StrPBrkOpt StrPBrk; -static StrToOpt StrTo; -static StrSpnOpt StrSpn; -static StrCSpnOpt StrCSpn; -static StrStrOpt StrStr; - -// Memory library call optimizations. -static MemCmpOpt MemCmp; -static MemCpyOpt MemCpy; -static MemMoveOpt MemMove; -static MemSetOpt MemSet; - -// Math library call optimizations. -static UnaryDoubleFPOpt UnaryDoubleFP(false); -static BinaryDoubleFPOpt BinaryDoubleFP(false); -static UnaryDoubleFPOpt UnsafeUnaryDoubleFP(true); -static SinCosPiOpt SinCosPi; - - // Integer library call optimizations. -static FFSOpt FFS; -static AbsOpt Abs; -static IsDigitOpt IsDigit; -static IsAsciiOpt IsAscii; -static ToAsciiOpt ToAscii; - -// Formatting and IO library call optimizations. -static ErrorReportingOpt ErrorReporting; -static ErrorReportingOpt ErrorReporting0(0); -static ErrorReportingOpt ErrorReporting1(1); -static PrintFOpt PrintF; -static SPrintFOpt SPrintF; -static FPrintFOpt FPrintF; -static FWriteOpt FWrite; -static FPutsOpt FPuts; -static PutsOpt Puts; - -LibCallOptimization *LibCallSimplifierImpl::lookupOptimization(CallInst *CI) { +Value *LibCallSimplifier::optimizeStringMemoryLibCall(CallInst *CI, + IRBuilder<> &Builder) { + LibFunc::Func Func; + Function *Callee = CI->getCalledFunction(); + StringRef FuncName = Callee->getName(); + + // Check for string/memory library functions. + if (TLI->getLibFunc(FuncName, Func) && TLI->has(Func)) { + // Make sure we never change the calling convention. + assert((ignoreCallingConv(Func) || + CI->getCallingConv() == llvm::CallingConv::C) && + "Optimizing string/memory libcall would change the calling convention"); + switch (Func) { + case LibFunc::strcat: + return optimizeStrCat(CI, Builder); + case LibFunc::strncat: + return optimizeStrNCat(CI, Builder); + case LibFunc::strchr: + return optimizeStrChr(CI, Builder); + case LibFunc::strrchr: + return optimizeStrRChr(CI, Builder); + case LibFunc::strcmp: + return optimizeStrCmp(CI, Builder); + case LibFunc::strncmp: + return optimizeStrNCmp(CI, Builder); + case LibFunc::strcpy: + return optimizeStrCpy(CI, Builder); + case LibFunc::stpcpy: + return optimizeStpCpy(CI, Builder); + case LibFunc::strncpy: + return optimizeStrNCpy(CI, Builder); + case LibFunc::strlen: + return optimizeStrLen(CI, Builder); + case LibFunc::strpbrk: + return optimizeStrPBrk(CI, Builder); + case LibFunc::strtol: + case LibFunc::strtod: + case LibFunc::strtof: + case LibFunc::strtoul: + case LibFunc::strtoll: + case LibFunc::strtold: + case LibFunc::strtoull: + return optimizeStrTo(CI, Builder); + case LibFunc::strspn: + return optimizeStrSpn(CI, Builder); + case LibFunc::strcspn: + return optimizeStrCSpn(CI, Builder); + case LibFunc::strstr: + return optimizeStrStr(CI, Builder); + case LibFunc::memcmp: + return optimizeMemCmp(CI, Builder); + case LibFunc::memcpy: + return optimizeMemCpy(CI, Builder); + case LibFunc::memmove: + return optimizeMemMove(CI, Builder); + case LibFunc::memset: + return optimizeMemSet(CI, Builder); + default: + break; + } + } + return nullptr; +} + +Value *LibCallSimplifier::optimizeCall(CallInst *CI) { + if (CI->isNoBuiltin()) + return nullptr; + LibFunc::Func Func; Function *Callee = CI->getCalledFunction(); StringRef FuncName = Callee->getName(); + IRBuilder<> Builder(CI); + bool isCallingConvC = CI->getCallingConv() == llvm::CallingConv::C; + + // Command-line parameter overrides function attribute. + if (EnableUnsafeFPShrink.getNumOccurrences() > 0) + UnsafeFPShrink = EnableUnsafeFPShrink; + else if (Callee->hasFnAttribute("unsafe-fp-math")) { + // FIXME: This is the same problem as described in optimizeSqrt(). + // If calls gain access to IR-level FMF, then use that instead of a + // function attribute. + + // Check for unsafe-fp-math = true. + Attribute Attr = Callee->getFnAttribute("unsafe-fp-math"); + if (Attr.getValueAsString() == "true") + UnsafeFPShrink = true; + } - // Next check for intrinsics. + // First, check for intrinsics. if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) { + if (!isCallingConvC) + return nullptr; switch (II->getIntrinsicID()) { case Intrinsic::pow: - return &Pow; + return optimizePow(CI, Builder); case Intrinsic::exp2: - return &Exp2; + return optimizeExp2(CI, Builder); + case Intrinsic::fabs: + return optimizeFabs(CI, Builder); + case Intrinsic::sqrt: + return optimizeSqrt(CI, Builder); default: - return nullptr; + return nullptr; } } + // Also try to simplify calls to fortified library functions. + if (Value *SimplifiedFortifiedCI = FortifiedSimplifier.optimizeCall(CI)) { + // Try to further simplify the result. + CallInst *SimplifiedCI = dyn_cast<CallInst>(SimplifiedFortifiedCI); + if (SimplifiedCI && SimplifiedCI->getCalledFunction()) + if (Value *V = optimizeStringMemoryLibCall(SimplifiedCI, Builder)) + return V; + return SimplifiedFortifiedCI; + } + // Then check for known library functions. if (TLI->getLibFunc(FuncName, Func) && TLI->has(Func)) { + // We never change the calling convention. + if (!ignoreCallingConv(Func) && !isCallingConvC) + return nullptr; + if (Value *V = optimizeStringMemoryLibCall(CI, Builder)) + return V; switch (Func) { - case LibFunc::strcat: - return &StrCat; - case LibFunc::strncat: - return &StrNCat; - case LibFunc::strchr: - return &StrChr; - case LibFunc::strrchr: - return &StrRChr; - case LibFunc::strcmp: - return &StrCmp; - case LibFunc::strncmp: - return &StrNCmp; - case LibFunc::strcpy: - return &StrCpy; - case LibFunc::stpcpy: - return &StpCpy; - case LibFunc::strncpy: - return &StrNCpy; - case LibFunc::strlen: - return &StrLen; - case LibFunc::strpbrk: - return &StrPBrk; - case LibFunc::strtol: - case LibFunc::strtod: - case LibFunc::strtof: - case LibFunc::strtoul: - case LibFunc::strtoll: - case LibFunc::strtold: - case LibFunc::strtoull: - return &StrTo; - case LibFunc::strspn: - return &StrSpn; - case LibFunc::strcspn: - return &StrCSpn; - case LibFunc::strstr: - return &StrStr; - case LibFunc::memcmp: - return &MemCmp; - case LibFunc::memcpy: - return &MemCpy; - case LibFunc::memmove: - return &MemMove; - case LibFunc::memset: - return &MemSet; - case LibFunc::cosf: - case LibFunc::cos: - case LibFunc::cosl: - return &Cos; - case LibFunc::sinpif: - case LibFunc::sinpi: - case LibFunc::cospif: - case LibFunc::cospi: - return &SinCosPi; - case LibFunc::powf: - case LibFunc::pow: - case LibFunc::powl: - return &Pow; - case LibFunc::exp2l: - case LibFunc::exp2: - case LibFunc::exp2f: - return &Exp2; - case LibFunc::ffs: - case LibFunc::ffsl: - case LibFunc::ffsll: - return &FFS; - case LibFunc::abs: - case LibFunc::labs: - case LibFunc::llabs: - return &Abs; - case LibFunc::isdigit: - return &IsDigit; - case LibFunc::isascii: - return &IsAscii; - case LibFunc::toascii: - return &ToAscii; - case LibFunc::printf: - return &PrintF; - case LibFunc::sprintf: - return &SPrintF; - case LibFunc::fprintf: - return &FPrintF; - case LibFunc::fwrite: - return &FWrite; - case LibFunc::fputs: - return &FPuts; - case LibFunc::puts: - return &Puts; - case LibFunc::perror: - return &ErrorReporting; - case LibFunc::vfprintf: - case LibFunc::fiprintf: - return &ErrorReporting0; - case LibFunc::fputc: - return &ErrorReporting1; - case LibFunc::ceil: - case LibFunc::fabs: - case LibFunc::floor: - case LibFunc::rint: - case LibFunc::round: - case LibFunc::nearbyint: - case LibFunc::trunc: - if (hasFloatVersion(FuncName)) - return &UnaryDoubleFP; - return nullptr; - case LibFunc::acos: - case LibFunc::acosh: - case LibFunc::asin: - case LibFunc::asinh: - case LibFunc::atan: - case LibFunc::atanh: - case LibFunc::cbrt: - case LibFunc::cosh: - case LibFunc::exp: - case LibFunc::exp10: - case LibFunc::expm1: - case LibFunc::log: - case LibFunc::log10: - case LibFunc::log1p: - case LibFunc::log2: - case LibFunc::logb: - case LibFunc::sin: - case LibFunc::sinh: - case LibFunc::sqrt: - case LibFunc::tan: - case LibFunc::tanh: - if (UnsafeFPShrink && hasFloatVersion(FuncName)) - return &UnsafeUnaryDoubleFP; - return nullptr; - case LibFunc::fmin: - case LibFunc::fmax: - if (hasFloatVersion(FuncName)) - return &BinaryDoubleFP; - return nullptr; - case LibFunc::memcpy_chk: - return &MemCpyChk; - default: - return nullptr; - } - } - - // Finally check for fortified library calls. - if (FuncName.endswith("_chk")) { - if (FuncName == "__memmove_chk") - return &MemMoveChk; - else if (FuncName == "__memset_chk") - return &MemSetChk; - else if (FuncName == "__strcpy_chk") - return &StrCpyChk; - else if (FuncName == "__stpcpy_chk") - return &StpCpyChk; - else if (FuncName == "__strncpy_chk") - return &StrNCpyChk; - else if (FuncName == "__stpncpy_chk") - return &StrNCpyChk; - } - - return nullptr; - -} - -Value *LibCallSimplifierImpl::optimizeCall(CallInst *CI) { - LibCallOptimization *LCO = lookupOptimization(CI); - if (LCO) { - IRBuilder<> Builder(CI); - return LCO->optimizeCall(CI, DL, TLI, LCS, Builder); + case LibFunc::cosf: + case LibFunc::cos: + case LibFunc::cosl: + return optimizeCos(CI, Builder); + case LibFunc::sinpif: + case LibFunc::sinpi: + case LibFunc::cospif: + case LibFunc::cospi: + return optimizeSinCosPi(CI, Builder); + case LibFunc::powf: + case LibFunc::pow: + case LibFunc::powl: + return optimizePow(CI, Builder); + case LibFunc::exp2l: + case LibFunc::exp2: + case LibFunc::exp2f: + return optimizeExp2(CI, Builder); + case LibFunc::fabsf: + case LibFunc::fabs: + case LibFunc::fabsl: + return optimizeFabs(CI, Builder); + case LibFunc::sqrtf: + case LibFunc::sqrt: + case LibFunc::sqrtl: + return optimizeSqrt(CI, Builder); + case LibFunc::ffs: + case LibFunc::ffsl: + case LibFunc::ffsll: + return optimizeFFS(CI, Builder); + case LibFunc::abs: + case LibFunc::labs: + case LibFunc::llabs: + return optimizeAbs(CI, Builder); + case LibFunc::isdigit: + return optimizeIsDigit(CI, Builder); + case LibFunc::isascii: + return optimizeIsAscii(CI, Builder); + case LibFunc::toascii: + return optimizeToAscii(CI, Builder); + case LibFunc::printf: + return optimizePrintF(CI, Builder); + case LibFunc::sprintf: + return optimizeSPrintF(CI, Builder); + case LibFunc::fprintf: + return optimizeFPrintF(CI, Builder); + case LibFunc::fwrite: + return optimizeFWrite(CI, Builder); + case LibFunc::fputs: + return optimizeFPuts(CI, Builder); + case LibFunc::puts: + return optimizePuts(CI, Builder); + case LibFunc::perror: + return optimizeErrorReporting(CI, Builder); + case LibFunc::vfprintf: + case LibFunc::fiprintf: + return optimizeErrorReporting(CI, Builder, 0); + case LibFunc::fputc: + return optimizeErrorReporting(CI, Builder, 1); + case LibFunc::ceil: + case LibFunc::floor: + case LibFunc::rint: + case LibFunc::round: + case LibFunc::nearbyint: + case LibFunc::trunc: + if (hasFloatVersion(FuncName)) + return optimizeUnaryDoubleFP(CI, Builder, false); + return nullptr; + case LibFunc::acos: + case LibFunc::acosh: + case LibFunc::asin: + case LibFunc::asinh: + case LibFunc::atan: + case LibFunc::atanh: + case LibFunc::cbrt: + case LibFunc::cosh: + case LibFunc::exp: + case LibFunc::exp10: + case LibFunc::expm1: + case LibFunc::log: + case LibFunc::log10: + case LibFunc::log1p: + case LibFunc::log2: + case LibFunc::logb: + case LibFunc::sin: + case LibFunc::sinh: + case LibFunc::tan: + case LibFunc::tanh: + if (UnsafeFPShrink && hasFloatVersion(FuncName)) + return optimizeUnaryDoubleFP(CI, Builder, true); + return nullptr; + case LibFunc::copysign: + case LibFunc::fmin: + case LibFunc::fmax: + if (hasFloatVersion(FuncName)) + return optimizeBinaryDoubleFP(CI, Builder); + return nullptr; + default: + return nullptr; + } } return nullptr; } LibCallSimplifier::LibCallSimplifier(const DataLayout *DL, - const TargetLibraryInfo *TLI, - bool UnsafeFPShrink) { - Impl = new LibCallSimplifierImpl(DL, TLI, this, UnsafeFPShrink); -} - -LibCallSimplifier::~LibCallSimplifier() { - delete Impl; -} - -Value *LibCallSimplifier::optimizeCall(CallInst *CI) { - if (CI->isNoBuiltin()) return nullptr; - return Impl->optimizeCall(CI); + const TargetLibraryInfo *TLI) : + FortifiedSimplifier(DL, TLI), + DL(DL), + TLI(TLI), + UnsafeFPShrink(false) { } void LibCallSimplifier::replaceAllUsesWith(Instruction *I, Value *With) const { @@ -2312,8 +2097,6 @@ void LibCallSimplifier::replaceAllUsesWith(Instruction *I, Value *With) const { I->eraseFromParent(); } -} - // TODO: // Additional cases that we need to add to this file: // @@ -2361,3 +2144,184 @@ void LibCallSimplifier::replaceAllUsesWith(Instruction *I, Value *With) const { // * trunc(cnst) -> cnst' // // + +//===----------------------------------------------------------------------===// +// Fortified Library Call Optimizations +//===----------------------------------------------------------------------===// + +bool FortifiedLibCallSimplifier::isFortifiedCallFoldable(CallInst *CI, + unsigned ObjSizeOp, + unsigned SizeOp, + bool isString) { + if (CI->getArgOperand(ObjSizeOp) == CI->getArgOperand(SizeOp)) + return true; + if (ConstantInt *ObjSizeCI = + dyn_cast<ConstantInt>(CI->getArgOperand(ObjSizeOp))) { + if (ObjSizeCI->isAllOnesValue()) + return true; + // If the object size wasn't -1 (unknown), bail out if we were asked to. + if (OnlyLowerUnknownSize) + return false; + if (isString) { + uint64_t Len = GetStringLength(CI->getArgOperand(SizeOp)); + // If the length is 0 we don't know how long it is and so we can't + // remove the check. + if (Len == 0) + return false; + return ObjSizeCI->getZExtValue() >= Len; + } + if (ConstantInt *SizeCI = dyn_cast<ConstantInt>(CI->getArgOperand(SizeOp))) + return ObjSizeCI->getZExtValue() >= SizeCI->getZExtValue(); + } + return false; +} + +Value *FortifiedLibCallSimplifier::optimizeMemCpyChk(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + + if (!checkStringCopyLibFuncSignature(Callee, LibFunc::memcpy_chk, DL)) + return nullptr; + + if (isFortifiedCallFoldable(CI, 3, 2, false)) { + B.CreateMemCpy(CI->getArgOperand(0), CI->getArgOperand(1), + CI->getArgOperand(2), 1); + return CI->getArgOperand(0); + } + return nullptr; +} + +Value *FortifiedLibCallSimplifier::optimizeMemMoveChk(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + + if (!checkStringCopyLibFuncSignature(Callee, LibFunc::memmove_chk, DL)) + return nullptr; + + if (isFortifiedCallFoldable(CI, 3, 2, false)) { + B.CreateMemMove(CI->getArgOperand(0), CI->getArgOperand(1), + CI->getArgOperand(2), 1); + return CI->getArgOperand(0); + } + return nullptr; +} + +Value *FortifiedLibCallSimplifier::optimizeMemSetChk(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + + if (!checkStringCopyLibFuncSignature(Callee, LibFunc::memset_chk, DL)) + return nullptr; + + if (isFortifiedCallFoldable(CI, 3, 2, false)) { + Value *Val = B.CreateIntCast(CI->getArgOperand(1), B.getInt8Ty(), false); + B.CreateMemSet(CI->getArgOperand(0), Val, CI->getArgOperand(2), 1); + return CI->getArgOperand(0); + } + return nullptr; +} + +Value *FortifiedLibCallSimplifier::optimizeStrCpyChk(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + StringRef Name = Callee->getName(); + LibFunc::Func Func = + Name.startswith("str") ? LibFunc::strcpy_chk : LibFunc::stpcpy_chk; + + if (!checkStringCopyLibFuncSignature(Callee, Func, DL)) + return nullptr; + + Value *Dst = CI->getArgOperand(0), *Src = CI->getArgOperand(1), + *ObjSize = CI->getArgOperand(2); + + // __stpcpy_chk(x,x,...) -> x+strlen(x) + if (!OnlyLowerUnknownSize && Dst == Src) { + Value *StrLen = EmitStrLen(Src, B, DL, TLI); + return StrLen ? B.CreateInBoundsGEP(Dst, StrLen) : nullptr; + } + + // If a) we don't have any length information, or b) we know this will + // fit then just lower to a plain st[rp]cpy. Otherwise we'll keep our + // st[rp]cpy_chk call which may fail at runtime if the size is too long. + // TODO: It might be nice to get a maximum length out of the possible + // string lengths for varying. + if (isFortifiedCallFoldable(CI, 2, 1, true)) { + Value *Ret = EmitStrCpy(Dst, Src, B, DL, TLI, Name.substr(2, 6)); + return Ret; + } else if (!OnlyLowerUnknownSize) { + // Maybe we can stil fold __st[rp]cpy_chk to __memcpy_chk. + uint64_t Len = GetStringLength(Src); + if (Len == 0) + return nullptr; + + // This optimization requires DataLayout. + if (!DL) + return nullptr; + + Type *SizeTTy = DL->getIntPtrType(CI->getContext()); + Value *LenV = ConstantInt::get(SizeTTy, Len); + Value *Ret = EmitMemCpyChk(Dst, Src, LenV, ObjSize, B, DL, TLI); + // If the function was an __stpcpy_chk, and we were able to fold it into + // a __memcpy_chk, we still need to return the correct end pointer. + if (Ret && Func == LibFunc::stpcpy_chk) + return B.CreateGEP(Dst, ConstantInt::get(SizeTTy, Len - 1)); + return Ret; + } + return nullptr; +} + +Value *FortifiedLibCallSimplifier::optimizeStrNCpyChk(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + StringRef Name = Callee->getName(); + LibFunc::Func Func = + Name.startswith("str") ? LibFunc::strncpy_chk : LibFunc::stpncpy_chk; + + if (!checkStringCopyLibFuncSignature(Callee, Func, DL)) + return nullptr; + if (isFortifiedCallFoldable(CI, 3, 2, false)) { + Value *Ret = + EmitStrNCpy(CI->getArgOperand(0), CI->getArgOperand(1), + CI->getArgOperand(2), B, DL, TLI, Name.substr(2, 7)); + return Ret; + } + return nullptr; +} + +Value *FortifiedLibCallSimplifier::optimizeCall(CallInst *CI) { + if (CI->isNoBuiltin()) + return nullptr; + + LibFunc::Func Func; + Function *Callee = CI->getCalledFunction(); + StringRef FuncName = Callee->getName(); + IRBuilder<> Builder(CI); + bool isCallingConvC = CI->getCallingConv() == llvm::CallingConv::C; + + // First, check that this is a known library functions. + if (!TLI->getLibFunc(FuncName, Func) || !TLI->has(Func)) + return nullptr; + + // We never change the calling convention. + if (!ignoreCallingConv(Func) && !isCallingConvC) + return nullptr; + + switch (Func) { + case LibFunc::memcpy_chk: + return optimizeMemCpyChk(CI, Builder); + case LibFunc::memmove_chk: + return optimizeMemMoveChk(CI, Builder); + case LibFunc::memset_chk: + return optimizeMemSetChk(CI, Builder); + case LibFunc::stpcpy_chk: + case LibFunc::strcpy_chk: + return optimizeStrCpyChk(CI, Builder); + case LibFunc::stpncpy_chk: + case LibFunc::strncpy_chk: + return optimizeStrNCpyChk(CI, Builder); + default: + break; + } + return nullptr; +} + +FortifiedLibCallSimplifier:: +FortifiedLibCallSimplifier(const DataLayout *DL, const TargetLibraryInfo *TLI, + bool OnlyLowerUnknownSize) + : DL(DL), TLI(TLI), OnlyLowerUnknownSize(OnlyLowerUnknownSize) { +} diff --git a/lib/Transforms/Utils/SymbolRewriter.cpp b/lib/Transforms/Utils/SymbolRewriter.cpp new file mode 100644 index 000000000000..b35a662f17b5 --- /dev/null +++ b/lib/Transforms/Utils/SymbolRewriter.cpp @@ -0,0 +1,527 @@ +//===- SymbolRewriter.cpp - Symbol Rewriter ---------------------*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// SymbolRewriter is a LLVM pass which can rewrite symbols transparently within +// existing code. It is implemented as a compiler pass and is configured via a +// YAML configuration file. +// +// The YAML configuration file format is as follows: +// +// RewriteMapFile := RewriteDescriptors +// RewriteDescriptors := RewriteDescriptor | RewriteDescriptors +// RewriteDescriptor := RewriteDescriptorType ':' '{' RewriteDescriptorFields '}' +// RewriteDescriptorFields := RewriteDescriptorField | RewriteDescriptorFields +// RewriteDescriptorField := FieldIdentifier ':' FieldValue ',' +// RewriteDescriptorType := Identifier +// FieldIdentifier := Identifier +// FieldValue := Identifier +// Identifier := [0-9a-zA-Z]+ +// +// Currently, the following descriptor types are supported: +// +// - function: (function rewriting) +// + Source (original name of the function) +// + Target (explicit transformation) +// + Transform (pattern transformation) +// + Naked (boolean, whether the function is undecorated) +// - global variable: (external linkage global variable rewriting) +// + Source (original name of externally visible variable) +// + Target (explicit transformation) +// + Transform (pattern transformation) +// - global alias: (global alias rewriting) +// + Source (original name of the aliased name) +// + Target (explicit transformation) +// + Transform (pattern transformation) +// +// Note that source and exactly one of [Target, Transform] must be provided +// +// New rewrite descriptors can be created. Addding a new rewrite descriptor +// involves: +// +// a) extended the rewrite descriptor kind enumeration +// (<anonymous>::RewriteDescriptor::RewriteDescriptorType) +// b) implementing the new descriptor +// (c.f. <anonymous>::ExplicitRewriteFunctionDescriptor) +// c) extending the rewrite map parser +// (<anonymous>::RewriteMapParser::parseEntry) +// +// Specify to rewrite the symbols using the `-rewrite-symbols` option, and +// specify the map file to use for the rewriting via the `-rewrite-map-file` +// option. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "symbol-rewriter" +#include "llvm/CodeGen/Passes.h" +#include "llvm/Pass.h" +#include "llvm/PassManager.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/Regex.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/YAMLParser.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/IPO/PassManagerBuilder.h" +#include "llvm/Transforms/Utils/SymbolRewriter.h" + +using namespace llvm; + +static cl::list<std::string> RewriteMapFiles("rewrite-map-file", + cl::desc("Symbol Rewrite Map"), + cl::value_desc("filename")); + +namespace llvm { +namespace SymbolRewriter { +template <RewriteDescriptor::Type DT, typename ValueType, + ValueType *(llvm::Module::*Get)(StringRef) const> +class ExplicitRewriteDescriptor : public RewriteDescriptor { +public: + const std::string Source; + const std::string Target; + + ExplicitRewriteDescriptor(StringRef S, StringRef T, const bool Naked) + : RewriteDescriptor(DT), Source(Naked ? StringRef("\01" + S.str()) : S), + Target(T) {} + + bool performOnModule(Module &M) override; + + static bool classof(const RewriteDescriptor *RD) { + return RD->getType() == DT; + } +}; + +template <RewriteDescriptor::Type DT, typename ValueType, + ValueType *(llvm::Module::*Get)(StringRef) const> +bool ExplicitRewriteDescriptor<DT, ValueType, Get>::performOnModule(Module &M) { + bool Changed = false; + if (ValueType *S = (M.*Get)(Source)) { + if (Value *T = (M.*Get)(Target)) + S->setValueName(T->getValueName()); + else + S->setName(Target); + Changed = true; + } + return Changed; +} + +template <RewriteDescriptor::Type DT, typename ValueType, + ValueType *(llvm::Module::*Get)(StringRef) const, + iterator_range<typename iplist<ValueType>::iterator> + (llvm::Module::*Iterator)()> +class PatternRewriteDescriptor : public RewriteDescriptor { +public: + const std::string Pattern; + const std::string Transform; + + PatternRewriteDescriptor(StringRef P, StringRef T) + : RewriteDescriptor(DT), Pattern(P), Transform(T) { } + + bool performOnModule(Module &M) override; + + static bool classof(const RewriteDescriptor *RD) { + return RD->getType() == DT; + } +}; + +template <RewriteDescriptor::Type DT, typename ValueType, + ValueType *(llvm::Module::*Get)(StringRef) const, + iterator_range<typename iplist<ValueType>::iterator> + (llvm::Module::*Iterator)()> +bool PatternRewriteDescriptor<DT, ValueType, Get, Iterator>:: +performOnModule(Module &M) { + bool Changed = false; + for (auto &C : (M.*Iterator)()) { + std::string Error; + + std::string Name = Regex(Pattern).sub(Transform, C.getName(), &Error); + if (!Error.empty()) + report_fatal_error("unable to transforn " + C.getName() + " in " + + M.getModuleIdentifier() + ": " + Error); + + if (Value *V = (M.*Get)(Name)) + C.setValueName(V->getValueName()); + else + C.setName(Name); + + Changed = true; + } + return Changed; +} + +/// Represents a rewrite for an explicitly named (function) symbol. Both the +/// source function name and target function name of the transformation are +/// explicitly spelt out. +typedef ExplicitRewriteDescriptor<RewriteDescriptor::Type::Function, + llvm::Function, &llvm::Module::getFunction> + ExplicitRewriteFunctionDescriptor; + +/// Represents a rewrite for an explicitly named (global variable) symbol. Both +/// the source variable name and target variable name are spelt out. This +/// applies only to module level variables. +typedef ExplicitRewriteDescriptor<RewriteDescriptor::Type::GlobalVariable, + llvm::GlobalVariable, + &llvm::Module::getGlobalVariable> + ExplicitRewriteGlobalVariableDescriptor; + +/// Represents a rewrite for an explicitly named global alias. Both the source +/// and target name are explicitly spelt out. +typedef ExplicitRewriteDescriptor<RewriteDescriptor::Type::NamedAlias, + llvm::GlobalAlias, + &llvm::Module::getNamedAlias> + ExplicitRewriteNamedAliasDescriptor; + +/// Represents a rewrite for a regular expression based pattern for functions. +/// A pattern for the function name is provided and a transformation for that +/// pattern to determine the target function name create the rewrite rule. +typedef PatternRewriteDescriptor<RewriteDescriptor::Type::Function, + llvm::Function, &llvm::Module::getFunction, + &llvm::Module::functions> + PatternRewriteFunctionDescriptor; + +/// Represents a rewrite for a global variable based upon a matching pattern. +/// Each global variable matching the provided pattern will be transformed as +/// described in the transformation pattern for the target. Applies only to +/// module level variables. +typedef PatternRewriteDescriptor<RewriteDescriptor::Type::GlobalVariable, + llvm::GlobalVariable, + &llvm::Module::getGlobalVariable, + &llvm::Module::globals> + PatternRewriteGlobalVariableDescriptor; + +/// PatternRewriteNamedAliasDescriptor - represents a rewrite for global +/// aliases which match a given pattern. The provided transformation will be +/// applied to each of the matching names. +typedef PatternRewriteDescriptor<RewriteDescriptor::Type::NamedAlias, + llvm::GlobalAlias, + &llvm::Module::getNamedAlias, + &llvm::Module::aliases> + PatternRewriteNamedAliasDescriptor; + +bool RewriteMapParser::parse(const std::string &MapFile, + RewriteDescriptorList *DL) { + ErrorOr<std::unique_ptr<MemoryBuffer>> Mapping = + MemoryBuffer::getFile(MapFile); + + if (!Mapping) + report_fatal_error("unable to read rewrite map '" + MapFile + "': " + + Mapping.getError().message()); + + if (!parse(*Mapping, DL)) + report_fatal_error("unable to parse rewrite map '" + MapFile + "'"); + + return true; +} + +bool RewriteMapParser::parse(std::unique_ptr<MemoryBuffer> &MapFile, + RewriteDescriptorList *DL) { + SourceMgr SM; + yaml::Stream YS(MapFile->getBuffer(), SM); + + for (auto &Document : YS) { + yaml::MappingNode *DescriptorList; + + // ignore empty documents + if (isa<yaml::NullNode>(Document.getRoot())) + continue; + + DescriptorList = dyn_cast<yaml::MappingNode>(Document.getRoot()); + if (!DescriptorList) { + YS.printError(Document.getRoot(), "DescriptorList node must be a map"); + return false; + } + + for (auto &Descriptor : *DescriptorList) + if (!parseEntry(YS, Descriptor, DL)) + return false; + } + + return true; +} + +bool RewriteMapParser::parseEntry(yaml::Stream &YS, yaml::KeyValueNode &Entry, + RewriteDescriptorList *DL) { + yaml::ScalarNode *Key; + yaml::MappingNode *Value; + SmallString<32> KeyStorage; + StringRef RewriteType; + + Key = dyn_cast<yaml::ScalarNode>(Entry.getKey()); + if (!Key) { + YS.printError(Entry.getKey(), "rewrite type must be a scalar"); + return false; + } + + Value = dyn_cast<yaml::MappingNode>(Entry.getValue()); + if (!Value) { + YS.printError(Entry.getValue(), "rewrite descriptor must be a map"); + return false; + } + + RewriteType = Key->getValue(KeyStorage); + if (RewriteType.equals("function")) + return parseRewriteFunctionDescriptor(YS, Key, Value, DL); + else if (RewriteType.equals("global variable")) + return parseRewriteGlobalVariableDescriptor(YS, Key, Value, DL); + else if (RewriteType.equals("global alias")) + return parseRewriteGlobalAliasDescriptor(YS, Key, Value, DL); + + YS.printError(Entry.getKey(), "unknown rewrite type"); + return false; +} + +bool RewriteMapParser:: +parseRewriteFunctionDescriptor(yaml::Stream &YS, yaml::ScalarNode *K, + yaml::MappingNode *Descriptor, + RewriteDescriptorList *DL) { + bool Naked = false; + std::string Source; + std::string Target; + std::string Transform; + + for (auto &Field : *Descriptor) { + yaml::ScalarNode *Key; + yaml::ScalarNode *Value; + SmallString<32> KeyStorage; + SmallString<32> ValueStorage; + StringRef KeyValue; + + Key = dyn_cast<yaml::ScalarNode>(Field.getKey()); + if (!Key) { + YS.printError(Field.getKey(), "descriptor key must be a scalar"); + return false; + } + + Value = dyn_cast<yaml::ScalarNode>(Field.getValue()); + if (!Value) { + YS.printError(Field.getValue(), "descriptor value must be a scalar"); + return false; + } + + KeyValue = Key->getValue(KeyStorage); + if (KeyValue.equals("source")) { + std::string Error; + + Source = Value->getValue(ValueStorage); + if (!Regex(Source).isValid(Error)) { + YS.printError(Field.getKey(), "invalid regex: " + Error); + return false; + } + } else if (KeyValue.equals("target")) { + Target = Value->getValue(ValueStorage); + } else if (KeyValue.equals("transform")) { + Transform = Value->getValue(ValueStorage); + } else if (KeyValue.equals("naked")) { + std::string Undecorated; + + Undecorated = Value->getValue(ValueStorage); + Naked = StringRef(Undecorated).lower() == "true" || Undecorated == "1"; + } else { + YS.printError(Field.getKey(), "unknown key for function"); + return false; + } + } + + if (Transform.empty() == Target.empty()) { + YS.printError(Descriptor, + "exactly one of transform or target must be specified"); + return false; + } + + // TODO see if there is a more elegant solution to selecting the rewrite + // descriptor type + if (!Target.empty()) + DL->push_back(new ExplicitRewriteFunctionDescriptor(Source, Target, Naked)); + else + DL->push_back(new PatternRewriteFunctionDescriptor(Source, Transform)); + + return true; +} + +bool RewriteMapParser:: +parseRewriteGlobalVariableDescriptor(yaml::Stream &YS, yaml::ScalarNode *K, + yaml::MappingNode *Descriptor, + RewriteDescriptorList *DL) { + std::string Source; + std::string Target; + std::string Transform; + + for (auto &Field : *Descriptor) { + yaml::ScalarNode *Key; + yaml::ScalarNode *Value; + SmallString<32> KeyStorage; + SmallString<32> ValueStorage; + StringRef KeyValue; + + Key = dyn_cast<yaml::ScalarNode>(Field.getKey()); + if (!Key) { + YS.printError(Field.getKey(), "descriptor Key must be a scalar"); + return false; + } + + Value = dyn_cast<yaml::ScalarNode>(Field.getValue()); + if (!Value) { + YS.printError(Field.getValue(), "descriptor value must be a scalar"); + return false; + } + + KeyValue = Key->getValue(KeyStorage); + if (KeyValue.equals("source")) { + std::string Error; + + Source = Value->getValue(ValueStorage); + if (!Regex(Source).isValid(Error)) { + YS.printError(Field.getKey(), "invalid regex: " + Error); + return false; + } + } else if (KeyValue.equals("target")) { + Target = Value->getValue(ValueStorage); + } else if (KeyValue.equals("transform")) { + Transform = Value->getValue(ValueStorage); + } else { + YS.printError(Field.getKey(), "unknown Key for Global Variable"); + return false; + } + } + + if (Transform.empty() == Target.empty()) { + YS.printError(Descriptor, + "exactly one of transform or target must be specified"); + return false; + } + + if (!Target.empty()) + DL->push_back(new ExplicitRewriteGlobalVariableDescriptor(Source, Target, + /*Naked*/false)); + else + DL->push_back(new PatternRewriteGlobalVariableDescriptor(Source, + Transform)); + + return true; +} + +bool RewriteMapParser:: +parseRewriteGlobalAliasDescriptor(yaml::Stream &YS, yaml::ScalarNode *K, + yaml::MappingNode *Descriptor, + RewriteDescriptorList *DL) { + std::string Source; + std::string Target; + std::string Transform; + + for (auto &Field : *Descriptor) { + yaml::ScalarNode *Key; + yaml::ScalarNode *Value; + SmallString<32> KeyStorage; + SmallString<32> ValueStorage; + StringRef KeyValue; + + Key = dyn_cast<yaml::ScalarNode>(Field.getKey()); + if (!Key) { + YS.printError(Field.getKey(), "descriptor key must be a scalar"); + return false; + } + + Value = dyn_cast<yaml::ScalarNode>(Field.getValue()); + if (!Value) { + YS.printError(Field.getValue(), "descriptor value must be a scalar"); + return false; + } + + KeyValue = Key->getValue(KeyStorage); + if (KeyValue.equals("source")) { + std::string Error; + + Source = Value->getValue(ValueStorage); + if (!Regex(Source).isValid(Error)) { + YS.printError(Field.getKey(), "invalid regex: " + Error); + return false; + } + } else if (KeyValue.equals("target")) { + Target = Value->getValue(ValueStorage); + } else if (KeyValue.equals("transform")) { + Transform = Value->getValue(ValueStorage); + } else { + YS.printError(Field.getKey(), "unknown key for Global Alias"); + return false; + } + } + + if (Transform.empty() == Target.empty()) { + YS.printError(Descriptor, + "exactly one of transform or target must be specified"); + return false; + } + + if (!Target.empty()) + DL->push_back(new ExplicitRewriteNamedAliasDescriptor(Source, Target, + /*Naked*/false)); + else + DL->push_back(new PatternRewriteNamedAliasDescriptor(Source, Transform)); + + return true; +} +} +} + +namespace { +class RewriteSymbols : public ModulePass { +public: + static char ID; // Pass identification, replacement for typeid + + RewriteSymbols(); + RewriteSymbols(SymbolRewriter::RewriteDescriptorList &DL); + + bool runOnModule(Module &M) override; + +private: + void loadAndParseMapFiles(); + + SymbolRewriter::RewriteDescriptorList Descriptors; +}; + +char RewriteSymbols::ID = 0; + +RewriteSymbols::RewriteSymbols() : ModulePass(ID) { + initializeRewriteSymbolsPass(*PassRegistry::getPassRegistry()); + loadAndParseMapFiles(); +} + +RewriteSymbols::RewriteSymbols(SymbolRewriter::RewriteDescriptorList &DL) + : ModulePass(ID) { + Descriptors.splice(Descriptors.begin(), DL); +} + +bool RewriteSymbols::runOnModule(Module &M) { + bool Changed; + + Changed = false; + for (auto &Descriptor : Descriptors) + Changed |= Descriptor.performOnModule(M); + + return Changed; +} + +void RewriteSymbols::loadAndParseMapFiles() { + const std::vector<std::string> MapFiles(RewriteMapFiles); + SymbolRewriter::RewriteMapParser parser; + + for (const auto &MapFile : MapFiles) + parser.parse(MapFile, &Descriptors); +} +} + +INITIALIZE_PASS(RewriteSymbols, "rewrite-symbols", "Rewrite Symbols", false, + false) + +ModulePass *llvm::createRewriteSymbolsPass() { return new RewriteSymbols(); } + +ModulePass * +llvm::createRewriteSymbolsPass(SymbolRewriter::RewriteDescriptorList &DL) { + return new RewriteSymbols(DL); +} diff --git a/lib/Transforms/Utils/ValueMapper.cpp b/lib/Transforms/Utils/ValueMapper.cpp index 0f20e6df6c96..477fba42412e 100644 --- a/lib/Transforms/Utils/ValueMapper.cpp +++ b/lib/Transforms/Utils/ValueMapper.cpp @@ -40,7 +40,7 @@ Value *llvm::MapValue(const Value *V, ValueToValueMapTy &VM, RemapFlags Flags, // Global values do not need to be seeded into the VM if they // are using the identity mapping. - if (isa<GlobalValue>(V) || isa<MDString>(V)) + if (isa<GlobalValue>(V)) return VM[V] = const_cast<Value*>(V); if (const InlineAsm *IA = dyn_cast<InlineAsm>(V)) { @@ -56,57 +56,24 @@ Value *llvm::MapValue(const Value *V, ValueToValueMapTy &VM, RemapFlags Flags, return VM[V] = const_cast<Value*>(V); } - - if (const MDNode *MD = dyn_cast<MDNode>(V)) { + if (const auto *MDV = dyn_cast<MetadataAsValue>(V)) { + const Metadata *MD = MDV->getMetadata(); // If this is a module-level metadata and we know that nothing at the module // level is changing, then use an identity mapping. - if (!MD->isFunctionLocal() && (Flags & RF_NoModuleLevelChanges)) - return VM[V] = const_cast<Value*>(V); - - // Create a dummy node in case we have a metadata cycle. - MDNode *Dummy = MDNode::getTemporary(V->getContext(), None); - VM[V] = Dummy; - - // Check all operands to see if any need to be remapped. - for (unsigned i = 0, e = MD->getNumOperands(); i != e; ++i) { - Value *OP = MD->getOperand(i); - if (!OP) continue; - Value *Mapped_OP = MapValue(OP, VM, Flags, TypeMapper, Materializer); - // Use identity map if Mapped_Op is null and we can ignore missing - // entries. - if (Mapped_OP == OP || - (Mapped_OP == nullptr && (Flags & RF_IgnoreMissingEntries))) - continue; - - // Ok, at least one operand needs remapping. - SmallVector<Value*, 4> Elts; - Elts.reserve(MD->getNumOperands()); - for (i = 0; i != e; ++i) { - Value *Op = MD->getOperand(i); - if (!Op) - Elts.push_back(nullptr); - else { - Value *Mapped_Op = MapValue(Op, VM, Flags, TypeMapper, Materializer); - // Use identity map if Mapped_Op is null and we can ignore missing - // entries. - if (Mapped_Op == nullptr && (Flags & RF_IgnoreMissingEntries)) - Mapped_Op = Op; - Elts.push_back(Mapped_Op); - } - } - MDNode *NewMD = MDNode::get(V->getContext(), Elts); - Dummy->replaceAllUsesWith(NewMD); - VM[V] = NewMD; - MDNode::deleteTemporary(Dummy); - return NewMD; - } + if (!isa<LocalAsMetadata>(MD) && (Flags & RF_NoModuleLevelChanges)) + return VM[V] = const_cast<Value *>(V); - VM[V] = const_cast<Value*>(V); - MDNode::deleteTemporary(Dummy); + auto *MappedMD = MapMetadata(MD, VM, Flags, TypeMapper, Materializer); + if (MD == MappedMD || (!MappedMD && (Flags & RF_IgnoreMissingEntries))) + return VM[V] = const_cast<Value *>(V); - // No operands needed remapping. Use an identity mapping. - return const_cast<Value*>(V); + // FIXME: This assert crashes during bootstrap, but I think it should be + // correct. For now, just match behaviour from before the metadata/value + // split. + // + // assert(MappedMD && "Referenced metadata value not in value map"); + return VM[V] = MetadataAsValue::get(V->getContext(), MappedMD); } // Okay, this either must be a constant (which may or may not be mappable) or @@ -177,6 +144,229 @@ Value *llvm::MapValue(const Value *V, ValueToValueMapTy &VM, RemapFlags Flags, return VM[V] = ConstantPointerNull::get(cast<PointerType>(NewTy)); } +static Metadata *mapToMetadata(ValueToValueMapTy &VM, const Metadata *Key, + Metadata *Val) { + VM.MD()[Key].reset(Val); + return Val; +} + +static Metadata *mapToSelf(ValueToValueMapTy &VM, const Metadata *MD) { + return mapToMetadata(VM, MD, const_cast<Metadata *>(MD)); +} + +static Metadata *MapMetadataImpl(const Metadata *MD, ValueToValueMapTy &VM, + RemapFlags Flags, + ValueMapTypeRemapper *TypeMapper, + ValueMaterializer *Materializer); + +static Metadata *mapMetadataOp(Metadata *Op, ValueToValueMapTy &VM, + RemapFlags Flags, + ValueMapTypeRemapper *TypeMapper, + ValueMaterializer *Materializer) { + if (!Op) + return nullptr; + if (Metadata *MappedOp = + MapMetadataImpl(Op, VM, Flags, TypeMapper, Materializer)) + return MappedOp; + // Use identity map if MappedOp is null and we can ignore missing entries. + if (Flags & RF_IgnoreMissingEntries) + return Op; + + // FIXME: This assert crashes during bootstrap, but I think it should be + // correct. For now, just match behaviour from before the metadata/value + // split. + // + // llvm_unreachable("Referenced metadata not in value map!"); + return nullptr; +} + +static Metadata *cloneMDTuple(const MDTuple *Node, ValueToValueMapTy &VM, + RemapFlags Flags, + ValueMapTypeRemapper *TypeMapper, + ValueMaterializer *Materializer, + bool IsDistinct) { + // Distinct MDTuples have their own code path. + assert(!IsDistinct && "Unexpected distinct tuple"); + (void)IsDistinct; + + SmallVector<Metadata *, 4> Elts; + Elts.reserve(Node->getNumOperands()); + for (unsigned I = 0, E = Node->getNumOperands(); I != E; ++I) + Elts.push_back(mapMetadataOp(Node->getOperand(I), VM, Flags, TypeMapper, + Materializer)); + + return MDTuple::get(Node->getContext(), Elts); +} + +static Metadata *cloneMDLocation(const MDLocation *Node, ValueToValueMapTy &VM, + RemapFlags Flags, + ValueMapTypeRemapper *TypeMapper, + ValueMaterializer *Materializer, + bool IsDistinct) { + return (IsDistinct ? MDLocation::getDistinct : MDLocation::get)( + Node->getContext(), Node->getLine(), Node->getColumn(), + mapMetadataOp(Node->getScope(), VM, Flags, TypeMapper, Materializer), + mapMetadataOp(Node->getInlinedAt(), VM, Flags, TypeMapper, Materializer)); +} + +static Metadata *cloneMDNode(const UniquableMDNode *Node, ValueToValueMapTy &VM, + RemapFlags Flags, ValueMapTypeRemapper *TypeMapper, + ValueMaterializer *Materializer, bool IsDistinct) { + switch (Node->getMetadataID()) { + default: + llvm_unreachable("Invalid UniquableMDNode subclass"); +#define HANDLE_UNIQUABLE_LEAF(CLASS) \ + case Metadata::CLASS##Kind: \ + return clone##CLASS(cast<CLASS>(Node), VM, Flags, TypeMapper, \ + Materializer, IsDistinct); +#include "llvm/IR/Metadata.def" + } +} + +/// \brief Map a distinct MDNode. +/// +/// Distinct nodes are not uniqued, so they must always recreated. +static Metadata *mapDistinctNode(const UniquableMDNode *Node, + ValueToValueMapTy &VM, RemapFlags Flags, + ValueMapTypeRemapper *TypeMapper, + ValueMaterializer *Materializer) { + assert(Node->isDistinct() && "Expected distinct node"); + + // Optimization for MDTuples. + if (isa<MDTuple>(Node)) { + // Create the node first so it's available for cyclical references. + SmallVector<Metadata *, 4> EmptyOps(Node->getNumOperands()); + MDTuple *NewMD = MDTuple::getDistinct(Node->getContext(), EmptyOps); + mapToMetadata(VM, Node, NewMD); + + // Fix the operands. + for (unsigned I = 0, E = Node->getNumOperands(); I != E; ++I) + NewMD->replaceOperandWith(I, mapMetadataOp(Node->getOperand(I), VM, Flags, + TypeMapper, Materializer)); + + return NewMD; + } + + // In general we need a dummy node, since whether the operands are null can + // affect the size of the node. + std::unique_ptr<MDNodeFwdDecl> Dummy( + MDNode::getTemporary(Node->getContext(), None)); + mapToMetadata(VM, Node, Dummy.get()); + Metadata *NewMD = cloneMDNode(Node, VM, Flags, TypeMapper, Materializer, + /* IsDistinct */ true); + Dummy->replaceAllUsesWith(NewMD); + return mapToMetadata(VM, Node, NewMD); +} + +/// \brief Check whether a uniqued node needs to be remapped. +/// +/// Check whether a uniqued node needs to be remapped (due to any operands +/// changing). +static bool shouldRemapUniquedNode(const UniquableMDNode *Node, + ValueToValueMapTy &VM, RemapFlags Flags, + ValueMapTypeRemapper *TypeMapper, + ValueMaterializer *Materializer) { + // Check all operands to see if any need to be remapped. + for (unsigned I = 0, E = Node->getNumOperands(); I != E; ++I) { + Metadata *Op = Node->getOperand(I); + if (Op != mapMetadataOp(Op, VM, Flags, TypeMapper, Materializer)) + return true; + } + return false; +} + +/// \brief Map a uniqued MDNode. +/// +/// Uniqued nodes may not need to be recreated (they may map to themselves). +static Metadata *mapUniquedNode(const UniquableMDNode *Node, + ValueToValueMapTy &VM, RemapFlags Flags, + ValueMapTypeRemapper *TypeMapper, + ValueMaterializer *Materializer) { + assert(!Node->isDistinct() && "Expected uniqued node"); + + // Create a dummy node in case we have a metadata cycle. + MDNodeFwdDecl *Dummy = MDNode::getTemporary(Node->getContext(), None); + mapToMetadata(VM, Node, Dummy); + + // Check all operands to see if any need to be remapped. + if (!shouldRemapUniquedNode(Node, VM, Flags, TypeMapper, Materializer)) { + // Use an identity mapping. + mapToSelf(VM, Node); + MDNode::deleteTemporary(Dummy); + return const_cast<Metadata *>(static_cast<const Metadata *>(Node)); + } + + // At least one operand needs remapping. + Metadata *NewMD = cloneMDNode(Node, VM, Flags, TypeMapper, Materializer, + /* IsDistinct */ false); + Dummy->replaceAllUsesWith(NewMD); + MDNode::deleteTemporary(Dummy); + return mapToMetadata(VM, Node, NewMD); +} + +static Metadata *MapMetadataImpl(const Metadata *MD, ValueToValueMapTy &VM, + RemapFlags Flags, + ValueMapTypeRemapper *TypeMapper, + ValueMaterializer *Materializer) { + // If the value already exists in the map, use it. + if (Metadata *NewMD = VM.MD().lookup(MD).get()) + return NewMD; + + if (isa<MDString>(MD)) + return mapToSelf(VM, MD); + + if (isa<ConstantAsMetadata>(MD)) + if ((Flags & RF_NoModuleLevelChanges)) + return mapToSelf(VM, MD); + + if (const auto *VMD = dyn_cast<ValueAsMetadata>(MD)) { + Value *MappedV = + MapValue(VMD->getValue(), VM, Flags, TypeMapper, Materializer); + if (VMD->getValue() == MappedV || + (!MappedV && (Flags & RF_IgnoreMissingEntries))) + return mapToSelf(VM, MD); + + // FIXME: This assert crashes during bootstrap, but I think it should be + // correct. For now, just match behaviour from before the metadata/value + // split. + // + // assert(MappedV && "Referenced metadata not in value map!"); + if (MappedV) + return mapToMetadata(VM, MD, ValueAsMetadata::get(MappedV)); + return nullptr; + } + + const UniquableMDNode *Node = cast<UniquableMDNode>(MD); + assert(Node->isResolved() && "Unexpected unresolved node"); + + // If this is a module-level metadata and we know that nothing at the + // module level is changing, then use an identity mapping. + if (Flags & RF_NoModuleLevelChanges) + return mapToSelf(VM, MD); + + if (Node->isDistinct()) + return mapDistinctNode(Node, VM, Flags, TypeMapper, Materializer); + + return mapUniquedNode(Node, VM, Flags, TypeMapper, Materializer); +} + +Metadata *llvm::MapMetadata(const Metadata *MD, ValueToValueMapTy &VM, + RemapFlags Flags, ValueMapTypeRemapper *TypeMapper, + ValueMaterializer *Materializer) { + Metadata *NewMD = MapMetadataImpl(MD, VM, Flags, TypeMapper, Materializer); + if (NewMD && NewMD != MD) + if (auto *N = dyn_cast<UniquableMDNode>(NewMD)) + N->resolveCycles(); + return NewMD; +} + +MDNode *llvm::MapMetadata(const MDNode *MD, ValueToValueMapTy &VM, + RemapFlags Flags, ValueMapTypeRemapper *TypeMapper, + ValueMaterializer *Materializer) { + return cast<MDNode>(MapMetadata(static_cast<const Metadata *>(MD), VM, Flags, + TypeMapper, Materializer)); +} + /// RemapInstruction - Convert the instruction operands from referencing the /// current values into those specified by VMap. /// @@ -210,10 +400,12 @@ void llvm::RemapInstruction(Instruction *I, ValueToValueMapTy &VMap, // Remap attached metadata. SmallVector<std::pair<unsigned, MDNode *>, 4> MDs; I->getAllMetadata(MDs); - for (SmallVectorImpl<std::pair<unsigned, MDNode *> >::iterator - MI = MDs.begin(), ME = MDs.end(); MI != ME; ++MI) { + for (SmallVectorImpl<std::pair<unsigned, MDNode *>>::iterator + MI = MDs.begin(), + ME = MDs.end(); + MI != ME; ++MI) { MDNode *Old = MI->second; - MDNode *New = MapValue(Old, VMap, Flags, TypeMapper, Materializer); + MDNode *New = MapMetadata(Old, VMap, Flags, TypeMapper, Materializer); if (New != Old) I->setMetadata(MI->first, New); } diff --git a/lib/Transforms/Vectorize/BBVectorize.cpp b/lib/Transforms/Vectorize/BBVectorize.cpp index 28ec83bf8683..a0ccf9d7b8cd 100644 --- a/lib/Transforms/Vectorize/BBVectorize.cpp +++ b/lib/Transforms/Vectorize/BBVectorize.cpp @@ -391,8 +391,6 @@ namespace { Instruction *&InsertionPt, Instruction *I, Instruction *J); - void combineMetadata(Instruction *K, const Instruction *J); - bool vectorizeBB(BasicBlock &BB) { if (skipOptnoneFunction(BB)) return false; @@ -687,6 +685,8 @@ namespace { case Intrinsic::trunc: case Intrinsic::floor: case Intrinsic::fabs: + case Intrinsic::minnum: + case Intrinsic::maxnum: return Config.VectorizeMath; case Intrinsic::bswap: case Intrinsic::ctpop: @@ -2609,7 +2609,6 @@ namespace { true, o, 1)); NewI1->insertBefore(IBeforeJ ? J : I); I1 = NewI1; - I1T = I2T; I1Elem = I2Elem; } else if (I1Elem > I2Elem) { std::vector<Constant *> Mask(I1Elem); @@ -2626,8 +2625,6 @@ namespace { true, o, 1)); NewI2->insertBefore(IBeforeJ ? J : I); I2 = NewI2; - I2T = I1T; - I2Elem = I1Elem; } // Now that both I1 and I2 are the same length we can shuffle them @@ -2964,31 +2961,6 @@ namespace { } } - // When the first instruction in each pair is cloned, it will inherit its - // parent's metadata. This metadata must be combined with that of the other - // instruction in a safe way. - void BBVectorize::combineMetadata(Instruction *K, const Instruction *J) { - SmallVector<std::pair<unsigned, MDNode*>, 4> Metadata; - K->getAllMetadataOtherThanDebugLoc(Metadata); - for (unsigned i = 0, n = Metadata.size(); i < n; ++i) { - unsigned Kind = Metadata[i].first; - MDNode *JMD = J->getMetadata(Kind); - MDNode *KMD = Metadata[i].second; - - switch (Kind) { - default: - K->setMetadata(Kind, nullptr); // Remove unknown metadata - break; - case LLVMContext::MD_tbaa: - K->setMetadata(Kind, MDNode::getMostGenericTBAA(JMD, KMD)); - break; - case LLVMContext::MD_fpmath: - K->setMetadata(Kind, MDNode::getMostGenericFPMath(JMD, KMD)); - break; - } - } - } - // This function fuses the chosen instruction pairs into vector instructions, // taking care preserve any needed scalar outputs and, then, it reorders the // remaining instructions as needed (users of the first member of the pair @@ -3138,7 +3110,13 @@ namespace { if (!isa<StoreInst>(K)) K->mutateType(getVecTypeForPair(L->getType(), H->getType())); - combineMetadata(K, H); + unsigned KnownIDs[] = { + LLVMContext::MD_tbaa, + LLVMContext::MD_alias_scope, + LLVMContext::MD_noalias, + LLVMContext::MD_fpmath + }; + combineMetadata(K, H, KnownIDs); K->intersectOptionalDataWith(H); for (unsigned o = 0; o < NumOperands; ++o) diff --git a/lib/Transforms/Vectorize/LoopVectorize.cpp b/lib/Transforms/Vectorize/LoopVectorize.cpp index 79fcb09f8913..557304ed56c5 100644 --- a/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -55,7 +55,9 @@ #include "llvm/ADT/StringExtras.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/AliasSetTracker.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/Analysis/CodeMetrics.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopIterator.h" #include "llvm/Analysis/LoopPass.h" @@ -108,8 +110,8 @@ VectorizationFactor("force-vector-width", cl::init(0), cl::Hidden, cl::desc("Sets the SIMD width. Zero is autoselect.")); static cl::opt<unsigned> -VectorizationUnroll("force-vector-unroll", cl::init(0), cl::Hidden, - cl::desc("Sets the vectorization unroll count. " +VectorizationInterleave("force-vector-interleave", cl::init(0), cl::Hidden, + cl::desc("Sets the vectorization interleave count. " "Zero is autoselect.")); static cl::opt<bool> @@ -157,17 +159,17 @@ static cl::opt<unsigned> ForceTargetNumVectorRegs( "force-target-num-vector-regs", cl::init(0), cl::Hidden, cl::desc("A flag that overrides the target's number of vector registers.")); -/// Maximum vectorization unroll count. -static const unsigned MaxUnrollFactor = 16; +/// Maximum vectorization interleave count. +static const unsigned MaxInterleaveFactor = 16; -static cl::opt<unsigned> ForceTargetMaxScalarUnrollFactor( - "force-target-max-scalar-unroll", cl::init(0), cl::Hidden, - cl::desc("A flag that overrides the target's max unroll factor for scalar " - "loops.")); +static cl::opt<unsigned> ForceTargetMaxScalarInterleaveFactor( + "force-target-max-scalar-interleave", cl::init(0), cl::Hidden, + cl::desc("A flag that overrides the target's max interleave factor for " + "scalar loops.")); -static cl::opt<unsigned> ForceTargetMaxVectorUnrollFactor( - "force-target-max-vector-unroll", cl::init(0), cl::Hidden, - cl::desc("A flag that overrides the target's max unroll factor for " +static cl::opt<unsigned> ForceTargetMaxVectorInterleaveFactor( + "force-target-max-vector-interleave", cl::init(0), cl::Hidden, + cl::desc("A flag that overrides the target's max interleave factor for " "vectorized loops.")); static cl::opt<unsigned> ForceTargetInstructionCost( @@ -204,11 +206,17 @@ static cl::opt<bool> EnableCondStoresVectorization( "enable-cond-stores-vec", cl::init(false), cl::Hidden, cl::desc("Enable if predication of stores during vectorization.")); +static cl::opt<unsigned> MaxNestedScalarReductionUF( + "max-nested-scalar-reduction-unroll", cl::init(2), cl::Hidden, + cl::desc("The maximum unroll factor to use when unrolling a scalar " + "reduction in a nested loop.")); + namespace { // Forward declarations. class LoopVectorizationLegality; class LoopVectorizationCostModel; +class LoopVectorizeHints; /// Optimization analysis message produced during vectorization. Messages inform /// the user why vectorization did not occur. @@ -535,6 +543,8 @@ static void propagateMetadata(Instruction *To, const Instruction *From) { // non-speculated memory access when the condition was false, this would be // caught by the runtime overlap checks). if (Kind != LLVMContext::MD_tbaa && + Kind != LLVMContext::MD_alias_scope && + Kind != LLVMContext::MD_noalias && Kind != LLVMContext::MD_fpmath) continue; @@ -570,9 +580,10 @@ public: LoopVectorizationLegality(Loop *L, ScalarEvolution *SE, const DataLayout *DL, DominatorTree *DT, TargetLibraryInfo *TLI, - AliasAnalysis *AA, Function *F) + AliasAnalysis *AA, Function *F, + const TargetTransformInfo *TTI) : NumLoads(0), NumStores(0), NumPredStores(0), TheLoop(L), SE(SE), DL(DL), - DT(DT), TLI(TLI), AA(AA), TheFunction(F), Induction(nullptr), + DT(DT), TLI(TLI), AA(AA), TheFunction(F), TTI(TTI), Induction(nullptr), WidestIndTy(nullptr), HasFunNoNaNAttr(false), MaxSafeDepDistBytes(-1U) { } @@ -758,6 +769,21 @@ public: } SmallPtrSet<Value *, 8>::iterator strides_end() { return StrideSet.end(); } + /// Returns true if the target machine supports masked store operation + /// for the given \p DataType and kind of access to \p Ptr. + bool isLegalMaskedStore(Type *DataType, Value *Ptr) { + return TTI->isLegalMaskedStore(DataType, isConsecutivePtr(Ptr)); + } + /// Returns true if the target machine supports masked load operation + /// for the given \p DataType and kind of access to \p Ptr. + bool isLegalMaskedLoad(Type *DataType, Value *Ptr) { + return TTI->isLegalMaskedLoad(DataType, isConsecutivePtr(Ptr)); + } + /// Returns true if vector representation of the instruction \p I + /// requires mask. + bool isMaskRequired(const Instruction* I) { + return (MaskedOp.count(I) != 0); + } private: /// Check if a single basic block loop is vectorizable. /// At this point we know that this is a loop with a constant trip count @@ -780,7 +806,7 @@ private: /// Return true if all of the instructions in the block can be speculatively /// executed. \p SafePtrs is a list of addresses that are known to be legal /// and we know that we can read from them without segfault. - bool blockCanBePredicated(BasicBlock *BB, SmallPtrSet<Value *, 8>& SafePtrs); + bool blockCanBePredicated(BasicBlock *BB, SmallPtrSetImpl<Value *> &SafePtrs); /// Returns True, if 'Phi' is the kind of reduction variable for type /// 'Kind'. If this is a reduction variable, it adds it to ReductionList. @@ -804,7 +830,7 @@ private: /// /// Looks for accesses like "a[i * StrideA]" where "StrideA" is loop /// invariant. - void collectStridedAcccess(Value *LoadOrStoreInst); + void collectStridedAccess(Value *LoadOrStoreInst); /// Report an analysis message to assist the user in diagnosing loops that are /// not vectorized. @@ -830,6 +856,8 @@ private: AliasAnalysis *AA; /// Parent function Function *TheFunction; + /// Target Transform Info + const TargetTransformInfo *TTI; // --- vectorization state --- // @@ -861,6 +889,10 @@ private: ValueToValueMap Strides; SmallPtrSet<Value *, 8> StrideSet; + + /// While vectorizing these instructions we have to generate a + /// call to the appropriate masked intrinsic + SmallPtrSet<const Instruction*, 8> MaskedOp; }; /// LoopVectorizationCostModel - estimates the expected speedups due to @@ -875,8 +907,13 @@ public: LoopVectorizationCostModel(Loop *L, ScalarEvolution *SE, LoopInfo *LI, LoopVectorizationLegality *Legal, const TargetTransformInfo &TTI, - const DataLayout *DL, const TargetLibraryInfo *TLI) - : TheLoop(L), SE(SE), LI(LI), Legal(Legal), TTI(TTI), DL(DL), TLI(TLI) {} + const DataLayout *DL, const TargetLibraryInfo *TLI, + AssumptionCache *AC, const Function *F, + const LoopVectorizeHints *Hints) + : TheLoop(L), SE(SE), LI(LI), Legal(Legal), TTI(TTI), DL(DL), TLI(TLI), + TheFunction(F), Hints(Hints) { + CodeMetrics::collectEphemeralValues(L, AC, EphValues); + } /// Information about vectorization costs struct VectorizationFactor { @@ -887,9 +924,7 @@ public: /// This method checks every power of two up to VF. If UserVF is not ZERO /// then this vectorization factor will be selected if vectorization is /// possible. - VectorizationFactor selectVectorizationFactor(bool OptForSize, - unsigned UserVF, - bool ForceVectorization); + VectorizationFactor selectVectorizationFactor(bool OptForSize); /// \return The size (in bits) of the widest type in the code that /// needs to be vectorized. We ignore values that remain scalar such as @@ -901,8 +936,7 @@ public: /// based on register pressure and other parameters. /// VF and LoopCost are the selected vectorization factor and the cost of the /// selected VF. - unsigned selectUnrollFactor(bool OptForSize, unsigned UserUF, unsigned VF, - unsigned LoopCost); + unsigned selectUnrollFactor(bool OptForSize, unsigned VF, unsigned LoopCost); /// \brief A struct that represents some properties of the register usage /// of a loop. @@ -938,6 +972,19 @@ private: /// as a vector operation. bool isConsecutiveLoadOrStore(Instruction *I); + /// Report an analysis message to assist the user in diagnosing loops that are + /// not vectorized. + void emitAnalysis(Report &Message) { + DebugLoc DL = TheLoop->getStartLoc(); + if (Instruction *I = Message.getInstr()) + DL = I->getDebugLoc(); + emitOptimizationRemarkAnalysis(TheFunction->getContext(), DEBUG_TYPE, + *TheFunction, DL, Message.str()); + } + + /// Values used only by @llvm.assume calls. + SmallPtrSet<const Value *, 32> EphValues; + /// The loop that we evaluate. Loop *TheLoop; /// Scev analysis. @@ -952,11 +999,59 @@ private: const DataLayout *DL; /// Target Library Info. const TargetLibraryInfo *TLI; + const Function *TheFunction; + // Loop Vectorize Hint. + const LoopVectorizeHints *Hints; }; /// Utility class for getting and setting loop vectorizer hints in the form /// of loop metadata. +/// This class keeps a number of loop annotations locally (as member variables) +/// and can, upon request, write them back as metadata on the loop. It will +/// initially scan the loop for existing metadata, and will update the local +/// values based on information in the loop. +/// We cannot write all values to metadata, as the mere presence of some info, +/// for example 'force', means a decision has been made. So, we need to be +/// careful NOT to add them if the user hasn't specifically asked so. class LoopVectorizeHints { + enum HintKind { + HK_WIDTH, + HK_UNROLL, + HK_FORCE + }; + + /// Hint - associates name and validation with the hint value. + struct Hint { + const char * Name; + unsigned Value; // This may have to change for non-numeric values. + HintKind Kind; + + Hint(const char * Name, unsigned Value, HintKind Kind) + : Name(Name), Value(Value), Kind(Kind) { } + + bool validate(unsigned Val) { + switch (Kind) { + case HK_WIDTH: + return isPowerOf2_32(Val) && Val <= MaxVectorWidth; + case HK_UNROLL: + return isPowerOf2_32(Val) && Val <= MaxInterleaveFactor; + case HK_FORCE: + return (Val <= 1); + } + return false; + } + }; + + /// Vectorization width. + Hint Width; + /// Vectorization interleave factor. + Hint Interleave; + /// Vectorization forced + Hint Force; + + /// Return the loop metadata prefix. + static StringRef Prefix() { return "llvm.loop."; } + public: enum ForceKind { FK_Undefined = -1, ///< Not selected. @@ -964,90 +1059,57 @@ public: FK_Enabled = 1, ///< Forcing enabled. }; - LoopVectorizeHints(const Loop *L, bool DisableUnrolling) - : Width(VectorizationFactor), - Unroll(DisableUnrolling), - Force(FK_Undefined), - LoopID(L->getLoopID()) { - getHints(L); - // force-vector-unroll overrides DisableUnrolling. - if (VectorizationUnroll.getNumOccurrences() > 0) - Unroll = VectorizationUnroll; + LoopVectorizeHints(const Loop *L, bool DisableInterleaving) + : Width("vectorize.width", VectorizationFactor, HK_WIDTH), + Interleave("interleave.count", DisableInterleaving, HK_UNROLL), + Force("vectorize.enable", FK_Undefined, HK_FORCE), + TheLoop(L) { + // Populate values with existing loop metadata. + getHintsFromMetadata(); - DEBUG(if (DisableUnrolling && Unroll == 1) dbgs() - << "LV: Unrolling disabled by the pass manager\n"); - } - - /// Return the loop metadata prefix. - static StringRef Prefix() { return "llvm.loop."; } + // force-vector-interleave overrides DisableInterleaving. + if (VectorizationInterleave.getNumOccurrences() > 0) + Interleave.Value = VectorizationInterleave; - MDNode *createHint(LLVMContext &Context, StringRef Name, unsigned V) const { - SmallVector<Value*, 2> Vals; - Vals.push_back(MDString::get(Context, Name)); - Vals.push_back(ConstantInt::get(Type::getInt32Ty(Context), V)); - return MDNode::get(Context, Vals); + DEBUG(if (DisableInterleaving && Interleave.Value == 1) dbgs() + << "LV: Interleaving disabled by the pass manager\n"); } /// Mark the loop L as already vectorized by setting the width to 1. - void setAlreadyVectorized(Loop *L) { - LLVMContext &Context = L->getHeader()->getContext(); - - Width = 1; - - // Create a new loop id with one more operand for the already_vectorized - // hint. If the loop already has a loop id then copy the existing operands. - SmallVector<Value*, 4> Vals(1); - if (LoopID) - for (unsigned i = 1, ie = LoopID->getNumOperands(); i < ie; ++i) - Vals.push_back(LoopID->getOperand(i)); - - Vals.push_back( - createHint(Context, Twine(Prefix(), "vectorize.width").str(), Width)); - Vals.push_back( - createHint(Context, Twine(Prefix(), "interleave.count").str(), 1)); - - MDNode *NewLoopID = MDNode::get(Context, Vals); - // Set operand 0 to refer to the loop id itself. - NewLoopID->replaceOperandWith(0, NewLoopID); - - L->setLoopID(NewLoopID); - if (LoopID) - LoopID->replaceAllUsesWith(NewLoopID); - - LoopID = NewLoopID; + void setAlreadyVectorized() { + Width.Value = Interleave.Value = 1; + Hint Hints[] = {Width, Interleave}; + writeHintsToMetadata(Hints); } + /// Dumps all the hint information. std::string emitRemark() const { Report R; - R << "vectorization "; - switch (Force) { - case LoopVectorizeHints::FK_Disabled: - R << "is explicitly disabled"; - break; - case LoopVectorizeHints::FK_Enabled: - R << "is explicitly enabled"; - if (Width != 0 && Unroll != 0) - R << " with width " << Width << " and interleave count " << Unroll; - else if (Width != 0) - R << " with width " << Width; - else if (Unroll != 0) - R << " with interleave count " << Unroll; - break; - case LoopVectorizeHints::FK_Undefined: - R << "was not specified"; - break; + if (Force.Value == LoopVectorizeHints::FK_Disabled) + R << "vectorization is explicitly disabled"; + else { + R << "use -Rpass-analysis=loop-vectorize for more info"; + if (Force.Value == LoopVectorizeHints::FK_Enabled) { + R << " (Force=true"; + if (Width.Value != 0) + R << ", Vector Width=" << Width.Value; + if (Interleave.Value != 0) + R << ", Interleave Count=" << Interleave.Value; + R << ")"; + } } + return R.str(); } - unsigned getWidth() const { return Width; } - unsigned getUnroll() const { return Unroll; } - enum ForceKind getForce() const { return Force; } - MDNode *getLoopID() const { return LoopID; } + unsigned getWidth() const { return Width.Value; } + unsigned getInterleave() const { return Interleave.Value; } + enum ForceKind getForce() const { return (ForceKind)Force.Value; } private: - /// Find hints specified in the loop metadata. - void getHints(const Loop *L) { + /// Find hints specified in the loop metadata and update local values. + void getHintsFromMetadata() { + MDNode *LoopID = TheLoop->getLoopID(); if (!LoopID) return; @@ -1057,7 +1119,7 @@ private: for (unsigned i = 1, ie = LoopID->getNumOperands(); i < ie; ++i) { const MDString *S = nullptr; - SmallVector<Value*, 4> Args; + SmallVector<Metadata *, 4> Args; // The expected hint is either a MDString or a MDNode with the first // operand a MDString. @@ -1076,52 +1138,88 @@ private: continue; // Check if the hint starts with the loop metadata prefix. - StringRef Hint = S->getString(); - if (!Hint.startswith(Prefix())) - continue; - // Remove the prefix. - Hint = Hint.substr(Prefix().size(), StringRef::npos); - + StringRef Name = S->getString(); if (Args.size() == 1) - getHint(Hint, Args[0]); + setHint(Name, Args[0]); } } - // Check string hint with one operand. - void getHint(StringRef Hint, Value *Arg) { - const ConstantInt *C = dyn_cast<ConstantInt>(Arg); + /// Checks string hint with one operand and set value if valid. + void setHint(StringRef Name, Metadata *Arg) { + if (!Name.startswith(Prefix())) + return; + Name = Name.substr(Prefix().size(), StringRef::npos); + + const ConstantInt *C = mdconst::dyn_extract<ConstantInt>(Arg); if (!C) return; unsigned Val = C->getZExtValue(); - if (Hint == "vectorize.width") { - if (isPowerOf2_32(Val) && Val <= MaxVectorWidth) - Width = Val; - else - DEBUG(dbgs() << "LV: ignoring invalid width hint metadata\n"); - } else if (Hint == "vectorize.enable") { - if (C->getBitWidth() == 1) - Force = Val == 1 ? LoopVectorizeHints::FK_Enabled - : LoopVectorizeHints::FK_Disabled; - else - DEBUG(dbgs() << "LV: ignoring invalid enable hint metadata\n"); - } else if (Hint == "interleave.count") { - if (isPowerOf2_32(Val) && Val <= MaxUnrollFactor) - Unroll = Val; - else - DEBUG(dbgs() << "LV: ignoring invalid unroll hint metadata\n"); - } else { - DEBUG(dbgs() << "LV: ignoring unknown hint " << Hint << '\n'); + Hint *Hints[] = {&Width, &Interleave, &Force}; + for (auto H : Hints) { + if (Name == H->Name) { + if (H->validate(Val)) + H->Value = Val; + else + DEBUG(dbgs() << "LV: ignoring invalid hint '" << Name << "'\n"); + break; + } } } - /// Vectorization width. - unsigned Width; - /// Vectorization unroll factor. - unsigned Unroll; - /// Vectorization forced - enum ForceKind Force; + /// Create a new hint from name / value pair. + MDNode *createHintMetadata(StringRef Name, unsigned V) const { + LLVMContext &Context = TheLoop->getHeader()->getContext(); + Metadata *MDs[] = {MDString::get(Context, Name), + ConstantAsMetadata::get( + ConstantInt::get(Type::getInt32Ty(Context), V))}; + return MDNode::get(Context, MDs); + } + + /// Matches metadata with hint name. + bool matchesHintMetadataName(MDNode *Node, ArrayRef<Hint> HintTypes) { + MDString* Name = dyn_cast<MDString>(Node->getOperand(0)); + if (!Name) + return false; - MDNode *LoopID; + for (auto H : HintTypes) + if (Name->getString().endswith(H.Name)) + return true; + return false; + } + + /// Sets current hints into loop metadata, keeping other values intact. + void writeHintsToMetadata(ArrayRef<Hint> HintTypes) { + if (HintTypes.size() == 0) + return; + + // Reserve the first element to LoopID (see below). + SmallVector<Metadata *, 4> MDs(1); + // If the loop already has metadata, then ignore the existing operands. + MDNode *LoopID = TheLoop->getLoopID(); + if (LoopID) { + for (unsigned i = 1, ie = LoopID->getNumOperands(); i < ie; ++i) { + MDNode *Node = cast<MDNode>(LoopID->getOperand(i)); + // If node in update list, ignore old value. + if (!matchesHintMetadataName(Node, HintTypes)) + MDs.push_back(Node); + } + } + + // Now, add the missing hints. + for (auto H : HintTypes) + MDs.push_back(createHintMetadata(Twine(Prefix(), H.Name).str(), H.Value)); + + // Replace current metadata node with new one. + LLVMContext &Context = TheLoop->getHeader()->getContext(); + MDNode *NewLoopID = MDNode::get(Context, MDs); + // Set operand 0 to refer to the loop id itself. + NewLoopID->replaceOperandWith(0, NewLoopID); + + TheLoop->setLoopID(NewLoopID); + } + + /// The loop these hints belong to. + const Loop *TheLoop; }; static void emitMissedWarning(Function *F, Loop *L, @@ -1134,7 +1232,7 @@ static void emitMissedWarning(Function *F, Loop *L, emitLoopVectorizeWarning( F->getContext(), *F, L->getStartLoc(), "failed explicitly specified loop vectorization"); - else if (LH.getUnroll() != 1) + else if (LH.getInterleave() != 1) emitLoopInterleaveWarning( F->getContext(), *F, L->getStartLoc(), "failed explicitly specified loop interleaving"); @@ -1169,6 +1267,7 @@ struct LoopVectorize : public FunctionPass { BlockFrequencyInfo *BFI; TargetLibraryInfo *TLI; AliasAnalysis *AA; + AssumptionCache *AC; bool DisableUnrolling; bool AlwaysVectorize; @@ -1184,6 +1283,7 @@ struct LoopVectorize : public FunctionPass { BFI = &getAnalysis<BlockFrequencyInfo>(); TLI = getAnalysisIfAvailable<TargetLibraryInfo>(); AA = &getAnalysis<AliasAnalysis>(); + AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); // Compute some weights outside of the loop over the loops. Compute this // using a BranchProbability to re-use its scaling math. @@ -1240,7 +1340,7 @@ struct LoopVectorize : public FunctionPass { : (Hints.getForce() == LoopVectorizeHints::FK_Enabled ? "enabled" : "?")) << " width=" << Hints.getWidth() - << " unroll=" << Hints.getUnroll() << "\n"); + << " unroll=" << Hints.getInterleave() << "\n"); // Function containing loop Function *F = L->getHeader()->getParent(); @@ -1267,7 +1367,7 @@ struct LoopVectorize : public FunctionPass { return false; } - if (Hints.getWidth() == 1 && Hints.getUnroll() == 1) { + if (Hints.getWidth() == 1 && Hints.getInterleave() == 1) { DEBUG(dbgs() << "LV: Not vectorizing: Disabled/already vectorized.\n"); emitOptimizationRemarkAnalysis( F->getContext(), DEBUG_TYPE, *F, L->getStartLoc(), @@ -1278,8 +1378,7 @@ struct LoopVectorize : public FunctionPass { // Check the loop for a trip count threshold: // do not vectorize loops with a tiny trip count. - BasicBlock *Latch = L->getLoopLatch(); - const unsigned TC = SE->getSmallConstantTripCount(L, Latch); + const unsigned TC = SE->getSmallConstantTripCount(L); if (TC > 0u && TC < TinyTripCountVectorThreshold) { DEBUG(dbgs() << "LV: Found a loop with a very small trip count. " << "This loop is not worth vectorizing."); @@ -1295,7 +1394,7 @@ struct LoopVectorize : public FunctionPass { } // Check if it is legal to vectorize the loop. - LoopVectorizationLegality LVL(L, SE, DL, DT, TLI, AA, F); + LoopVectorizationLegality LVL(L, SE, DL, DT, TLI, AA, F, TTI); if (!LVL.canVectorize()) { DEBUG(dbgs() << "LV: Not vectorizing: Cannot prove legality.\n"); emitMissedWarning(F, L, Hints); @@ -1303,7 +1402,8 @@ struct LoopVectorize : public FunctionPass { } // Use the cost model. - LoopVectorizationCostModel CM(L, SE, LI, &LVL, *TTI, DL, TLI); + LoopVectorizationCostModel CM(L, SE, LI, &LVL, *TTI, DL, TLI, AC, F, + &Hints); // Check the function attributes to find out if this function should be // optimized for size. @@ -1338,13 +1438,11 @@ struct LoopVectorize : public FunctionPass { // Select the optimal vectorization factor. const LoopVectorizationCostModel::VectorizationFactor VF = - CM.selectVectorizationFactor(OptForSize, Hints.getWidth(), - Hints.getForce() == - LoopVectorizeHints::FK_Enabled); + CM.selectVectorizationFactor(OptForSize); // Select the unroll factor. const unsigned UF = - CM.selectUnrollFactor(OptForSize, Hints.getUnroll(), VF.Width, VF.Cost); + CM.selectUnrollFactor(OptForSize, VF.Width, VF.Cost); DEBUG(dbgs() << "LV: Found a vectorizable loop (" << VF.Width << ") in " << DebugLocStr << '\n'); @@ -1385,13 +1483,14 @@ struct LoopVectorize : public FunctionPass { } // Mark the loop as already vectorized to avoid vectorizing again. - Hints.setAlreadyVectorized(L); + Hints.setAlreadyVectorized(); DEBUG(verifyFunction(*L->getHeader()->getParent())); return true; } void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<AssumptionCacheTracker>(); AU.addRequiredID(LoopSimplifyID); AU.addRequiredID(LCSSAID); AU.addRequired<BlockFrequencyInfo>(); @@ -1683,7 +1782,8 @@ void InnerLoopVectorizer::vectorizeMemoryInstruction(Instruction *Instr) { unsigned ScalarAllocatedSize = DL->getTypeAllocSize(ScalarDataTy); unsigned VectorElementSize = DL->getTypeStoreSize(DataTy)/VF; - if (SI && Legal->blockNeedsPredication(SI->getParent())) + if (SI && Legal->blockNeedsPredication(SI->getParent()) && + !Legal->isMaskRequired(SI)) return scalarizeInstruction(Instr, true); if (ScalarAllocatedSize != VectorElementSize) @@ -1752,6 +1852,7 @@ void InnerLoopVectorizer::vectorizeMemoryInstruction(Instruction *Instr) { Ptr = Builder.CreateExtractElement(PtrVal[0], Zero); } + VectorParts Mask = createBlockInMask(Instr->getParent()); // Handle Stores: if (SI) { assert(!Legal->isUniform(SI->getPointerOperand()) && @@ -1760,7 +1861,7 @@ void InnerLoopVectorizer::vectorizeMemoryInstruction(Instruction *Instr) { // We don't want to update the value in the map as it might be used in // another expression. So don't use a reference type for "StoredVal". VectorParts StoredVal = getVectorValue(SI->getValueOperand()); - + for (unsigned Part = 0; Part < UF; ++Part) { // Calculate the pointer for the specific unroll-part. Value *PartPtr = Builder.CreateGEP(Ptr, Builder.getInt32(Part * VF)); @@ -1777,8 +1878,13 @@ void InnerLoopVectorizer::vectorizeMemoryInstruction(Instruction *Instr) { Value *VecPtr = Builder.CreateBitCast(PartPtr, DataTy->getPointerTo(AddressSpace)); - StoreInst *NewSI = - Builder.CreateAlignedStore(StoredVal[Part], VecPtr, Alignment); + + Instruction *NewSI; + if (Legal->isMaskRequired(SI)) + NewSI = Builder.CreateMaskedStore(StoredVal[Part], VecPtr, Alignment, + Mask[Part]); + else + NewSI = Builder.CreateAlignedStore(StoredVal[Part], VecPtr, Alignment); propagateMetadata(NewSI, SI); } return; @@ -1793,14 +1899,20 @@ void InnerLoopVectorizer::vectorizeMemoryInstruction(Instruction *Instr) { if (Reverse) { // If the address is consecutive but reversed, then the - // wide store needs to start at the last vector element. + // wide load needs to start at the last vector element. PartPtr = Builder.CreateGEP(Ptr, Builder.getInt32(-Part * VF)); PartPtr = Builder.CreateGEP(PartPtr, Builder.getInt32(1 - VF)); } + Instruction* NewLI; Value *VecPtr = Builder.CreateBitCast(PartPtr, DataTy->getPointerTo(AddressSpace)); - LoadInst *NewLI = Builder.CreateAlignedLoad(VecPtr, Alignment, "wide.load"); + if (Legal->isMaskRequired(LI)) + NewLI = Builder.CreateMaskedLoad(VecPtr, Alignment, Mask[Part], + UndefValue::get(DataTy), + "wide.masked.load"); + else + NewLI = Builder.CreateAlignedLoad(VecPtr, Alignment, "wide.load"); propagateMetadata(NewLI, LI); Entry[Part] = Reverse ? reverseVector(NewLI) : NewLI; } @@ -2487,7 +2599,7 @@ void InnerLoopVectorizer::createEmptyLoop() { LoopScalarBody = OldBasicBlock; LoopVectorizeHints Hints(Lp, true); - Hints.setAlreadyVectorized(Lp); + Hints.setAlreadyVectorized(); } /// This function returns the identity element (or neutral element) for @@ -2755,9 +2867,6 @@ void InnerLoopVectorizer::vectorizeLoop() { } // Fix the vector-loop phi. - // We created the induction variable so we know that the - // preheader is the first entry. - BasicBlock *VecPreheader = Induction->getIncomingBlock(0); // Reductions do not have to start at zero. They can start with // any loop invariant values. @@ -2769,7 +2878,8 @@ void InnerLoopVectorizer::vectorizeLoop() { // Make sure to add the reduction stat value only to the // first unroll part. Value *StartVal = (part == 0) ? VectorStart : Identity; - cast<PHINode>(VecRdxPhi[part])->addIncoming(StartVal, VecPreheader); + cast<PHINode>(VecRdxPhi[part])->addIncoming(StartVal, + LoopVectorPreHeader); cast<PHINode>(VecRdxPhi[part])->addIncoming(Val[part], LoopVectorBody.back()); } @@ -2901,7 +3011,7 @@ void InnerLoopVectorizer::fixLCSSAPHIs() { LCSSAPhi->addIncoming(UndefValue::get(LCSSAPhi->getType()), LoopMiddleBlock); } -} +} InnerLoopVectorizer::VectorParts InnerLoopVectorizer::createEdgeMask(BasicBlock *Src, BasicBlock *Dst) { @@ -3168,18 +3278,8 @@ void InnerLoopVectorizer::vectorizeBlockInLoop(BasicBlock *BB, PhiVector *PV) { for (unsigned Part = 0; Part < UF; ++Part) { Value *V = Builder.CreateBinOp(BinOp->getOpcode(), A[Part], B[Part]); - // Update the NSW, NUW and Exact flags. Notice: V can be an Undef. - BinaryOperator *VecOp = dyn_cast<BinaryOperator>(V); - if (VecOp && isa<OverflowingBinaryOperator>(BinOp)) { - VecOp->setHasNoSignedWrap(BinOp->hasNoSignedWrap()); - VecOp->setHasNoUnsignedWrap(BinOp->hasNoUnsignedWrap()); - } - if (VecOp && isa<PossiblyExactOperator>(VecOp)) - VecOp->setIsExact(BinOp->isExact()); - - // Copy the fast-math flags. - if (VecOp && isa<FPMathOperator>(V)) - VecOp->setFastMathFlags(it->getFastMathFlags()); + if (BinaryOperator *VecOp = dyn_cast<BinaryOperator>(V)) + VecOp->copyIRFlags(BinOp); Entry[Part] = V; } @@ -3292,6 +3392,7 @@ void InnerLoopVectorizer::vectorizeBlockInLoop(BasicBlock *BB, PhiVector *PV) { Intrinsic::ID ID = getIntrinsicIDForCall(CI, TLI); assert(ID && "Not an intrinsic call!"); switch (ID) { + case Intrinsic::assume: case Intrinsic::lifetime_end: case Intrinsic::lifetime_start: scalarizeInstruction(it); @@ -3542,7 +3643,7 @@ static Type* getWiderType(const DataLayout &DL, Type *Ty0, Type *Ty1) { /// \brief Check that the instruction has outside loop users and is not an /// identified reduction variable. static bool hasOutsideLoopUser(const Loop *TheLoop, Instruction *Inst, - SmallPtrSet<Value *, 4> &Reductions) { + SmallPtrSetImpl<Value *> &Reductions) { // Reduction instructions are allowed to have exit users. All other // instructions must not have external users. if (!Reductions.count(Inst)) @@ -3597,12 +3698,12 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { // identified reduction value with an outside user. if (!hasOutsideLoopUser(TheLoop, it, AllowedExit)) continue; - emitAnalysis(Report(it) << "value that could not be identified as " - "reduction is used outside the loop"); + emitAnalysis(Report(it) << "value could not be identified as " + "an induction or reduction variable"); return false; } - // We only allow if-converted PHIs with more than two incoming values. + // We only allow if-converted PHIs with exactly two incoming values. if (Phi->getNumIncomingValues() != 2) { emitAnalysis(Report(it) << "control flow not understood by vectorizer"); @@ -3683,7 +3784,8 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { continue; } - emitAnalysis(Report(it) << "unvectorizable operation"); + emitAnalysis(Report(it) << "value that could not be identified as " + "reduction is used outside the loop"); DEBUG(dbgs() << "LV: Found an unidentified PHI."<< *Phi <<"\n"); return false; }// end of PHI handling @@ -3727,12 +3829,12 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { return false; } if (EnableMemAccessVersioning) - collectStridedAcccess(ST); + collectStridedAccess(ST); } if (EnableMemAccessVersioning) if (LoadInst *LI = dyn_cast<LoadInst>(it)) - collectStridedAcccess(LI); + collectStridedAccess(LI); // Reduction instructions are allowed to have exit users. // All other instructions must not have external users. @@ -3870,7 +3972,7 @@ static Value *getStrideFromPointer(Value *Ptr, ScalarEvolution *SE, return Stride; } -void LoopVectorizationLegality::collectStridedAcccess(Value *MemAccess) { +void LoopVectorizationLegality::collectStridedAccess(Value *MemAccess) { Value *Ptr = nullptr; if (LoadInst *LI = dyn_cast<LoadInst>(MemAccess)) Ptr = LI->getPointerOperand(); @@ -3946,7 +4048,7 @@ public: /// \brief Register a load and whether it is only read from. void addLoad(AliasAnalysis::Location &Loc, bool IsReadOnly) { Value *Ptr = const_cast<Value*>(Loc.Ptr); - AST.add(Ptr, AliasAnalysis::UnknownSize, Loc.TBAATag); + AST.add(Ptr, AliasAnalysis::UnknownSize, Loc.AATags); Accesses.insert(MemAccessInfo(Ptr, false)); if (IsReadOnly) ReadOnlyPtr.insert(Ptr); @@ -3955,7 +4057,7 @@ public: /// \brief Register a store. void addStore(AliasAnalysis::Location &Loc) { Value *Ptr = const_cast<Value*>(Loc.Ptr); - AST.add(Ptr, AliasAnalysis::UnknownSize, Loc.TBAATag); + AST.add(Ptr, AliasAnalysis::UnknownSize, Loc.AATags); Accesses.insert(MemAccessInfo(Ptr, true)); } @@ -4166,57 +4268,66 @@ void AccessAnalysis::processMemAccesses() { bool UseDeferred = SetIteration > 0; PtrAccessSet &S = UseDeferred ? DeferredAccesses : Accesses; - for (auto A : AS) { - Value *Ptr = A.getValue(); - bool IsWrite = S.count(MemAccessInfo(Ptr, true)); + for (auto AV : AS) { + Value *Ptr = AV.getValue(); - // If we're using the deferred access set, then it contains only reads. - bool IsReadOnlyPtr = ReadOnlyPtr.count(Ptr) && !IsWrite; - if (UseDeferred && !IsReadOnlyPtr) - continue; - // Otherwise, the pointer must be in the PtrAccessSet, either as a read - // or a write. - assert(((IsReadOnlyPtr && UseDeferred) || IsWrite || - S.count(MemAccessInfo(Ptr, false))) && - "Alias-set pointer not in the access set?"); - - MemAccessInfo Access(Ptr, IsWrite); - DepCands.insert(Access); - - // Memorize read-only pointers for later processing and skip them in the - // first round (they need to be checked after we have seen all write - // pointers). Note: we also mark pointer that are not consecutive as - // "read-only" pointers (so that we check "a[b[i]] +="). Hence, we need - // the second check for "!IsWrite". - if (!UseDeferred && IsReadOnlyPtr) { - DeferredAccesses.insert(Access); - continue; - } + // For a single memory access in AliasSetTracker, Accesses may contain + // both read and write, and they both need to be handled for CheckDeps. + for (auto AC : S) { + if (AC.getPointer() != Ptr) + continue; - // If this is a write - check other reads and writes for conflicts. If - // this is a read only check other writes for conflicts (but only if - // there is no other write to the ptr - this is an optimization to - // catch "a[i] = a[i] + " without having to do a dependence check). - if ((IsWrite || IsReadOnlyPtr) && SetHasWrite) { - CheckDeps.insert(Access); - IsRTCheckNeeded = true; - } + bool IsWrite = AC.getInt(); + + // If we're using the deferred access set, then it contains only + // reads. + bool IsReadOnlyPtr = ReadOnlyPtr.count(Ptr) && !IsWrite; + if (UseDeferred && !IsReadOnlyPtr) + continue; + // Otherwise, the pointer must be in the PtrAccessSet, either as a + // read or a write. + assert(((IsReadOnlyPtr && UseDeferred) || IsWrite || + S.count(MemAccessInfo(Ptr, false))) && + "Alias-set pointer not in the access set?"); + + MemAccessInfo Access(Ptr, IsWrite); + DepCands.insert(Access); + + // Memorize read-only pointers for later processing and skip them in + // the first round (they need to be checked after we have seen all + // write pointers). Note: we also mark pointer that are not + // consecutive as "read-only" pointers (so that we check + // "a[b[i]] +="). Hence, we need the second check for "!IsWrite". + if (!UseDeferred && IsReadOnlyPtr) { + DeferredAccesses.insert(Access); + continue; + } + + // If this is a write - check other reads and writes for conflicts. If + // this is a read only check other writes for conflicts (but only if + // there is no other write to the ptr - this is an optimization to + // catch "a[i] = a[i] + " without having to do a dependence check). + if ((IsWrite || IsReadOnlyPtr) && SetHasWrite) { + CheckDeps.insert(Access); + IsRTCheckNeeded = true; + } - if (IsWrite) - SetHasWrite = true; - - // Create sets of pointers connected by a shared alias set and - // underlying object. - typedef SmallVector<Value*, 16> ValueVector; - ValueVector TempObjects; - GetUnderlyingObjects(Ptr, TempObjects, DL); - for (Value *UnderlyingObj : TempObjects) { - UnderlyingObjToAccessMap::iterator Prev = - ObjToLastAccess.find(UnderlyingObj); - if (Prev != ObjToLastAccess.end()) - DepCands.unionSets(Access, Prev->second); - - ObjToLastAccess[UnderlyingObj] = Access; + if (IsWrite) + SetHasWrite = true; + + // Create sets of pointers connected by a shared alias set and + // underlying object. + typedef SmallVector<Value *, 16> ValueVector; + ValueVector TempObjects; + GetUnderlyingObjects(Ptr, TempObjects, DL); + for (Value *UnderlyingObj : TempObjects) { + UnderlyingObjToAccessMap::iterator Prev = + ObjToLastAccess.find(UnderlyingObj); + if (Prev != ObjToLastAccess.end()) + DepCands.unionSets(Access, Prev->second); + + ObjToLastAccess[UnderlyingObj] = Access; + } } } } @@ -4566,7 +4677,7 @@ bool MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx, // Bail out early if passed-in parameters make vectorization not feasible. unsigned ForcedFactor = VectorizationFactor ? VectorizationFactor : 1; - unsigned ForcedUnroll = VectorizationUnroll ? VectorizationUnroll : 1; + unsigned ForcedUnroll = VectorizationInterleave ? VectorizationInterleave : 1; // The distance must be bigger than the size needed for a vectorized version // of the operation and the size of the vectorized operation must not be @@ -4738,7 +4849,7 @@ bool LoopVectorizationLegality::canVectorizeMemory() { // If we did *not* see this pointer before, insert it to the read-write // list. At this phase it is only a 'write' list. - if (Seen.insert(Ptr)) { + if (Seen.insert(Ptr).second) { ++NumReadWrites; AliasAnalysis::Location Loc = AA->getLocation(ST); @@ -4746,7 +4857,7 @@ bool LoopVectorizationLegality::canVectorizeMemory() { // condition, so we cannot rely on it when determining whether or not we // need runtime pointer checks. if (blockNeedsPredication(ST->getParent())) - Loc.TBAATag = nullptr; + Loc.AATags.TBAA = nullptr; Accesses.addStore(Loc); } @@ -4771,7 +4882,8 @@ bool LoopVectorizationLegality::canVectorizeMemory() { // read a few words, modify, and write a few words, and some of the // words may be written to the same address. bool IsReadOnlyPtr = false; - if (Seen.insert(Ptr) || !isStridedPtr(SE, DL, Ptr, TheLoop, Strides)) { + if (Seen.insert(Ptr).second || + !isStridedPtr(SE, DL, Ptr, TheLoop, Strides)) { ++NumReads; IsReadOnlyPtr = true; } @@ -4781,7 +4893,7 @@ bool LoopVectorizationLegality::canVectorizeMemory() { // condition, so we cannot rely on it when determining whether or not we // need runtime pointer checks. if (blockNeedsPredication(LD->getParent())) - Loc.TBAATag = nullptr; + Loc.AATags.TBAA = nullptr; Accesses.addLoad(Loc, IsReadOnlyPtr); } @@ -4884,7 +4996,7 @@ bool LoopVectorizationLegality::canVectorizeMemory() { } static bool hasMultipleUsesOf(Instruction *I, - SmallPtrSet<Instruction *, 8> &Insts) { + SmallPtrSetImpl<Instruction *> &Insts) { unsigned NumUses = 0; for(User::op_iterator Use = I->op_begin(), E = I->op_end(); Use != E; ++Use) { if (Insts.count(dyn_cast<Instruction>(*Use))) @@ -4896,7 +5008,7 @@ static bool hasMultipleUsesOf(Instruction *I, return false; } -static bool areAllUsesIn(Instruction *I, SmallPtrSet<Instruction *, 8> &Set) { +static bool areAllUsesIn(Instruction *I, SmallPtrSetImpl<Instruction *> &Set) { for(User::op_iterator Use = I->op_begin(), E = I->op_end(); Use != E; ++Use) if (!Set.count(dyn_cast<Instruction>(*Use))) return false; @@ -5034,7 +5146,7 @@ bool LoopVectorizationLegality::AddReductionVar(PHINode *Phi, // value must only be used once, except by phi nodes and min/max // reductions which are represented as a cmp followed by a select. ReductionInstDesc IgnoredVal(false, nullptr); - if (VisitedInsts.insert(UI)) { + if (VisitedInsts.insert(UI).second) { if (isa<PHINode>(UI)) PHIs.push_back(UI); else @@ -5136,7 +5248,7 @@ LoopVectorizationLegality::isReductionInstr(Instruction *I, ReductionKind Kind, ReductionInstDesc &Prev) { bool FP = I->getType()->isFloatingPointTy(); - bool FastMath = (FP && I->isCommutative() && I->isAssociative()); + bool FastMath = FP && I->hasUnsafeAlgebra(); switch (I->getOpcode()) { default: return ReductionInstDesc(false, I); @@ -5158,6 +5270,7 @@ LoopVectorizationLegality::isReductionInstr(Instruction *I, return ReductionInstDesc(Kind == RK_IntegerXor, I); case Instruction::FMul: return ReductionInstDesc(Kind == RK_FloatMult && FastMath, I); + case Instruction::FSub: case Instruction::FAdd: return ReductionInstDesc(Kind == RK_FloatAdd && FastMath, I); case Instruction::FCmp: @@ -5234,13 +5347,28 @@ bool LoopVectorizationLegality::blockNeedsPredication(BasicBlock *BB) { } bool LoopVectorizationLegality::blockCanBePredicated(BasicBlock *BB, - SmallPtrSet<Value *, 8>& SafePtrs) { + SmallPtrSetImpl<Value *> &SafePtrs) { + for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e; ++it) { + // Check that we don't have a constant expression that can trap as operand. + for (Instruction::op_iterator OI = it->op_begin(), OE = it->op_end(); + OI != OE; ++OI) { + if (Constant *C = dyn_cast<Constant>(*OI)) + if (C->canTrap()) + return false; + } // We might be able to hoist the load. if (it->mayReadFromMemory()) { LoadInst *LI = dyn_cast<LoadInst>(it); - if (!LI || !SafePtrs.count(LI->getPointerOperand())) + if (!LI) + return false; + if (!SafePtrs.count(LI->getPointerOperand())) { + if (isLegalMaskedLoad(LI->getType(), LI->getPointerOperand())) { + MaskedOp.insert(LI); + continue; + } return false; + } } // We don't predicate stores at the moment. @@ -5248,22 +5376,30 @@ bool LoopVectorizationLegality::blockCanBePredicated(BasicBlock *BB, StoreInst *SI = dyn_cast<StoreInst>(it); // We only support predication of stores in basic blocks with one // predecessor. - if (!SI || ++NumPredStores > NumberOfStoresToPredicate || - !SafePtrs.count(SI->getPointerOperand()) || - !SI->getParent()->getSinglePredecessor()) + if (!SI) + return false; + + bool isSafePtr = (SafePtrs.count(SI->getPointerOperand()) != 0); + bool isSinglePredecessor = SI->getParent()->getSinglePredecessor(); + + if (++NumPredStores > NumberOfStoresToPredicate || !isSafePtr || + !isSinglePredecessor) { + // Build a masked store if it is legal for the target, otherwise scalarize + // the block. + bool isLegalMaskedOp = + isLegalMaskedStore(SI->getValueOperand()->getType(), + SI->getPointerOperand()); + if (isLegalMaskedOp) { + --NumPredStores; + MaskedOp.insert(SI); + continue; + } return false; + } } if (it->mayThrow()) return false; - // Check that we don't have a constant expression that can trap as operand. - for (Instruction::op_iterator OI = it->op_begin(), OE = it->op_end(); - OI != OE; ++OI) { - if (Constant *C = dyn_cast<Constant>(*OI)) - if (C->canTrap()) - return false; - } - // The instructions below can trap. switch (it->getOpcode()) { default: continue; @@ -5271,7 +5407,7 @@ bool LoopVectorizationLegality::blockCanBePredicated(BasicBlock *BB, case Instruction::SDiv: case Instruction::URem: case Instruction::SRem: - return false; + return false; } } @@ -5279,23 +5415,23 @@ bool LoopVectorizationLegality::blockCanBePredicated(BasicBlock *BB, } LoopVectorizationCostModel::VectorizationFactor -LoopVectorizationCostModel::selectVectorizationFactor(bool OptForSize, - unsigned UserVF, - bool ForceVectorization) { +LoopVectorizationCostModel::selectVectorizationFactor(bool OptForSize) { // Width 1 means no vectorize VectorizationFactor Factor = { 1U, 0U }; if (OptForSize && Legal->getRuntimePointerCheck()->Need) { + emitAnalysis(Report() << "runtime pointer checks needed. Enable vectorization of this loop with '#pragma clang loop vectorize(enable)' when compiling with -Os"); DEBUG(dbgs() << "LV: Aborting. Runtime ptr check is required in Os.\n"); return Factor; } if (!EnableCondStoresVectorization && Legal->NumPredStores) { + emitAnalysis(Report() << "store that is conditionally executed prevents vectorization"); DEBUG(dbgs() << "LV: No vectorization. There are conditional stores.\n"); return Factor; } // Find the trip count. - unsigned TC = SE->getSmallConstantTripCount(TheLoop, TheLoop->getLoopLatch()); + unsigned TC = SE->getSmallConstantTripCount(TheLoop); DEBUG(dbgs() << "LV: Found trip count: " << TC << '\n'); unsigned WidestType = getWidestType(); @@ -5315,7 +5451,7 @@ LoopVectorizationCostModel::selectVectorizationFactor(bool OptForSize, MaxVectorSize = 1; } - assert(MaxVectorSize <= 32 && "Did not expect to pack so many elements" + assert(MaxVectorSize <= 64 && "Did not expect to pack so many elements" " into one vector!"); unsigned VF = MaxVectorSize; @@ -5324,6 +5460,7 @@ LoopVectorizationCostModel::selectVectorizationFactor(bool OptForSize, if (OptForSize) { // If we are unable to calculate the trip count then don't try to vectorize. if (TC < 2) { + emitAnalysis(Report() << "unable to calculate the loop count due to complex control flow"); DEBUG(dbgs() << "LV: Aborting. A tail loop is required in Os.\n"); return Factor; } @@ -5337,11 +5474,16 @@ LoopVectorizationCostModel::selectVectorizationFactor(bool OptForSize, // If the trip count that we found modulo the vectorization factor is not // zero then we require a tail. if (VF < 2) { + emitAnalysis(Report() << "cannot optimize for size and vectorize at the " + "same time. Enable vectorization of this loop " + "with '#pragma clang loop vectorize(enable)' " + "when compiling with -Os"); DEBUG(dbgs() << "LV: Aborting. A tail loop is required in Os.\n"); return Factor; } } + int UserVF = Hints->getWidth(); if (UserVF != 0) { assert(isPowerOf2_32(UserVF) && "VF needs to be a power of two"); DEBUG(dbgs() << "LV: Using user VF " << UserVF << ".\n"); @@ -5357,6 +5499,7 @@ LoopVectorizationCostModel::selectVectorizationFactor(bool OptForSize, unsigned Width = 1; DEBUG(dbgs() << "LV: Scalar loop costs: " << (int)ScalarCost << ".\n"); + bool ForceVectorization = Hints->getForce() == LoopVectorizeHints::FK_Enabled; // Ignore scalar width, because the user explicitly wants vectorization. if (ForceVectorization && VF > 1) { Width = 2; @@ -5397,6 +5540,10 @@ unsigned LoopVectorizationCostModel::getWidestType() { for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e; ++it) { Type *T = it->getType(); + // Ignore ephemeral values. + if (EphValues.count(it)) + continue; + // Only examine Loads, Stores and PHINodes. if (!isa<LoadInst>(it) && !isa<StoreInst>(it) && !isa<PHINode>(it)) continue; @@ -5426,29 +5573,29 @@ unsigned LoopVectorizationCostModel::getWidestType() { unsigned LoopVectorizationCostModel::selectUnrollFactor(bool OptForSize, - unsigned UserUF, unsigned VF, unsigned LoopCost) { // -- The unroll heuristics -- // We unroll the loop in order to expose ILP and reduce the loop overhead. // There are many micro-architectural considerations that we can't predict - // at this level. For example frontend pressure (on decode or fetch) due to + // at this level. For example, frontend pressure (on decode or fetch) due to // code size, or the number and capabilities of the execution ports. // // We use the following heuristics to select the unroll factor: - // 1. If the code has reductions the we unroll in order to break the cross + // 1. If the code has reductions, then we unroll in order to break the cross // iteration dependency. - // 2. If the loop is really small then we unroll in order to reduce the loop + // 2. If the loop is really small, then we unroll in order to reduce the loop // overhead. // 3. We don't unroll if we think that we will spill registers to memory due // to the increased register pressure. // Use the user preference, unless 'auto' is selected. + int UserUF = Hints->getInterleave(); if (UserUF != 0) return UserUF; - // When we optimize for size we don't unroll. + // When we optimize for size, we don't unroll. if (OptForSize) return 1; @@ -5457,8 +5604,7 @@ LoopVectorizationCostModel::selectUnrollFactor(bool OptForSize, return 1; // Do not unroll loops with a relatively small trip count. - unsigned TC = SE->getSmallConstantTripCount(TheLoop, - TheLoop->getLoopLatch()); + unsigned TC = SE->getSmallConstantTripCount(TheLoop); if (TC > 1 && TC < TinyTripCountUnrollThreshold) return 1; @@ -5497,15 +5643,15 @@ LoopVectorizationCostModel::selectUnrollFactor(bool OptForSize, std::max(1U, (R.MaxLocalUsers - 1))); // Clamp the unroll factor ranges to reasonable factors. - unsigned MaxUnrollSize = TTI.getMaximumUnrollFactor(); + unsigned MaxInterleaveSize = TTI.getMaxInterleaveFactor(); // Check if the user has overridden the unroll max. if (VF == 1) { - if (ForceTargetMaxScalarUnrollFactor.getNumOccurrences() > 0) - MaxUnrollSize = ForceTargetMaxScalarUnrollFactor; + if (ForceTargetMaxScalarInterleaveFactor.getNumOccurrences() > 0) + MaxInterleaveSize = ForceTargetMaxScalarInterleaveFactor; } else { - if (ForceTargetMaxVectorUnrollFactor.getNumOccurrences() > 0) - MaxUnrollSize = ForceTargetMaxVectorUnrollFactor; + if (ForceTargetMaxVectorInterleaveFactor.getNumOccurrences() > 0) + MaxInterleaveSize = ForceTargetMaxVectorInterleaveFactor; } // If we did not calculate the cost for VF (because the user selected the VF) @@ -5515,8 +5661,8 @@ LoopVectorizationCostModel::selectUnrollFactor(bool OptForSize, // Clamp the calculated UF to be between the 1 and the max unroll factor // that the target allows. - if (UF > MaxUnrollSize) - UF = MaxUnrollSize; + if (UF > MaxInterleaveSize) + UF = MaxInterleaveSize; else if (UF < 1) UF = 1; @@ -5547,6 +5693,18 @@ LoopVectorizationCostModel::selectUnrollFactor(bool OptForSize, unsigned StoresUF = UF / (Legal->NumStores ? Legal->NumStores : 1); unsigned LoadsUF = UF / (Legal->NumLoads ? Legal->NumLoads : 1); + // If we have a scalar reduction (vector reductions are already dealt with + // by this point), we can increase the critical path length if the loop + // we're unrolling is inside another loop. Limit, by default to 2, so the + // critical path only gets increased by one reduction operation. + if (Legal->getReductionVars()->size() && + TheLoop->getLoopDepth() > 1) { + unsigned F = static_cast<unsigned>(MaxNestedScalarReductionUF); + SmallUF = std::min(SmallUF, F); + StoresUF = std::min(StoresUF, F); + LoadsUF = std::min(LoadsUF, F); + } + if (EnableLoadStoreRuntimeUnroll && std::max(StoresUF, LoadsUF) > SmallUF) { DEBUG(dbgs() << "LV: Unrolling to saturate store or load ports.\n"); return std::max(StoresUF, LoadsUF); @@ -5648,6 +5806,10 @@ LoopVectorizationCostModel::calculateRegisterUsage() { // Ignore instructions that are never used within the loop. if (!Ends.count(I)) continue; + // Ignore ephemeral values. + if (EphValues.count(I)) + continue; + // Remove all of the instructions that end at this location. InstrList &List = TransposeEnds[i]; for (unsigned int j=0, e = List.size(); j < e; ++j) @@ -5688,6 +5850,10 @@ unsigned LoopVectorizationCostModel::expectedCost(unsigned VF) { if (isa<DbgInfoIntrinsic>(it)) continue; + // Ignore ephemeral values. + if (EphValues.count(it)) + continue; + unsigned C = getInstructionCost(it, VF); // Check if we should override the cost. @@ -5821,18 +5987,31 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, unsigned VF) { TargetTransformInfo::OK_AnyValue; TargetTransformInfo::OperandValueKind Op2VK = TargetTransformInfo::OK_AnyValue; + TargetTransformInfo::OperandValueProperties Op1VP = + TargetTransformInfo::OP_None; + TargetTransformInfo::OperandValueProperties Op2VP = + TargetTransformInfo::OP_None; Value *Op2 = I->getOperand(1); // Check for a splat of a constant or for a non uniform vector of constants. - if (isa<ConstantInt>(Op2)) + if (isa<ConstantInt>(Op2)) { + ConstantInt *CInt = cast<ConstantInt>(Op2); + if (CInt && CInt->getValue().isPowerOf2()) + Op2VP = TargetTransformInfo::OP_PowerOf2; Op2VK = TargetTransformInfo::OK_UniformConstantValue; - else if (isa<ConstantVector>(Op2) || isa<ConstantDataVector>(Op2)) { + } else if (isa<ConstantVector>(Op2) || isa<ConstantDataVector>(Op2)) { Op2VK = TargetTransformInfo::OK_NonUniformConstantValue; - if (cast<Constant>(Op2)->getSplatValue() != nullptr) + Constant *SplatValue = cast<Constant>(Op2)->getSplatValue(); + if (SplatValue) { + ConstantInt *CInt = dyn_cast<ConstantInt>(SplatValue); + if (CInt && CInt->getValue().isPowerOf2()) + Op2VP = TargetTransformInfo::OP_PowerOf2; Op2VK = TargetTransformInfo::OK_UniformConstantValue; + } } - return TTI.getArithmeticInstrCost(I->getOpcode(), VectorTy, Op1VK, Op2VK); + return TTI.getArithmeticInstrCost(I->getOpcode(), VectorTy, Op1VK, Op2VK, + Op1VP, Op2VP); } case Instruction::Select: { SelectInst *SI = cast<SelectInst>(I); @@ -5975,6 +6154,7 @@ static const char lv_name[] = "Loop Vectorization"; INITIALIZE_PASS_BEGIN(LoopVectorize, LV_NAME, lv_name, false, false) INITIALIZE_AG_DEPENDENCY(TargetTransformInfo) INITIALIZE_AG_DEPENDENCY(AliasAnalysis) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(BlockFrequencyInfo) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(ScalarEvolution) diff --git a/lib/Transforms/Vectorize/SLPVectorizer.cpp b/lib/Transforms/Vectorize/SLPVectorizer.cpp index 53a43d9851e9..4834782ecc14 100644 --- a/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -19,7 +19,11 @@ #include "llvm/ADT/MapVector.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/SetVector.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/CodeMetrics.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" @@ -42,12 +46,15 @@ #include "llvm/Transforms/Utils/VectorUtils.h" #include <algorithm> #include <map> +#include <memory> using namespace llvm; #define SV_NAME "slp-vectorizer" #define DEBUG_TYPE "SLP" +STATISTIC(NumVectorInstructions, "Number of vector instructions generated"); + static cl::opt<int> SLPCostThreshold("slp-threshold", cl::init(0), cl::Hidden, cl::desc("Only vectorize if you gain more than this " @@ -68,53 +75,6 @@ static const unsigned MinVecRegSize = 128; static const unsigned RecursionMaxDepth = 12; -/// A helper class for numbering instructions in multiple blocks. -/// Numbers start at zero for each basic block. -struct BlockNumbering { - - BlockNumbering(BasicBlock *Bb) : BB(Bb), Valid(false) {} - - void numberInstructions() { - unsigned Loc = 0; - InstrIdx.clear(); - InstrVec.clear(); - // Number the instructions in the block. - for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e; ++it) { - InstrIdx[it] = Loc++; - InstrVec.push_back(it); - assert(InstrVec[InstrIdx[it]] == it && "Invalid allocation"); - } - Valid = true; - } - - int getIndex(Instruction *I) { - assert(I->getParent() == BB && "Invalid instruction"); - if (!Valid) - numberInstructions(); - assert(InstrIdx.count(I) && "Unknown instruction"); - return InstrIdx[I]; - } - - Instruction *getInstruction(unsigned loc) { - if (!Valid) - numberInstructions(); - assert(InstrVec.size() > loc && "Invalid Index"); - return InstrVec[loc]; - } - - void forget() { Valid = false; } - -private: - /// The block we are numbering. - BasicBlock *BB; - /// Is the block numbered. - bool Valid; - /// Maps instructions to numbers and back. - SmallDenseMap<Instruction *, int> InstrIdx; - /// Maps integers to Instructions. - SmallVector<Instruction *, 32> InstrVec; -}; - /// \returns the parent basic block if all of the instructions in \p VL /// are in the same block or null otherwise. static BasicBlock *getSameBlock(ArrayRef<Value *> VL) { @@ -209,6 +169,23 @@ static unsigned getSameOpcode(ArrayRef<Value *> VL) { return Opcode; } +/// Get the intersection (logical and) of all of the potential IR flags +/// of each scalar operation (VL) that will be converted into a vector (I). +/// Flag set: NSW, NUW, exact, and all of fast-math. +static void propagateIRFlags(Value *I, ArrayRef<Value *> VL) { + if (auto *VecOp = dyn_cast<BinaryOperator>(I)) { + if (auto *Intersection = dyn_cast<BinaryOperator>(VL[0])) { + // Intersection is initialized to the 0th scalar, + // so start counting from index '1'. + for (int i = 1, e = VL.size(); i < e; ++i) { + if (auto *Scalar = dyn_cast<BinaryOperator>(VL[i])) + Intersection->andIRFlags(Scalar); + } + VecOp->copyIRFlags(Intersection); + } + } +} + /// \returns \p I after propagating metadata from \p VL. static Instruction *propagateMetadata(Instruction *I, ArrayRef<Value *> VL) { Instruction *I0 = cast<Instruction>(VL[0]); @@ -230,6 +207,10 @@ static Instruction *propagateMetadata(Instruction *I, ArrayRef<Value *> VL) { case LLVMContext::MD_tbaa: MD = MDNode::getMostGenericTBAA(MD, IMD); break; + case LLVMContext::MD_alias_scope: + case LLVMContext::MD_noalias: + MD = MDNode::intersect(MD, IMD); + break; case LLVMContext::MD_fpmath: MD = MDNode::getMostGenericFPMath(MD, IMD); break; @@ -381,6 +362,42 @@ static void reorderInputsAccordingToOpcode(ArrayRef<Value *> VL, } } +/// \returns True if in-tree use also needs extract. This refers to +/// possible scalar operand in vectorized instruction. +static bool InTreeUserNeedToExtract(Value *Scalar, Instruction *UserInst, + TargetLibraryInfo *TLI) { + + unsigned Opcode = UserInst->getOpcode(); + switch (Opcode) { + case Instruction::Load: { + LoadInst *LI = cast<LoadInst>(UserInst); + return (LI->getPointerOperand() == Scalar); + } + case Instruction::Store: { + StoreInst *SI = cast<StoreInst>(UserInst); + return (SI->getPointerOperand() == Scalar); + } + case Instruction::Call: { + CallInst *CI = cast<CallInst>(UserInst); + Intrinsic::ID ID = getIntrinsicIDForCall(CI, TLI); + if (hasVectorInstrinsicScalarOpd(ID, 1)) { + return (CI->getArgOperand(1) == Scalar); + } + } + default: + return false; + } +} + +/// \returns the AA location that is being access by the instruction. +static AliasAnalysis::Location getLocation(Instruction *I, AliasAnalysis *AA) { + if (StoreInst *SI = dyn_cast<StoreInst>(I)) + return AA->getLocation(SI); + if (LoadInst *LI = dyn_cast<LoadInst>(I)) + return AA->getLocation(LI); + return AliasAnalysis::Location(); +} + /// Bottom Up SLP Vectorizer. class BoUpSLP { public: @@ -391,14 +408,21 @@ public: BoUpSLP(Function *Func, ScalarEvolution *Se, const DataLayout *Dl, TargetTransformInfo *Tti, TargetLibraryInfo *TLi, AliasAnalysis *Aa, - LoopInfo *Li, DominatorTree *Dt) - : F(Func), SE(Se), DL(Dl), TTI(Tti), TLI(TLi), AA(Aa), LI(Li), DT(Dt), - Builder(Se->getContext()) {} + LoopInfo *Li, DominatorTree *Dt, AssumptionCache *AC) + : NumLoadsWantToKeepOrder(0), NumLoadsWantToChangeOrder(0), F(Func), + SE(Se), DL(Dl), TTI(Tti), TLI(TLi), AA(Aa), LI(Li), DT(Dt), + Builder(Se->getContext()) { + CodeMetrics::collectEphemeralValues(F, AC, EphValues); + } /// \brief Vectorize the tree that starts with the elements in \p VL. /// Returns the vectorized root. Value *vectorizeTree(); + /// \returns the cost incurred by unwanted spills and fills, caused by + /// holding live values over call sites. + int getSpillCost(); + /// \returns the vectorization cost of the subtree that starts at \p VL. /// A negative number means that this is profitable. int getTreeCost(); @@ -414,7 +438,12 @@ public: ScalarToTreeEntry.clear(); MustGather.clear(); ExternalUses.clear(); - MemBarrierIgnoreList.clear(); + NumLoadsWantToKeepOrder = 0; + NumLoadsWantToChangeOrder = 0; + for (auto &Iter : BlocksSchedules) { + BlockScheduling *BS = Iter.second.get(); + BS->clear(); + } } /// \returns true if the memory operations A and B are consecutive. @@ -423,6 +452,11 @@ public: /// \brief Perform LICM and CSE on the newly generated gather sequences. void optimizeGatherSequence(); + /// \returns true if it is benefitial to reverse the vector order. + bool shouldReorder() const { + return NumLoadsWantToChangeOrder > NumLoadsWantToKeepOrder; + } + private: struct TreeEntry; @@ -459,20 +493,6 @@ private: /// roots. This method calculates the cost of extracting the values. int getGatherCost(ArrayRef<Value *> VL); - /// \returns the AA location that is being access by the instruction. - AliasAnalysis::Location getLocation(Instruction *I); - - /// \brief Checks if it is possible to sink an instruction from - /// \p Src to \p Dst. - /// \returns the pointer to the barrier instruction if we can't sink. - Value *getSinkBarrier(Instruction *Src, Instruction *Dst); - - /// \returns the index of the last instruction in the BB from \p VL. - int getLastIndex(ArrayRef<Value *> VL); - - /// \returns the Instruction in the bundle \p VL. - Instruction *getLastInstruction(ArrayRef<Value *> VL); - /// \brief Set the Builder insert point to one after the last instruction in /// the bundle void setInsertPointAfterBundle(ArrayRef<Value *> VL); @@ -485,7 +505,7 @@ private: bool isFullyVectorizableTinyTree(); struct TreeEntry { - TreeEntry() : Scalars(), VectorizedValue(nullptr), LastScalarIndex(0), + TreeEntry() : Scalars(), VectorizedValue(nullptr), NeedToGather(0) {} /// \returns true if the scalars in VL are equal to this entry. @@ -500,9 +520,6 @@ private: /// The Scalars are vectorized into this value. It is initialized to Null. Value *VectorizedValue; - /// The index in the basic block of the last scalar. - int LastScalarIndex; - /// Do we need to gather this sequence ? bool NeedToGather; }; @@ -515,18 +532,16 @@ private: Last->Scalars.insert(Last->Scalars.begin(), VL.begin(), VL.end()); Last->NeedToGather = !Vectorized; if (Vectorized) { - Last->LastScalarIndex = getLastIndex(VL); for (int i = 0, e = VL.size(); i != e; ++i) { assert(!ScalarToTreeEntry.count(VL[i]) && "Scalar already in tree!"); ScalarToTreeEntry[VL[i]] = idx; } } else { - Last->LastScalarIndex = 0; MustGather.insert(VL.begin(), VL.end()); } return Last; } - + /// -- Vectorization State -- /// Holds all of the tree entries. std::vector<TreeEntry> VectorizableTree; @@ -550,32 +565,369 @@ private: }; typedef SmallVector<ExternalUser, 16> UserList; + /// Checks if two instructions may access the same memory. + /// + /// \p Loc1 is the location of \p Inst1. It is passed explicitly because it + /// is invariant in the calling loop. + bool isAliased(const AliasAnalysis::Location &Loc1, Instruction *Inst1, + Instruction *Inst2) { + + // First check if the result is already in the cache. + AliasCacheKey key = std::make_pair(Inst1, Inst2); + Optional<bool> &result = AliasCache[key]; + if (result.hasValue()) { + return result.getValue(); + } + AliasAnalysis::Location Loc2 = getLocation(Inst2, AA); + bool aliased = true; + if (Loc1.Ptr && Loc2.Ptr) { + // Do the alias check. + aliased = AA->alias(Loc1, Loc2); + } + // Store the result in the cache. + result = aliased; + return aliased; + } + + typedef std::pair<Instruction *, Instruction *> AliasCacheKey; + + /// Cache for alias results. + /// TODO: consider moving this to the AliasAnalysis itself. + DenseMap<AliasCacheKey, Optional<bool>> AliasCache; + + /// Removes an instruction from its block and eventually deletes it. + /// It's like Instruction::eraseFromParent() except that the actual deletion + /// is delayed until BoUpSLP is destructed. + /// This is required to ensure that there are no incorrect collisions in the + /// AliasCache, which can happen if a new instruction is allocated at the + /// same address as a previously deleted instruction. + void eraseInstruction(Instruction *I) { + I->removeFromParent(); + I->dropAllReferences(); + DeletedInstructions.push_back(std::unique_ptr<Instruction>(I)); + } + + /// Temporary store for deleted instructions. Instructions will be deleted + /// eventually when the BoUpSLP is destructed. + SmallVector<std::unique_ptr<Instruction>, 8> DeletedInstructions; + /// A list of values that need to extracted out of the tree. /// This list holds pairs of (Internal Scalar : External User). UserList ExternalUses; - /// A list of instructions to ignore while sinking - /// memory instructions. This map must be reset between runs of getCost. - ValueSet MemBarrierIgnoreList; + /// Values used only by @llvm.assume calls. + SmallPtrSet<const Value *, 32> EphValues; /// Holds all of the instructions that we gathered. SetVector<Instruction *> GatherSeq; /// A list of blocks that we are going to CSE. SetVector<BasicBlock *> CSEBlocks; - /// Numbers instructions in different blocks. - DenseMap<BasicBlock *, BlockNumbering> BlocksNumbers; + /// Contains all scheduling relevant data for an instruction. + /// A ScheduleData either represents a single instruction or a member of an + /// instruction bundle (= a group of instructions which is combined into a + /// vector instruction). + struct ScheduleData { + + // The initial value for the dependency counters. It means that the + // dependencies are not calculated yet. + enum { InvalidDeps = -1 }; + + ScheduleData() + : Inst(nullptr), FirstInBundle(nullptr), NextInBundle(nullptr), + NextLoadStore(nullptr), SchedulingRegionID(0), SchedulingPriority(0), + Dependencies(InvalidDeps), UnscheduledDeps(InvalidDeps), + UnscheduledDepsInBundle(InvalidDeps), IsScheduled(false) {} + + void init(int BlockSchedulingRegionID) { + FirstInBundle = this; + NextInBundle = nullptr; + NextLoadStore = nullptr; + IsScheduled = false; + SchedulingRegionID = BlockSchedulingRegionID; + UnscheduledDepsInBundle = UnscheduledDeps; + clearDependencies(); + } + + /// Returns true if the dependency information has been calculated. + bool hasValidDependencies() const { return Dependencies != InvalidDeps; } + + /// Returns true for single instructions and for bundle representatives + /// (= the head of a bundle). + bool isSchedulingEntity() const { return FirstInBundle == this; } + + /// Returns true if it represents an instruction bundle and not only a + /// single instruction. + bool isPartOfBundle() const { + return NextInBundle != nullptr || FirstInBundle != this; + } + + /// Returns true if it is ready for scheduling, i.e. it has no more + /// unscheduled depending instructions/bundles. + bool isReady() const { + assert(isSchedulingEntity() && + "can't consider non-scheduling entity for ready list"); + return UnscheduledDepsInBundle == 0 && !IsScheduled; + } + + /// Modifies the number of unscheduled dependencies, also updating it for + /// the whole bundle. + int incrementUnscheduledDeps(int Incr) { + UnscheduledDeps += Incr; + return FirstInBundle->UnscheduledDepsInBundle += Incr; + } + + /// Sets the number of unscheduled dependencies to the number of + /// dependencies. + void resetUnscheduledDeps() { + incrementUnscheduledDeps(Dependencies - UnscheduledDeps); + } + + /// Clears all dependency information. + void clearDependencies() { + Dependencies = InvalidDeps; + resetUnscheduledDeps(); + MemoryDependencies.clear(); + } + + void dump(raw_ostream &os) const { + if (!isSchedulingEntity()) { + os << "/ " << *Inst; + } else if (NextInBundle) { + os << '[' << *Inst; + ScheduleData *SD = NextInBundle; + while (SD) { + os << ';' << *SD->Inst; + SD = SD->NextInBundle; + } + os << ']'; + } else { + os << *Inst; + } + } + + Instruction *Inst; - /// \brief Get the corresponding instruction numbering list for a given - /// BasicBlock. The list is allocated lazily. - BlockNumbering &getBlockNumbering(BasicBlock *BB) { - auto I = BlocksNumbers.insert(std::make_pair(BB, BlockNumbering(BB))); - return I.first->second; - } + /// Points to the head in an instruction bundle (and always to this for + /// single instructions). + ScheduleData *FirstInBundle; + + /// Single linked list of all instructions in a bundle. Null if it is a + /// single instruction. + ScheduleData *NextInBundle; + + /// Single linked list of all memory instructions (e.g. load, store, call) + /// in the block - until the end of the scheduling region. + ScheduleData *NextLoadStore; + + /// The dependent memory instructions. + /// This list is derived on demand in calculateDependencies(). + SmallVector<ScheduleData *, 4> MemoryDependencies; + + /// This ScheduleData is in the current scheduling region if this matches + /// the current SchedulingRegionID of BlockScheduling. + int SchedulingRegionID; + + /// Used for getting a "good" final ordering of instructions. + int SchedulingPriority; + + /// The number of dependencies. Constitutes of the number of users of the + /// instruction plus the number of dependent memory instructions (if any). + /// This value is calculated on demand. + /// If InvalidDeps, the number of dependencies is not calculated yet. + /// + int Dependencies; + + /// The number of dependencies minus the number of dependencies of scheduled + /// instructions. As soon as this is zero, the instruction/bundle gets ready + /// for scheduling. + /// Note that this is negative as long as Dependencies is not calculated. + int UnscheduledDeps; + + /// The sum of UnscheduledDeps in a bundle. Equals to UnscheduledDeps for + /// single instructions. + int UnscheduledDepsInBundle; + + /// True if this instruction is scheduled (or considered as scheduled in the + /// dry-run). + bool IsScheduled; + }; + +#ifndef NDEBUG + friend raw_ostream &operator<<(raw_ostream &os, + const BoUpSLP::ScheduleData &SD); +#endif + + /// Contains all scheduling data for a basic block. + /// + struct BlockScheduling { + + BlockScheduling(BasicBlock *BB) + : BB(BB), ChunkSize(BB->size()), ChunkPos(ChunkSize), + ScheduleStart(nullptr), ScheduleEnd(nullptr), + FirstLoadStoreInRegion(nullptr), LastLoadStoreInRegion(nullptr), + // Make sure that the initial SchedulingRegionID is greater than the + // initial SchedulingRegionID in ScheduleData (which is 0). + SchedulingRegionID(1) {} + + void clear() { + ReadyInsts.clear(); + ScheduleStart = nullptr; + ScheduleEnd = nullptr; + FirstLoadStoreInRegion = nullptr; + LastLoadStoreInRegion = nullptr; + + // Make a new scheduling region, i.e. all existing ScheduleData is not + // in the new region yet. + ++SchedulingRegionID; + } + + ScheduleData *getScheduleData(Value *V) { + ScheduleData *SD = ScheduleDataMap[V]; + if (SD && SD->SchedulingRegionID == SchedulingRegionID) + return SD; + return nullptr; + } + + bool isInSchedulingRegion(ScheduleData *SD) { + return SD->SchedulingRegionID == SchedulingRegionID; + } + + /// Marks an instruction as scheduled and puts all dependent ready + /// instructions into the ready-list. + template <typename ReadyListType> + void schedule(ScheduleData *SD, ReadyListType &ReadyList) { + SD->IsScheduled = true; + DEBUG(dbgs() << "SLP: schedule " << *SD << "\n"); + + ScheduleData *BundleMember = SD; + while (BundleMember) { + // Handle the def-use chain dependencies. + for (Use &U : BundleMember->Inst->operands()) { + ScheduleData *OpDef = getScheduleData(U.get()); + if (OpDef && OpDef->hasValidDependencies() && + OpDef->incrementUnscheduledDeps(-1) == 0) { + // There are no more unscheduled dependencies after decrementing, + // so we can put the dependent instruction into the ready list. + ScheduleData *DepBundle = OpDef->FirstInBundle; + assert(!DepBundle->IsScheduled && + "already scheduled bundle gets ready"); + ReadyList.insert(DepBundle); + DEBUG(dbgs() << "SLP: gets ready (def): " << *DepBundle << "\n"); + } + } + // Handle the memory dependencies. + for (ScheduleData *MemoryDepSD : BundleMember->MemoryDependencies) { + if (MemoryDepSD->incrementUnscheduledDeps(-1) == 0) { + // There are no more unscheduled dependencies after decrementing, + // so we can put the dependent instruction into the ready list. + ScheduleData *DepBundle = MemoryDepSD->FirstInBundle; + assert(!DepBundle->IsScheduled && + "already scheduled bundle gets ready"); + ReadyList.insert(DepBundle); + DEBUG(dbgs() << "SLP: gets ready (mem): " << *DepBundle << "\n"); + } + } + BundleMember = BundleMember->NextInBundle; + } + } + + /// Put all instructions into the ReadyList which are ready for scheduling. + template <typename ReadyListType> + void initialFillReadyList(ReadyListType &ReadyList) { + for (auto *I = ScheduleStart; I != ScheduleEnd; I = I->getNextNode()) { + ScheduleData *SD = getScheduleData(I); + if (SD->isSchedulingEntity() && SD->isReady()) { + ReadyList.insert(SD); + DEBUG(dbgs() << "SLP: initially in ready list: " << *I << "\n"); + } + } + } + + /// Checks if a bundle of instructions can be scheduled, i.e. has no + /// cyclic dependencies. This is only a dry-run, no instructions are + /// actually moved at this stage. + bool tryScheduleBundle(ArrayRef<Value *> VL, BoUpSLP *SLP); + + /// Un-bundles a group of instructions. + void cancelScheduling(ArrayRef<Value *> VL); + + /// Extends the scheduling region so that V is inside the region. + void extendSchedulingRegion(Value *V); + + /// Initialize the ScheduleData structures for new instructions in the + /// scheduling region. + void initScheduleData(Instruction *FromI, Instruction *ToI, + ScheduleData *PrevLoadStore, + ScheduleData *NextLoadStore); + + /// Updates the dependency information of a bundle and of all instructions/ + /// bundles which depend on the original bundle. + void calculateDependencies(ScheduleData *SD, bool InsertInReadyList, + BoUpSLP *SLP); + + /// Sets all instruction in the scheduling region to un-scheduled. + void resetSchedule(); + + BasicBlock *BB; + + /// Simple memory allocation for ScheduleData. + std::vector<std::unique_ptr<ScheduleData[]>> ScheduleDataChunks; + + /// The size of a ScheduleData array in ScheduleDataChunks. + int ChunkSize; + + /// The allocator position in the current chunk, which is the last entry + /// of ScheduleDataChunks. + int ChunkPos; + + /// Attaches ScheduleData to Instruction. + /// Note that the mapping survives during all vectorization iterations, i.e. + /// ScheduleData structures are recycled. + DenseMap<Value *, ScheduleData *> ScheduleDataMap; + + struct ReadyList : SmallVector<ScheduleData *, 8> { + void insert(ScheduleData *SD) { push_back(SD); } + }; + + /// The ready-list for scheduling (only used for the dry-run). + ReadyList ReadyInsts; + + /// The first instruction of the scheduling region. + Instruction *ScheduleStart; + + /// The first instruction _after_ the scheduling region. + Instruction *ScheduleEnd; + + /// The first memory accessing instruction in the scheduling region + /// (can be null). + ScheduleData *FirstLoadStoreInRegion; + + /// The last memory accessing instruction in the scheduling region + /// (can be null). + ScheduleData *LastLoadStoreInRegion; + + /// The ID of the scheduling region. For a new vectorization iteration this + /// is incremented which "removes" all ScheduleData from the region. + int SchedulingRegionID; + }; + + /// Attaches the BlockScheduling structures to basic blocks. + MapVector<BasicBlock *, std::unique_ptr<BlockScheduling>> BlocksSchedules; + + /// Performs the "real" scheduling. Done before vectorization is actually + /// performed in a basic block. + void scheduleBlock(BlockScheduling *BS); /// List of users to ignore during scheduling and that don't need extracting. ArrayRef<Value *> UserIgnoreList; + // Number of load-bundles, which contain consecutive loads. + int NumLoadsWantToKeepOrder; + + // Number of load-bundles of size 2, which are consecutive loads if reversed. + int NumLoadsWantToChangeOrder; + // Analysis and block reference. Function *F; ScalarEvolution *SE; @@ -589,6 +941,13 @@ private: IRBuilder<> Builder; }; +#ifndef NDEBUG +raw_ostream &operator<<(raw_ostream &os, const BoUpSLP::ScheduleData &SD) { + SD.dump(os); + return os; +} +#endif + void BoUpSLP::buildTree(ArrayRef<Value *> Roots, ArrayRef<Value *> UserIgnoreLst) { deleteTree(); @@ -612,18 +971,27 @@ void BoUpSLP::buildTree(ArrayRef<Value *> Roots, for (User *U : Scalar->users()) { DEBUG(dbgs() << "SLP: Checking user:" << *U << ".\n"); - // Skip in-tree scalars that become vectors. - if (ScalarToTreeEntry.count(U)) { - DEBUG(dbgs() << "SLP: \tInternal user will be removed:" << - *U << ".\n"); - int Idx = ScalarToTreeEntry[U]; (void) Idx; - assert(!VectorizableTree[Idx].NeedToGather && "Bad state"); - continue; - } Instruction *UserInst = dyn_cast<Instruction>(U); if (!UserInst) continue; + // Skip in-tree scalars that become vectors + if (ScalarToTreeEntry.count(U)) { + int Idx = ScalarToTreeEntry[U]; + TreeEntry *UseEntry = &VectorizableTree[Idx]; + Value *UseScalar = UseEntry->Scalars[0]; + // Some in-tree scalars will remain as scalar in vectorized + // instructions. If that is the case, the one in Lane 0 will + // be used. + if (UseScalar != U || + !InTreeUserNeedToExtract(Scalar, UserInst, TLI)) { + DEBUG(dbgs() << "SLP: \tInternal user will be removed:" << *U + << ".\n"); + assert(!VectorizableTree[Idx].NeedToGather && "Bad state"); + continue; + } + } + // Ignore users in the user ignore list. if (std::find(UserIgnoreList.begin(), UserIgnoreList.end(), UserInst) != UserIgnoreList.end()) @@ -683,6 +1051,16 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { // We now know that this is a vector of instructions of the same type from // the same block. + // Don't vectorize ephemeral values. + for (unsigned i = 0, e = VL.size(); i != e; ++i) { + if (EphValues.count(VL[i])) { + DEBUG(dbgs() << "SLP: The instruction (" << *VL[i] << + ") is ephemeral.\n"); + newTreeEntry(VL, false); + return; + } + } + // Check if this is a duplicate of another entry. if (ScalarToTreeEntry.count(VL[0])) { int Idx = ScalarToTreeEntry[VL[0]]; @@ -709,11 +1087,11 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { } } - // If any of the scalars appears in the table OR it is marked as a value that - // needs to stat scalar then we need to gather the scalars. + // If any of the scalars is marked as a value that needs to stay scalar then + // we need to gather the scalars. for (unsigned i = 0, e = VL.size(); i != e; ++i) { - if (ScalarToTreeEntry.count(VL[i]) || MustGather.count(VL[i])) { - DEBUG(dbgs() << "SLP: Gathering due to gathered scalar. \n"); + if (MustGather.count(VL[i])) { + DEBUG(dbgs() << "SLP: Gathering due to gathered scalar.\n"); newTreeEntry(VL, false); return; } @@ -722,69 +1100,16 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { // Check that all of the users of the scalars that we want to vectorize are // schedulable. Instruction *VL0 = cast<Instruction>(VL[0]); - int MyLastIndex = getLastIndex(VL); BasicBlock *BB = cast<Instruction>(VL0)->getParent(); - for (unsigned i = 0, e = VL.size(); i != e; ++i) { - Instruction *Scalar = cast<Instruction>(VL[i]); - DEBUG(dbgs() << "SLP: Checking users of " << *Scalar << ". \n"); - for (User *U : Scalar->users()) { - DEBUG(dbgs() << "SLP: \tUser " << *U << ". \n"); - Instruction *UI = dyn_cast<Instruction>(U); - if (!UI) { - DEBUG(dbgs() << "SLP: Gathering due unknown user. \n"); - newTreeEntry(VL, false); - return; - } - - // We don't care if the user is in a different basic block. - BasicBlock *UserBlock = UI->getParent(); - if (UserBlock != BB) { - DEBUG(dbgs() << "SLP: User from a different basic block " - << *UI << ". \n"); - continue; - } - - // If this is a PHINode within this basic block then we can place the - // extract wherever we want. - if (isa<PHINode>(*UI)) { - DEBUG(dbgs() << "SLP: \tWe can schedule PHIs:" << *UI << ". \n"); - continue; - } - - // Check if this is a safe in-tree user. - if (ScalarToTreeEntry.count(UI)) { - int Idx = ScalarToTreeEntry[UI]; - int VecLocation = VectorizableTree[Idx].LastScalarIndex; - if (VecLocation <= MyLastIndex) { - DEBUG(dbgs() << "SLP: Gathering due to unschedulable vector. \n"); - newTreeEntry(VL, false); - return; - } - DEBUG(dbgs() << "SLP: In-tree user (" << *UI << ") at #" << - VecLocation << " vector value (" << *Scalar << ") at #" - << MyLastIndex << ".\n"); - continue; - } - - // Ignore users in the user ignore list. - if (std::find(UserIgnoreList.begin(), UserIgnoreList.end(), UI) != - UserIgnoreList.end()) - continue; - - // Make sure that we can schedule this unknown user. - BlockNumbering &BN = getBlockNumbering(BB); - int UserIndex = BN.getIndex(UI); - if (UserIndex < MyLastIndex) { - - DEBUG(dbgs() << "SLP: Can't schedule extractelement for " - << *UI << ". \n"); - newTreeEntry(VL, false); - return; - } - } + if (!DT->isReachableFromEntry(BB)) { + // Don't go into unreachable blocks. They may contain instructions with + // dependency cycles which confuse the final scheduling. + DEBUG(dbgs() << "SLP: bundle in unreachable block.\n"); + newTreeEntry(VL, false); + return; } - + // Check that every instructions appears once in this bundle. for (unsigned i = 0, e = VL.size(); i < e; ++i) for (unsigned j = i+1; j < e; ++j) @@ -794,38 +1119,19 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { return; } - // Check that instructions in this bundle don't reference other instructions. - // The runtime of this check is O(N * N-1 * uses(N)) and a typical N is 4. - for (unsigned i = 0, e = VL.size(); i < e; ++i) { - for (User *U : VL[i]->users()) { - for (unsigned j = 0; j < e; ++j) { - if (i != j && U == VL[j]) { - DEBUG(dbgs() << "SLP: Intra-bundle dependencies!" << *U << ". \n"); - newTreeEntry(VL, false); - return; - } - } - } + auto &BSRef = BlocksSchedules[BB]; + if (!BSRef) { + BSRef = llvm::make_unique<BlockScheduling>(BB); } + BlockScheduling &BS = *BSRef.get(); - DEBUG(dbgs() << "SLP: We are able to schedule this bundle.\n"); - - // Check if it is safe to sink the loads or the stores. - if (Opcode == Instruction::Load || Opcode == Instruction::Store) { - Instruction *Last = getLastInstruction(VL); - - for (unsigned i = 0, e = VL.size(); i < e; ++i) { - if (VL[i] == Last) - continue; - Value *Barrier = getSinkBarrier(cast<Instruction>(VL[i]), Last); - if (Barrier) { - DEBUG(dbgs() << "SLP: Can't sink " << *VL[i] << "\n down to " << *Last - << "\n because of " << *Barrier << ". Gathering.\n"); - newTreeEntry(VL, false); - return; - } - } + if (!BS.tryScheduleBundle(VL, this)) { + DEBUG(dbgs() << "SLP: We are not able to schedule this bundle!\n"); + BS.cancelScheduling(VL); + newTreeEntry(VL, false); + return; } + DEBUG(dbgs() << "SLP: We are able to schedule this bundle.\n"); switch (Opcode) { case Instruction::PHI: { @@ -838,6 +1144,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { cast<PHINode>(VL[j])->getIncomingValueForBlock(PH->getIncomingBlock(i))); if (Term) { DEBUG(dbgs() << "SLP: Need to swizzle PHINodes (TerminatorInst use).\n"); + BS.cancelScheduling(VL); newTreeEntry(VL, false); return; } @@ -861,6 +1168,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { bool Reuse = CanReuseExtract(VL); if (Reuse) { DEBUG(dbgs() << "SLP: Reusing extract sequence.\n"); + } else { + BS.cancelScheduling(VL); } newTreeEntry(VL, Reuse); return; @@ -869,12 +1178,23 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { // Check if the loads are consecutive or of we need to swizzle them. for (unsigned i = 0, e = VL.size() - 1; i < e; ++i) { LoadInst *L = cast<LoadInst>(VL[i]); - if (!L->isSimple() || !isConsecutiveAccess(VL[i], VL[i + 1])) { + if (!L->isSimple()) { + BS.cancelScheduling(VL); newTreeEntry(VL, false); - DEBUG(dbgs() << "SLP: Need to swizzle loads.\n"); + DEBUG(dbgs() << "SLP: Gathering non-simple loads.\n"); + return; + } + if (!isConsecutiveAccess(VL[i], VL[i + 1])) { + if (VL.size() == 2 && isConsecutiveAccess(VL[1], VL[0])) { + ++NumLoadsWantToChangeOrder; + } + BS.cancelScheduling(VL); + newTreeEntry(VL, false); + DEBUG(dbgs() << "SLP: Gathering non-consecutive loads.\n"); return; } } + ++NumLoadsWantToKeepOrder; newTreeEntry(VL, true); DEBUG(dbgs() << "SLP: added a vector of loads.\n"); return; @@ -895,6 +1215,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { for (unsigned i = 0; i < VL.size(); ++i) { Type *Ty = cast<Instruction>(VL[i])->getOperand(0)->getType(); if (Ty != SrcTy || Ty->isAggregateType() || Ty->isVectorTy()) { + BS.cancelScheduling(VL); newTreeEntry(VL, false); DEBUG(dbgs() << "SLP: Gathering casts with different src types.\n"); return; @@ -922,6 +1243,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { CmpInst *Cmp = cast<CmpInst>(VL[i]); if (Cmp->getPredicate() != P0 || Cmp->getOperand(0)->getType() != ComparedTy) { + BS.cancelScheduling(VL); newTreeEntry(VL, false); DEBUG(dbgs() << "SLP: Gathering cmp with different predicate.\n"); return; @@ -968,20 +1290,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { if (isa<BinaryOperator>(VL0) && VL0->isCommutative()) { ValueList Left, Right; reorderInputsAccordingToOpcode(VL, Left, Right); - BasicBlock *LeftBB = getSameBlock(Left); - BasicBlock *RightBB = getSameBlock(Right); - // If we have common uses on separate paths in the tree make sure we - // process the one with greater common depth first. - // We can use block numbering to determine the subtree traversal as - // earler user has to come in between the common use and the later user. - if (LeftBB && RightBB && LeftBB == RightBB && - getLastIndex(Right) > getLastIndex(Left)) { - buildTree_rec(Right, Depth + 1); - buildTree_rec(Left, Depth + 1); - } else { - buildTree_rec(Left, Depth + 1); - buildTree_rec(Right, Depth + 1); - } + buildTree_rec(Left, Depth + 1); + buildTree_rec(Right, Depth + 1); return; } @@ -1000,6 +1310,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { for (unsigned j = 0; j < VL.size(); ++j) { if (cast<Instruction>(VL[j])->getNumOperands() != 2) { DEBUG(dbgs() << "SLP: not-vectorizable GEP (nested indexes).\n"); + BS.cancelScheduling(VL); newTreeEntry(VL, false); return; } @@ -1012,6 +1323,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { Type *CurTy = cast<Instruction>(VL[j])->getOperand(0)->getType(); if (Ty0 != CurTy) { DEBUG(dbgs() << "SLP: not-vectorizable GEP (different types).\n"); + BS.cancelScheduling(VL); newTreeEntry(VL, false); return; } @@ -1023,6 +1335,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { if (!isa<ConstantInt>(Op)) { DEBUG( dbgs() << "SLP: not-vectorizable GEP (non-constant indexes).\n"); + BS.cancelScheduling(VL); newTreeEntry(VL, false); return; } @@ -1044,6 +1357,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { // Check if the stores are consecutive or of we need to swizzle them. for (unsigned i = 0, e = VL.size() - 1; i < e; ++i) if (!isConsecutiveAccess(VL[i], VL[i + 1])) { + BS.cancelScheduling(VL); newTreeEntry(VL, false); DEBUG(dbgs() << "SLP: Non-consecutive store.\n"); return; @@ -1056,8 +1370,6 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { for (unsigned j = 0; j < VL.size(); ++j) Operands.push_back(cast<Instruction>(VL[j])->getOperand(0)); - // We can ignore these values because we are sinking them down. - MemBarrierIgnoreList.insert(VL.begin(), VL.end()); buildTree_rec(Operands, Depth + 1); return; } @@ -1068,6 +1380,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { // represented by an intrinsic call Intrinsic::ID ID = getIntrinsicIDForCall(CI, TLI); if (!isTriviallyVectorizable(ID)) { + BS.cancelScheduling(VL); newTreeEntry(VL, false); DEBUG(dbgs() << "SLP: Non-vectorizable call.\n"); return; @@ -1080,6 +1393,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { CallInst *CI2 = dyn_cast<CallInst>(VL[i]); if (!CI2 || CI2->getCalledFunction() != Int || getIntrinsicIDForCall(CI2, TLI) != ID) { + BS.cancelScheduling(VL); newTreeEntry(VL, false); DEBUG(dbgs() << "SLP: mismatched calls:" << *CI << "!=" << *VL[i] << "\n"); @@ -1090,6 +1404,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { if (hasVectorInstrinsicScalarOpd(ID, 1)) { Value *A1J = CI2->getArgOperand(1); if (A1I != A1J) { + BS.cancelScheduling(VL); newTreeEntry(VL, false); DEBUG(dbgs() << "SLP: mismatched arguments in call:" << *CI << " argument "<< A1I<<"!=" << A1J @@ -1115,6 +1430,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { // If this is not an alternate sequence of opcode like add-sub // then do not vectorize this instruction. if (!isAltShuffle) { + BS.cancelScheduling(VL); newTreeEntry(VL, false); DEBUG(dbgs() << "SLP: ShuffleVector are not vectorized.\n"); return; @@ -1132,6 +1448,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { return; } default: + BS.cancelScheduling(VL); newTreeEntry(VL, false); DEBUG(dbgs() << "SLP: Gathering unknown instruction.\n"); return; @@ -1234,6 +1551,10 @@ int BoUpSLP::getEntryCost(TreeEntry *E) { TargetTransformInfo::OK_AnyValue; TargetTransformInfo::OperandValueKind Op2VK = TargetTransformInfo::OK_UniformConstantValue; + TargetTransformInfo::OperandValueProperties Op1VP = + TargetTransformInfo::OP_None; + TargetTransformInfo::OperandValueProperties Op2VP = + TargetTransformInfo::OP_None; // If all operands are exactly the same ConstantInt then set the // operand kind to OK_UniformConstantValue. @@ -1255,11 +1576,17 @@ int BoUpSLP::getEntryCost(TreeEntry *E) { CInt != cast<ConstantInt>(I->getOperand(1))) Op2VK = TargetTransformInfo::OK_NonUniformConstantValue; } + // FIXME: Currently cost of model modification for division by + // power of 2 is handled only for X86. Add support for other targets. + if (Op2VK == TargetTransformInfo::OK_UniformConstantValue && CInt && + CInt->getValue().isPowerOf2()) + Op2VP = TargetTransformInfo::OP_PowerOf2; - ScalarCost = - VecTy->getNumElements() * - TTI->getArithmeticInstrCost(Opcode, ScalarTy, Op1VK, Op2VK); - VecCost = TTI->getArithmeticInstrCost(Opcode, VecTy, Op1VK, Op2VK); + ScalarCost = VecTy->getNumElements() * + TTI->getArithmeticInstrCost(Opcode, ScalarTy, Op1VK, Op2VK, + Op1VP, Op2VP); + VecCost = TTI->getArithmeticInstrCost(Opcode, VecTy, Op1VK, Op2VK, + Op1VP, Op2VP); } return VecCost - ScalarCost; } @@ -1364,6 +1691,68 @@ bool BoUpSLP::isFullyVectorizableTinyTree() { return true; } +int BoUpSLP::getSpillCost() { + // Walk from the bottom of the tree to the top, tracking which values are + // live. When we see a call instruction that is not part of our tree, + // query TTI to see if there is a cost to keeping values live over it + // (for example, if spills and fills are required). + unsigned BundleWidth = VectorizableTree.front().Scalars.size(); + int Cost = 0; + + SmallPtrSet<Instruction*, 4> LiveValues; + Instruction *PrevInst = nullptr; + + for (unsigned N = 0; N < VectorizableTree.size(); ++N) { + Instruction *Inst = dyn_cast<Instruction>(VectorizableTree[N].Scalars[0]); + if (!Inst) + continue; + + if (!PrevInst) { + PrevInst = Inst; + continue; + } + + DEBUG( + dbgs() << "SLP: #LV: " << LiveValues.size(); + for (auto *X : LiveValues) + dbgs() << " " << X->getName(); + dbgs() << ", Looking at "; + Inst->dump(); + ); + + // Update LiveValues. + LiveValues.erase(PrevInst); + for (auto &J : PrevInst->operands()) { + if (isa<Instruction>(&*J) && ScalarToTreeEntry.count(&*J)) + LiveValues.insert(cast<Instruction>(&*J)); + } + + // Now find the sequence of instructions between PrevInst and Inst. + BasicBlock::reverse_iterator InstIt(Inst), PrevInstIt(PrevInst); + --PrevInstIt; + while (InstIt != PrevInstIt) { + if (PrevInstIt == PrevInst->getParent()->rend()) { + PrevInstIt = Inst->getParent()->rbegin(); + continue; + } + + if (isa<CallInst>(&*PrevInstIt) && &*PrevInstIt != PrevInst) { + SmallVector<Type*, 4> V; + for (auto *II : LiveValues) + V.push_back(VectorType::get(II->getType(), BundleWidth)); + Cost += TTI->getCostOfKeepingLiveOverCall(V); + } + + ++PrevInstIt; + } + + PrevInst = Inst; + } + + DEBUG(dbgs() << "SLP: SpillCost=" << Cost << "\n"); + return Cost; +} + int BoUpSLP::getTreeCost() { int Cost = 0; DEBUG(dbgs() << "SLP: Calculating cost for tree of size " << @@ -1391,7 +1780,13 @@ int BoUpSLP::getTreeCost() { for (UserList::iterator I = ExternalUses.begin(), E = ExternalUses.end(); I != E; ++I) { // We only add extract cost once for the same scalar. - if (!ExtractCostCalculated.insert(I->Scalar)) + if (!ExtractCostCalculated.insert(I->Scalar).second) + continue; + + // Uses by ephemeral values are free (because the ephemeral value will be + // removed prior to code generation, and so the extraction will be + // removed as well). + if (EphValues.count(I->User)) continue; VectorType *VecTy = VectorType::get(I->Scalar->getType(), BundleWidth); @@ -1399,6 +1794,8 @@ int BoUpSLP::getTreeCost() { I->Lane); } + Cost += getSpillCost(); + DEBUG(dbgs() << "SLP: Total Cost " << Cost + ExtractCost<< ".\n"); return Cost + ExtractCost; } @@ -1420,14 +1817,6 @@ int BoUpSLP::getGatherCost(ArrayRef<Value *> VL) { return getGatherCost(VecTy); } -AliasAnalysis::Location BoUpSLP::getLocation(Instruction *I) { - if (StoreInst *SI = dyn_cast<StoreInst>(I)) - return AA->getLocation(SI); - if (LoadInst *LI = dyn_cast<LoadInst>(I)) - return AA->getLocation(LI); - return AliasAnalysis::Location(); -} - Value *BoUpSLP::getPointerOperand(Value *I) { if (LoadInst *LI = dyn_cast<LoadInst>(I)) return LI->getPointerOperand(); @@ -1485,59 +1874,9 @@ bool BoUpSLP::isConsecutiveAccess(Value *A, Value *B) { return X == PtrSCEVB; } -Value *BoUpSLP::getSinkBarrier(Instruction *Src, Instruction *Dst) { - assert(Src->getParent() == Dst->getParent() && "Not the same BB"); - BasicBlock::iterator I = Src, E = Dst; - /// Scan all of the instruction from SRC to DST and check if - /// the source may alias. - for (++I; I != E; ++I) { - // Ignore store instructions that are marked as 'ignore'. - if (MemBarrierIgnoreList.count(I)) - continue; - if (Src->mayWriteToMemory()) /* Write */ { - if (!I->mayReadOrWriteMemory()) - continue; - } else /* Read */ { - if (!I->mayWriteToMemory()) - continue; - } - AliasAnalysis::Location A = getLocation(&*I); - AliasAnalysis::Location B = getLocation(Src); - - if (!A.Ptr || !B.Ptr || AA->alias(A, B)) - return I; - } - return nullptr; -} - -int BoUpSLP::getLastIndex(ArrayRef<Value *> VL) { - BasicBlock *BB = cast<Instruction>(VL[0])->getParent(); - assert(BB == getSameBlock(VL) && "Invalid block"); - BlockNumbering &BN = getBlockNumbering(BB); - - int MaxIdx = BN.getIndex(BB->getFirstNonPHI()); - for (unsigned i = 0, e = VL.size(); i < e; ++i) - MaxIdx = std::max(MaxIdx, BN.getIndex(cast<Instruction>(VL[i]))); - return MaxIdx; -} - -Instruction *BoUpSLP::getLastInstruction(ArrayRef<Value *> VL) { - BasicBlock *BB = cast<Instruction>(VL[0])->getParent(); - assert(BB == getSameBlock(VL) && "Invalid block"); - BlockNumbering &BN = getBlockNumbering(BB); - - int MaxIdx = BN.getIndex(cast<Instruction>(VL[0])); - for (unsigned i = 1, e = VL.size(); i < e; ++i) - MaxIdx = std::max(MaxIdx, BN.getIndex(cast<Instruction>(VL[i]))); - Instruction *I = BN.getInstruction(MaxIdx); - assert(I && "bad location"); - return I; -} - void BoUpSLP::setInsertPointAfterBundle(ArrayRef<Value *> VL) { Instruction *VL0 = cast<Instruction>(VL[0]); - Instruction *LastInst = getLastInstruction(VL); - BasicBlock::iterator NextInst = LastInst; + BasicBlock::iterator NextInst = VL0; ++NextInst; Builder.SetInsertPoint(VL0->getParent(), NextInst); Builder.SetCurrentDebugLocation(VL0->getDebugLoc()); @@ -1620,6 +1959,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { setInsertPointAfterBundle(E->Scalars); return Gather(E->Scalars, VecTy); } + unsigned Opcode = getSameOpcode(E->Scalars); switch (Opcode) { @@ -1638,7 +1978,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { ValueList Operands; BasicBlock *IBB = PH->getIncomingBlock(i); - if (!VisitedBBs.insert(IBB)) { + if (!VisitedBBs.insert(IBB).second) { NewPhi->addIncoming(NewPhi->getIncomingValueForBlock(IBB), IBB); continue; } @@ -1693,6 +2033,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { CastInst *CI = dyn_cast<CastInst>(VL0); Value *V = Builder.CreateCast(CI->getOpcode(), InVec, VecTy); E->VectorizedValue = V; + ++NumVectorInstructions; return V; } case Instruction::FCmp: @@ -1719,6 +2060,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { V = Builder.CreateICmp(P0, L, R); E->VectorizedValue = V; + ++NumVectorInstructions; return V; } case Instruction::Select: { @@ -1740,6 +2082,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { Value *V = Builder.CreateSelect(Cond, True, False); E->VectorizedValue = V; + ++NumVectorInstructions; return V; } case Instruction::Add: @@ -1784,6 +2127,8 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { BinaryOperator *BinOp = cast<BinaryOperator>(VL0); Value *V = Builder.CreateBinOp(BinOp->getOpcode(), LHS, RHS); E->VectorizedValue = V; + propagateIRFlags(E->VectorizedValue, E->Scalars); + ++NumVectorInstructions; if (Instruction *I = dyn_cast<Instruction>(V)) return propagateMetadata(I, E->Scalars); @@ -1796,16 +2141,26 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { setInsertPointAfterBundle(E->Scalars); LoadInst *LI = cast<LoadInst>(VL0); + Type *ScalarLoadTy = LI->getType(); unsigned AS = LI->getPointerAddressSpace(); Value *VecPtr = Builder.CreateBitCast(LI->getPointerOperand(), VecTy->getPointerTo(AS)); + + // The pointer operand uses an in-tree scalar so we add the new BitCast to + // ExternalUses list to make sure that an extract will be generated in the + // future. + if (ScalarToTreeEntry.count(LI->getPointerOperand())) + ExternalUses.push_back( + ExternalUser(LI->getPointerOperand(), cast<User>(VecPtr), 0)); + unsigned Alignment = LI->getAlignment(); LI = Builder.CreateLoad(VecPtr); if (!Alignment) - Alignment = DL->getABITypeAlignment(LI->getPointerOperand()->getType()); + Alignment = DL->getABITypeAlignment(ScalarLoadTy); LI->setAlignment(Alignment); E->VectorizedValue = LI; + ++NumVectorInstructions; return propagateMetadata(LI, E->Scalars); } case Instruction::Store: { @@ -1823,10 +2178,19 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { Value *VecPtr = Builder.CreateBitCast(SI->getPointerOperand(), VecTy->getPointerTo(AS)); StoreInst *S = Builder.CreateStore(VecValue, VecPtr); + + // The pointer operand uses an in-tree scalar so we add the new BitCast to + // ExternalUses list to make sure that an extract will be generated in the + // future. + if (ScalarToTreeEntry.count(SI->getPointerOperand())) + ExternalUses.push_back( + ExternalUser(SI->getPointerOperand(), cast<User>(VecPtr), 0)); + if (!Alignment) - Alignment = DL->getABITypeAlignment(SI->getPointerOperand()->getType()); + Alignment = DL->getABITypeAlignment(SI->getValueOperand()->getType()); S->setAlignment(Alignment); E->VectorizedValue = S; + ++NumVectorInstructions; return propagateMetadata(S, E->Scalars); } case Instruction::GetElementPtr: { @@ -1851,6 +2215,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { Value *V = Builder.CreateGEP(Op0, OpVecs); E->VectorizedValue = V; + ++NumVectorInstructions; if (Instruction *I = dyn_cast<Instruction>(V)) return propagateMetadata(I, E->Scalars); @@ -1862,6 +2227,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { setInsertPointAfterBundle(E->Scalars); Function *FI; Intrinsic::ID IID = Intrinsic::not_intrinsic; + Value *ScalarArg = nullptr; if (CI && (FI = CI->getCalledFunction())) { IID = (Intrinsic::ID) FI->getIntrinsicID(); } @@ -1872,6 +2238,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { // a scalar. This argument should not be vectorized. if (hasVectorInstrinsicScalarOpd(IID, 1) && j == 1) { CallInst *CEI = cast<CallInst>(E->Scalars[0]); + ScalarArg = CEI->getArgOperand(j); OpVecs.push_back(CEI->getArgOperand(j)); continue; } @@ -1890,7 +2257,15 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { Type *Tys[] = { VectorType::get(CI->getType(), E->Scalars.size()) }; Function *CF = Intrinsic::getDeclaration(M, ID, Tys); Value *V = Builder.CreateCall(CF, OpVecs); + + // The scalar argument uses an in-tree scalar so we add the new vectorized + // call to ExternalUses list to make sure that an extract will be + // generated in the future. + if (ScalarArg && ScalarToTreeEntry.count(ScalarArg)) + ExternalUses.push_back(ExternalUser(ScalarArg, cast<User>(V), 0)); + E->VectorizedValue = V; + ++NumVectorInstructions; return V; } case Instruction::ShuffleVector: { @@ -1916,21 +2291,29 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { BinaryOperator *BinOp1 = cast<BinaryOperator>(VL1); Value *V1 = Builder.CreateBinOp(BinOp1->getOpcode(), LHS, RHS); - // Create appropriate shuffle to take alternative operations from - // the vector. - std::vector<Constant *> Mask(E->Scalars.size()); + // Create shuffle to take alternate operations from the vector. + // Also, gather up odd and even scalar ops to propagate IR flags to + // each vector operation. + ValueList OddScalars, EvenScalars; unsigned e = E->Scalars.size(); + SmallVector<Constant *, 8> Mask(e); for (unsigned i = 0; i < e; ++i) { - if (i & 1) + if (i & 1) { Mask[i] = Builder.getInt32(e + i); - else + OddScalars.push_back(E->Scalars[i]); + } else { Mask[i] = Builder.getInt32(i); + EvenScalars.push_back(E->Scalars[i]); + } } Value *ShuffleMask = ConstantVector::get(Mask); + propagateIRFlags(V0, EvenScalars); + propagateIRFlags(V1, OddScalars); Value *V = Builder.CreateShuffleVector(V0, V1, ShuffleMask); E->VectorizedValue = V; + ++NumVectorInstructions; if (Instruction *I = dyn_cast<Instruction>(V)) return propagateMetadata(I, E->Scalars); @@ -1943,6 +2326,12 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { } Value *BoUpSLP::vectorizeTree() { + + // All blocks must be scheduled before any instructions are inserted. + for (auto &BSIter : BlocksSchedules) { + scheduleBlock(BSIter.second.get()); + } + Builder.SetInsertPoint(F->getEntryBlock().begin()); vectorizeTree(&VectorizableTree[0]); @@ -2027,13 +2416,10 @@ Value *BoUpSLP::vectorizeTree() { Scalar->replaceAllUsesWith(Undef); } DEBUG(dbgs() << "SLP: \tErasing scalar:" << *Scalar << ".\n"); - cast<Instruction>(Scalar)->eraseFromParent(); + eraseInstruction(cast<Instruction>(Scalar)); } } - for (auto &BN : BlocksNumbers) - BN.second.forget(); - Builder.ClearInsertionPoint(); return VectorizableTree[0].VectorizedValue; @@ -2112,7 +2498,7 @@ void BoUpSLP::optimizeGatherSequence() { if (In->isIdenticalTo(*v) && DT->dominates((*v)->getParent(), In->getParent())) { In->replaceAllUsesWith(*v); - In->eraseFromParent(); + eraseInstruction(In); In = nullptr; break; } @@ -2127,6 +2513,354 @@ void BoUpSLP::optimizeGatherSequence() { GatherSeq.clear(); } +// Groups the instructions to a bundle (which is then a single scheduling entity) +// and schedules instructions until the bundle gets ready. +bool BoUpSLP::BlockScheduling::tryScheduleBundle(ArrayRef<Value *> VL, + BoUpSLP *SLP) { + if (isa<PHINode>(VL[0])) + return true; + + // Initialize the instruction bundle. + Instruction *OldScheduleEnd = ScheduleEnd; + ScheduleData *PrevInBundle = nullptr; + ScheduleData *Bundle = nullptr; + bool ReSchedule = false; + DEBUG(dbgs() << "SLP: bundle: " << *VL[0] << "\n"); + for (Value *V : VL) { + extendSchedulingRegion(V); + ScheduleData *BundleMember = getScheduleData(V); + assert(BundleMember && + "no ScheduleData for bundle member (maybe not in same basic block)"); + if (BundleMember->IsScheduled) { + // A bundle member was scheduled as single instruction before and now + // needs to be scheduled as part of the bundle. We just get rid of the + // existing schedule. + DEBUG(dbgs() << "SLP: reset schedule because " << *BundleMember + << " was already scheduled\n"); + ReSchedule = true; + } + assert(BundleMember->isSchedulingEntity() && + "bundle member already part of other bundle"); + if (PrevInBundle) { + PrevInBundle->NextInBundle = BundleMember; + } else { + Bundle = BundleMember; + } + BundleMember->UnscheduledDepsInBundle = 0; + Bundle->UnscheduledDepsInBundle += BundleMember->UnscheduledDeps; + + // Group the instructions to a bundle. + BundleMember->FirstInBundle = Bundle; + PrevInBundle = BundleMember; + } + if (ScheduleEnd != OldScheduleEnd) { + // The scheduling region got new instructions at the lower end (or it is a + // new region for the first bundle). This makes it necessary to + // recalculate all dependencies. + // It is seldom that this needs to be done a second time after adding the + // initial bundle to the region. + for (auto *I = ScheduleStart; I != ScheduleEnd; I = I->getNextNode()) { + ScheduleData *SD = getScheduleData(I); + SD->clearDependencies(); + } + ReSchedule = true; + } + if (ReSchedule) { + resetSchedule(); + initialFillReadyList(ReadyInsts); + } + + DEBUG(dbgs() << "SLP: try schedule bundle " << *Bundle << " in block " + << BB->getName() << "\n"); + + calculateDependencies(Bundle, true, SLP); + + // Now try to schedule the new bundle. As soon as the bundle is "ready" it + // means that there are no cyclic dependencies and we can schedule it. + // Note that's important that we don't "schedule" the bundle yet (see + // cancelScheduling). + while (!Bundle->isReady() && !ReadyInsts.empty()) { + + ScheduleData *pickedSD = ReadyInsts.back(); + ReadyInsts.pop_back(); + + if (pickedSD->isSchedulingEntity() && pickedSD->isReady()) { + schedule(pickedSD, ReadyInsts); + } + } + return Bundle->isReady(); +} + +void BoUpSLP::BlockScheduling::cancelScheduling(ArrayRef<Value *> VL) { + if (isa<PHINode>(VL[0])) + return; + + ScheduleData *Bundle = getScheduleData(VL[0]); + DEBUG(dbgs() << "SLP: cancel scheduling of " << *Bundle << "\n"); + assert(!Bundle->IsScheduled && + "Can't cancel bundle which is already scheduled"); + assert(Bundle->isSchedulingEntity() && Bundle->isPartOfBundle() && + "tried to unbundle something which is not a bundle"); + + // Un-bundle: make single instructions out of the bundle. + ScheduleData *BundleMember = Bundle; + while (BundleMember) { + assert(BundleMember->FirstInBundle == Bundle && "corrupt bundle links"); + BundleMember->FirstInBundle = BundleMember; + ScheduleData *Next = BundleMember->NextInBundle; + BundleMember->NextInBundle = nullptr; + BundleMember->UnscheduledDepsInBundle = BundleMember->UnscheduledDeps; + if (BundleMember->UnscheduledDepsInBundle == 0) { + ReadyInsts.insert(BundleMember); + } + BundleMember = Next; + } +} + +void BoUpSLP::BlockScheduling::extendSchedulingRegion(Value *V) { + if (getScheduleData(V)) + return; + Instruction *I = dyn_cast<Instruction>(V); + assert(I && "bundle member must be an instruction"); + assert(!isa<PHINode>(I) && "phi nodes don't need to be scheduled"); + if (!ScheduleStart) { + // It's the first instruction in the new region. + initScheduleData(I, I->getNextNode(), nullptr, nullptr); + ScheduleStart = I; + ScheduleEnd = I->getNextNode(); + assert(ScheduleEnd && "tried to vectorize a TerminatorInst?"); + DEBUG(dbgs() << "SLP: initialize schedule region to " << *I << "\n"); + return; + } + // Search up and down at the same time, because we don't know if the new + // instruction is above or below the existing scheduling region. + BasicBlock::reverse_iterator UpIter(ScheduleStart); + BasicBlock::reverse_iterator UpperEnd = BB->rend(); + BasicBlock::iterator DownIter(ScheduleEnd); + BasicBlock::iterator LowerEnd = BB->end(); + for (;;) { + if (UpIter != UpperEnd) { + if (&*UpIter == I) { + initScheduleData(I, ScheduleStart, nullptr, FirstLoadStoreInRegion); + ScheduleStart = I; + DEBUG(dbgs() << "SLP: extend schedule region start to " << *I << "\n"); + return; + } + UpIter++; + } + if (DownIter != LowerEnd) { + if (&*DownIter == I) { + initScheduleData(ScheduleEnd, I->getNextNode(), LastLoadStoreInRegion, + nullptr); + ScheduleEnd = I->getNextNode(); + assert(ScheduleEnd && "tried to vectorize a TerminatorInst?"); + DEBUG(dbgs() << "SLP: extend schedule region end to " << *I << "\n"); + return; + } + DownIter++; + } + assert((UpIter != UpperEnd || DownIter != LowerEnd) && + "instruction not found in block"); + } +} + +void BoUpSLP::BlockScheduling::initScheduleData(Instruction *FromI, + Instruction *ToI, + ScheduleData *PrevLoadStore, + ScheduleData *NextLoadStore) { + ScheduleData *CurrentLoadStore = PrevLoadStore; + for (Instruction *I = FromI; I != ToI; I = I->getNextNode()) { + ScheduleData *SD = ScheduleDataMap[I]; + if (!SD) { + // Allocate a new ScheduleData for the instruction. + if (ChunkPos >= ChunkSize) { + ScheduleDataChunks.push_back( + llvm::make_unique<ScheduleData[]>(ChunkSize)); + ChunkPos = 0; + } + SD = &(ScheduleDataChunks.back()[ChunkPos++]); + ScheduleDataMap[I] = SD; + SD->Inst = I; + } + assert(!isInSchedulingRegion(SD) && + "new ScheduleData already in scheduling region"); + SD->init(SchedulingRegionID); + + if (I->mayReadOrWriteMemory()) { + // Update the linked list of memory accessing instructions. + if (CurrentLoadStore) { + CurrentLoadStore->NextLoadStore = SD; + } else { + FirstLoadStoreInRegion = SD; + } + CurrentLoadStore = SD; + } + } + if (NextLoadStore) { + if (CurrentLoadStore) + CurrentLoadStore->NextLoadStore = NextLoadStore; + } else { + LastLoadStoreInRegion = CurrentLoadStore; + } +} + +void BoUpSLP::BlockScheduling::calculateDependencies(ScheduleData *SD, + bool InsertInReadyList, + BoUpSLP *SLP) { + assert(SD->isSchedulingEntity()); + + SmallVector<ScheduleData *, 10> WorkList; + WorkList.push_back(SD); + + while (!WorkList.empty()) { + ScheduleData *SD = WorkList.back(); + WorkList.pop_back(); + + ScheduleData *BundleMember = SD; + while (BundleMember) { + assert(isInSchedulingRegion(BundleMember)); + if (!BundleMember->hasValidDependencies()) { + + DEBUG(dbgs() << "SLP: update deps of " << *BundleMember << "\n"); + BundleMember->Dependencies = 0; + BundleMember->resetUnscheduledDeps(); + + // Handle def-use chain dependencies. + for (User *U : BundleMember->Inst->users()) { + if (isa<Instruction>(U)) { + ScheduleData *UseSD = getScheduleData(U); + if (UseSD && isInSchedulingRegion(UseSD->FirstInBundle)) { + BundleMember->Dependencies++; + ScheduleData *DestBundle = UseSD->FirstInBundle; + if (!DestBundle->IsScheduled) { + BundleMember->incrementUnscheduledDeps(1); + } + if (!DestBundle->hasValidDependencies()) { + WorkList.push_back(DestBundle); + } + } + } else { + // I'm not sure if this can ever happen. But we need to be safe. + // This lets the instruction/bundle never be scheduled and eventally + // disable vectorization. + BundleMember->Dependencies++; + BundleMember->incrementUnscheduledDeps(1); + } + } + + // Handle the memory dependencies. + ScheduleData *DepDest = BundleMember->NextLoadStore; + if (DepDest) { + Instruction *SrcInst = BundleMember->Inst; + AliasAnalysis::Location SrcLoc = getLocation(SrcInst, SLP->AA); + bool SrcMayWrite = BundleMember->Inst->mayWriteToMemory(); + + while (DepDest) { + assert(isInSchedulingRegion(DepDest)); + if (SrcMayWrite || DepDest->Inst->mayWriteToMemory()) { + if (SLP->isAliased(SrcLoc, SrcInst, DepDest->Inst)) { + DepDest->MemoryDependencies.push_back(BundleMember); + BundleMember->Dependencies++; + ScheduleData *DestBundle = DepDest->FirstInBundle; + if (!DestBundle->IsScheduled) { + BundleMember->incrementUnscheduledDeps(1); + } + if (!DestBundle->hasValidDependencies()) { + WorkList.push_back(DestBundle); + } + } + } + DepDest = DepDest->NextLoadStore; + } + } + } + BundleMember = BundleMember->NextInBundle; + } + if (InsertInReadyList && SD->isReady()) { + ReadyInsts.push_back(SD); + DEBUG(dbgs() << "SLP: gets ready on update: " << *SD->Inst << "\n"); + } + } +} + +void BoUpSLP::BlockScheduling::resetSchedule() { + assert(ScheduleStart && + "tried to reset schedule on block which has not been scheduled"); + for (Instruction *I = ScheduleStart; I != ScheduleEnd; I = I->getNextNode()) { + ScheduleData *SD = getScheduleData(I); + assert(isInSchedulingRegion(SD)); + SD->IsScheduled = false; + SD->resetUnscheduledDeps(); + } + ReadyInsts.clear(); +} + +void BoUpSLP::scheduleBlock(BlockScheduling *BS) { + + if (!BS->ScheduleStart) + return; + + DEBUG(dbgs() << "SLP: schedule block " << BS->BB->getName() << "\n"); + + BS->resetSchedule(); + + // For the real scheduling we use a more sophisticated ready-list: it is + // sorted by the original instruction location. This lets the final schedule + // be as close as possible to the original instruction order. + struct ScheduleDataCompare { + bool operator()(ScheduleData *SD1, ScheduleData *SD2) { + return SD2->SchedulingPriority < SD1->SchedulingPriority; + } + }; + std::set<ScheduleData *, ScheduleDataCompare> ReadyInsts; + + // Ensure that all depencency data is updated and fill the ready-list with + // initial instructions. + int Idx = 0; + int NumToSchedule = 0; + for (auto *I = BS->ScheduleStart; I != BS->ScheduleEnd; + I = I->getNextNode()) { + ScheduleData *SD = BS->getScheduleData(I); + assert( + SD->isPartOfBundle() == (ScalarToTreeEntry.count(SD->Inst) != 0) && + "scheduler and vectorizer have different opinion on what is a bundle"); + SD->FirstInBundle->SchedulingPriority = Idx++; + if (SD->isSchedulingEntity()) { + BS->calculateDependencies(SD, false, this); + NumToSchedule++; + } + } + BS->initialFillReadyList(ReadyInsts); + + Instruction *LastScheduledInst = BS->ScheduleEnd; + + // Do the "real" scheduling. + while (!ReadyInsts.empty()) { + ScheduleData *picked = *ReadyInsts.begin(); + ReadyInsts.erase(ReadyInsts.begin()); + + // Move the scheduled instruction(s) to their dedicated places, if not + // there yet. + ScheduleData *BundleMember = picked; + while (BundleMember) { + Instruction *pickedInst = BundleMember->Inst; + if (LastScheduledInst->getNextNode() != pickedInst) { + BS->BB->getInstList().remove(pickedInst); + BS->BB->getInstList().insert(LastScheduledInst, pickedInst); + } + LastScheduledInst = pickedInst; + BundleMember = BundleMember->NextInBundle; + } + + BS->schedule(picked, ReadyInsts); + NumToSchedule--; + } + assert(NumToSchedule == 0 && "could not schedule all instructions"); + + // Avoid duplicate scheduling of the block. + BS->ScheduleStart = nullptr; +} + /// The SLPVectorizer Pass. struct SLPVectorizer : public FunctionPass { typedef SmallVector<StoreInst *, 8> StoreList; @@ -2146,6 +2880,7 @@ struct SLPVectorizer : public FunctionPass { AliasAnalysis *AA; LoopInfo *LI; DominatorTree *DT; + AssumptionCache *AC; bool runOnFunction(Function &F) override { if (skipOptnoneFunction(F)) @@ -2159,6 +2894,7 @@ struct SLPVectorizer : public FunctionPass { AA = &getAnalysis<AliasAnalysis>(); LI = &getAnalysis<LoopInfo>(); DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); StoreRefs.clear(); bool Changed = false; @@ -2181,7 +2917,10 @@ struct SLPVectorizer : public FunctionPass { // Use the bottom up slp vectorizer to construct chains that start with // store instructions. - BoUpSLP R(&F, SE, DL, TTI, TLI, AA, LI, DT); + BoUpSLP R(&F, SE, DL, TTI, TLI, AA, LI, DT, AC); + + // A general note: the vectorizer must use BoUpSLP::eraseInstruction() to + // delete instructions. // Scan the blocks in the function in post order. for (po_iterator<BasicBlock*> it = po_begin(&F.getEntryBlock()), @@ -2208,6 +2947,7 @@ struct SLPVectorizer : public FunctionPass { void getAnalysisUsage(AnalysisUsage &AU) const override { FunctionPass::getAnalysisUsage(AU); + AU.addRequired<AssumptionCacheTracker>(); AU.addRequired<ScalarEvolution>(); AU.addRequired<AliasAnalysis>(); AU.addRequired<TargetTransformInfo>(); @@ -2234,7 +2974,8 @@ private: /// scheduling and that don't need extracting. /// \returns true if a value was vectorized. bool tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R, - ArrayRef<Value *> BuildVector = None); + ArrayRef<Value *> BuildVector = None, + bool allowReorder = false); /// \brief Try to vectorize a chain that may start at the operands of \V; bool tryToVectorize(BinaryOperator *V, BoUpSLP &R); @@ -2404,11 +3145,12 @@ bool SLPVectorizer::tryToVectorizePair(Value *A, Value *B, BoUpSLP &R) { if (!A || !B) return false; Value *VL[] = { A, B }; - return tryToVectorizeList(VL, R); + return tryToVectorizeList(VL, R, None, true); } bool SLPVectorizer::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R, - ArrayRef<Value *> BuildVector) { + ArrayRef<Value *> BuildVector, + bool allowReorder) { if (VL.size() < 2) return false; @@ -2463,6 +3205,14 @@ bool SLPVectorizer::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R, BuildVectorSlice = BuildVector.slice(i, OpsWidth); R.buildTree(Ops, BuildVectorSlice); + // TODO: check if we can allow reordering also for other cases than + // tryToVectorizePair() + if (allowReorder && R.shouldReorder()) { + assert(Ops.size() == 2); + assert(BuildVectorSlice.empty()); + Value *ReorderedOps[] = { Ops[1], Ops[0] }; + R.buildTree(ReorderedOps, None); + } int Cost = R.getTreeCost(); if (Cost < -SLPCostThreshold) { @@ -2514,11 +3264,9 @@ bool SLPVectorizer::tryToVectorize(BinaryOperator *V, BoUpSLP &R) { BinaryOperator *B0 = dyn_cast<BinaryOperator>(B->getOperand(0)); BinaryOperator *B1 = dyn_cast<BinaryOperator>(B->getOperand(1)); if (tryToVectorizePair(A, B0, R)) { - B->moveBefore(V); return true; } if (tryToVectorizePair(A, B1, R)) { - B->moveBefore(V); return true; } } @@ -2528,11 +3276,9 @@ bool SLPVectorizer::tryToVectorize(BinaryOperator *V, BoUpSLP &R) { BinaryOperator *A0 = dyn_cast<BinaryOperator>(A->getOperand(0)); BinaryOperator *A1 = dyn_cast<BinaryOperator>(A->getOperand(1)); if (tryToVectorizePair(A0, B, R)) { - A->moveBefore(V); return true; } if (tryToVectorizePair(A1, B, R)) { - A->moveBefore(V); return true; } } @@ -2728,8 +3474,7 @@ public: unsigned i = 0; for (; i < NumReducedVals - ReduxWidth + 1; i += ReduxWidth) { - ArrayRef<Value *> ValsToReduce(&ReducedVals[i], ReduxWidth); - V.buildTree(ValsToReduce, ReductionOps); + V.buildTree(makeArrayRef(&ReducedVals[i], ReduxWidth), ReductionOps); // Estimate cost. int Cost = V.getTreeCost() + getReductionCost(TTI, ReducedVals[i]); @@ -2807,11 +3552,10 @@ private: /// \brief Emit a horizontal reduction of the vectorized value. Value *emitReduction(Value *VectorizedValue, IRBuilder<> &Builder) { assert(VectorizedValue && "Need to have a vectorized tree node"); - Instruction *ValToReduce = dyn_cast<Instruction>(VectorizedValue); assert(isPowerOf2_32(ReduxWidth) && "We only handle power-of-two reductions for now"); - Value *TmpVec = ValToReduce; + Value *TmpVec = VectorizedValue; for (unsigned i = ReduxWidth / 2; i != 0; i >>= 1) { if (IsPairwiseReduction) { Value *LeftMask = @@ -2921,8 +3665,7 @@ bool SLPVectorizer::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { // Try to vectorize them. unsigned NumElts = (SameTypeIt - IncIt); DEBUG(errs() << "SLP: Trying to vectorize starting at PHIs (" << NumElts << ")\n"); - if (NumElts > 1 && - tryToVectorizeList(ArrayRef<Value *>(IncIt, NumElts), R)) { + if (NumElts > 1 && tryToVectorizeList(makeArrayRef(IncIt, NumElts), R)) { // Success start over because instructions might have been changed. HaveVectorizedPhiNodes = true; Changed = true; @@ -2938,7 +3681,7 @@ bool SLPVectorizer::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e; it++) { // We may go through BB multiple times so skip the one we have checked. - if (!VisitedInstrs.insert(it)) + if (!VisitedInstrs.insert(it).second) continue; if (isa<DbgInfoIntrinsic>(it)) @@ -3002,6 +3745,21 @@ bool SLPVectorizer::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { } } + // Try to vectorize horizontal reductions feeding into a return. + if (ReturnInst *RI = dyn_cast<ReturnInst>(it)) + if (RI->getNumOperands() != 0) + if (BinaryOperator *BinOp = + dyn_cast<BinaryOperator>(RI->getOperand(0))) { + DEBUG(dbgs() << "SLP: Found a return to vectorize.\n"); + if (tryToVectorizePair(BinOp->getOperand(0), + BinOp->getOperand(1), R)) { + Changed = true; + it = BB->begin(); + e = BB->end(); + continue; + } + } + // Try to vectorize trees that start at compare instructions. if (CmpInst *CI = dyn_cast<CmpInst>(it)) { if (tryToVectorizePair(CI->getOperand(0), CI->getOperand(1), R)) { @@ -3014,15 +3772,15 @@ bool SLPVectorizer::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { } for (int i = 0; i < 2; ++i) { - if (BinaryOperator *BI = dyn_cast<BinaryOperator>(CI->getOperand(i))) { - if (tryToVectorizePair(BI->getOperand(0), BI->getOperand(1), R)) { - Changed = true; - // We would like to start over since some instructions are deleted - // and the iterator may become invalid value. - it = BB->begin(); - e = BB->end(); - } - } + if (BinaryOperator *BI = dyn_cast<BinaryOperator>(CI->getOperand(i))) { + if (tryToVectorizePair(BI->getOperand(0), BI->getOperand(1), R)) { + Changed = true; + // We would like to start over since some instructions are deleted + // and the iterator may become invalid value. + it = BB->begin(); + e = BB->end(); + } + } } continue; } @@ -3064,8 +3822,8 @@ bool SLPVectorizer::vectorizeStoreChains(BoUpSLP &R) { // Process the stores in chunks of 16. for (unsigned CI = 0, CE = it->second.size(); CI < CE; CI+=16) { unsigned Len = std::min<unsigned>(CE - CI, 16); - ArrayRef<StoreInst *> Chunk(&it->second[CI], Len); - Changed |= vectorizeStores(Chunk, -SLPCostThreshold, R); + Changed |= vectorizeStores(makeArrayRef(&it->second[CI], Len), + -SLPCostThreshold, R); } } return Changed; @@ -3078,6 +3836,7 @@ static const char lv_name[] = "SLP Vectorizer"; INITIALIZE_PASS_BEGIN(SLPVectorizer, SV_NAME, lv_name, false, false) INITIALIZE_AG_DEPENDENCY(AliasAnalysis) INITIALIZE_AG_DEPENDENCY(TargetTransformInfo) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(ScalarEvolution) INITIALIZE_PASS_DEPENDENCY(LoopSimplify) INITIALIZE_PASS_END(SLPVectorizer, SV_NAME, lv_name, false, false) |