aboutsummaryrefslogtreecommitdiff
path: root/llvm/include/llvm/Analysis/MLInlineAdvisor.h
blob: 05411d9c99a29a821937fedc890e29ab373d18d8 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
//===- MLInlineAdvisor.h - ML - based InlineAdvisor factories ---*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_ANALYSIS_MLINLINEADVISOR_H
#define LLVM_ANALYSIS_MLINLINEADVISOR_H

#include "llvm/Analysis/InlineAdvisor.h"
#include "llvm/Analysis/LazyCallGraph.h"
#include "llvm/Analysis/MLModelRunner.h"
#include "llvm/IR/PassManager.h"

#include <deque>
#include <memory>

namespace llvm {
class Module;
class MLInlineAdvice;

class MLInlineAdvisor : public InlineAdvisor {
public:
  MLInlineAdvisor(Module &M, ModuleAnalysisManager &MAM,
                  std::unique_ptr<MLModelRunner> ModelRunner);

  virtual ~MLInlineAdvisor() = default;

  void onPassEntry() override;
  void onPassExit(LazyCallGraph::SCC *SCC) override;

  int64_t getIRSize(const Function &F) const { return F.getInstructionCount(); }
  void onSuccessfulInlining(const MLInlineAdvice &Advice,
                            bool CalleeWasDeleted);

  bool isForcedToStop() const { return ForceStop; }
  int64_t getLocalCalls(Function &F);
  const MLModelRunner &getModelRunner() const { return *ModelRunner.get(); }

protected:
  std::unique_ptr<InlineAdvice> getAdviceImpl(CallBase &CB) override;

  std::unique_ptr<InlineAdvice> getMandatoryAdvice(CallBase &CB,
                                                   bool Advice) override;

  virtual std::unique_ptr<MLInlineAdvice> getMandatoryAdviceImpl(CallBase &CB);

  virtual std::unique_ptr<MLInlineAdvice>
  getAdviceFromModel(CallBase &CB, OptimizationRemarkEmitter &ORE);

  // Get the initial 'level' of the function, or 0 if the function has been
  // introduced afterwards.
  // TODO: should we keep this updated?
  unsigned getInitialFunctionLevel(const Function &F) const;

  std::unique_ptr<MLModelRunner> ModelRunner;

private:
  int64_t getModuleIRSize() const;

  void print(raw_ostream &OS) const override {
    OS << "[MLInlineAdvisor] Nodes: " << NodeCount << " Edges: " << EdgeCount
       << "\n";
  }

  LazyCallGraph &CG;

  int64_t NodeCount = 0;
  int64_t EdgeCount = 0;
  int64_t EdgesOfLastSeenNodes = 0;

  std::map<const LazyCallGraph::Node *, unsigned> FunctionLevels;
  const int32_t InitialIRSize = 0;
  int32_t CurrentIRSize = 0;
  std::deque<const LazyCallGraph::Node *> NodesInLastSCC;
  DenseSet<const LazyCallGraph::Node *> AllNodes;
  bool ForceStop = false;
};

/// InlineAdvice that tracks changes post inlining. For that reason, it only
/// overrides the "successful inlining" extension points.
class MLInlineAdvice : public InlineAdvice {
public:
  MLInlineAdvice(MLInlineAdvisor *Advisor, CallBase &CB,
                 OptimizationRemarkEmitter &ORE, bool Recommendation)
      : InlineAdvice(Advisor, CB, ORE, Recommendation),
        CallerIRSize(Advisor->isForcedToStop() ? 0
                                               : Advisor->getIRSize(*Caller)),
        CalleeIRSize(Advisor->isForcedToStop() ? 0
                                               : Advisor->getIRSize(*Callee)),
        CallerAndCalleeEdges(Advisor->isForcedToStop()
                                 ? 0
                                 : (Advisor->getLocalCalls(*Caller) +
                                    Advisor->getLocalCalls(*Callee))) {}
  virtual ~MLInlineAdvice() = default;

  void recordInliningImpl() override;
  void recordInliningWithCalleeDeletedImpl() override;
  void recordUnsuccessfulInliningImpl(const InlineResult &Result) override;
  void recordUnattemptedInliningImpl() override;

  Function *getCaller() const { return Caller; }
  Function *getCallee() const { return Callee; }

  const int64_t CallerIRSize;
  const int64_t CalleeIRSize;
  const int64_t CallerAndCalleeEdges;

private:
  void reportContextForRemark(DiagnosticInfoOptimizationBase &OR);

  MLInlineAdvisor *getAdvisor() const {
    return static_cast<MLInlineAdvisor *>(Advisor);
  };
};

} // namespace llvm

#endif // LLVM_ANALYSIS_MLINLINEADVISOR_H