diff options
Diffstat (limited to 'llvm/include/llvm/Analysis/MLInlineAdvisor.h')
-rw-r--r-- | llvm/include/llvm/Analysis/MLInlineAdvisor.h | 27 |
1 files changed, 19 insertions, 8 deletions
diff --git a/llvm/include/llvm/Analysis/MLInlineAdvisor.h b/llvm/include/llvm/Analysis/MLInlineAdvisor.h index a218561e61c7..05411d9c99a2 100644 --- a/llvm/include/llvm/Analysis/MLInlineAdvisor.h +++ b/llvm/include/llvm/Analysis/MLInlineAdvisor.h @@ -9,13 +9,13 @@ #ifndef LLVM_ANALYSIS_MLINLINEADVISOR_H #define LLVM_ANALYSIS_MLINLINEADVISOR_H -#include "llvm/Analysis/CallGraph.h" #include "llvm/Analysis/InlineAdvisor.h" +#include "llvm/Analysis/LazyCallGraph.h" #include "llvm/Analysis/MLModelRunner.h" #include "llvm/IR/PassManager.h" +#include <deque> #include <memory> -#include <unordered_map> namespace llvm { class Module; @@ -26,10 +26,10 @@ public: MLInlineAdvisor(Module &M, ModuleAnalysisManager &MAM, std::unique_ptr<MLModelRunner> ModelRunner); - CallGraph *callGraph() const { return CG.get(); } virtual ~MLInlineAdvisor() = default; void onPassEntry() override; + void onPassExit(LazyCallGraph::SCC *SCC) override; int64_t getIRSize(const Function &F) const { return F.getInstructionCount(); } void onSuccessfulInlining(const MLInlineAdvice &Advice, @@ -38,7 +38,6 @@ public: bool isForcedToStop() const { return ForceStop; } int64_t getLocalCalls(Function &F); const MLModelRunner &getModelRunner() const { return *ModelRunner.get(); } - void onModuleInvalidated() override { Invalid = true; } protected: std::unique_ptr<InlineAdvice> getAdviceImpl(CallBase &CB) override; @@ -51,20 +50,32 @@ protected: virtual std::unique_ptr<MLInlineAdvice> getAdviceFromModel(CallBase &CB, OptimizationRemarkEmitter &ORE); + // Get the initial 'level' of the function, or 0 if the function has been + // introduced afterwards. + // TODO: should we keep this updated? + unsigned getInitialFunctionLevel(const Function &F) const; + std::unique_ptr<MLModelRunner> ModelRunner; private: int64_t getModuleIRSize() const; - bool Invalid = true; - std::unique_ptr<CallGraph> CG; + void print(raw_ostream &OS) const override { + OS << "[MLInlineAdvisor] Nodes: " << NodeCount << " Edges: " << EdgeCount + << "\n"; + } + + LazyCallGraph &CG; int64_t NodeCount = 0; int64_t EdgeCount = 0; - std::map<const Function *, unsigned> FunctionLevels; + int64_t EdgesOfLastSeenNodes = 0; + + std::map<const LazyCallGraph::Node *, unsigned> FunctionLevels; const int32_t InitialIRSize = 0; int32_t CurrentIRSize = 0; - + std::deque<const LazyCallGraph::Node *> NodesInLastSCC; + DenseSet<const LazyCallGraph::Node *> AllNodes; bool ForceStop = false; }; |