diff options
Diffstat (limited to 'llvm/lib/Analysis/DevelopmentModeInlineAdvisor.cpp')
-rw-r--r-- | llvm/lib/Analysis/DevelopmentModeInlineAdvisor.cpp | 41 |
1 files changed, 16 insertions, 25 deletions
diff --git a/llvm/lib/Analysis/DevelopmentModeInlineAdvisor.cpp b/llvm/lib/Analysis/DevelopmentModeInlineAdvisor.cpp index 31b2dafa29b4..4a792fce51d1 100644 --- a/llvm/lib/Analysis/DevelopmentModeInlineAdvisor.cpp +++ b/llvm/lib/Analysis/DevelopmentModeInlineAdvisor.cpp @@ -11,8 +11,10 @@ // //===----------------------------------------------------------------------===// #include "llvm/Config/config.h" +#include "llvm/Support/Casting.h" #if defined(LLVM_HAVE_TF_API) +#include "llvm/ADT/BitVector.h" #include "llvm/Analysis/CallGraph.h" #include "llvm/Analysis/InlineSizeEstimatorAnalysis.h" #include "llvm/Analysis/MLInlineAdvisor.h" @@ -111,7 +113,7 @@ private: StringRef LogFileName; const ModelUnderTrainingRunner *const MUTR; std::unique_ptr<Logger> L; - std::vector<bool> Effects; + BitVector Effects; /// There's at least one output. We'll set this to a different value if MUTR /// is avaliable. size_t OutputCount = 1; @@ -150,7 +152,7 @@ public: DevelopmentModeMLInlineAdvisor( Module &M, ModuleAnalysisManager &MAM, std::unique_ptr<MLModelRunner> ModelRunner, - std::function<bool(CallBase &)> GetDefaultAdvice, bool IsDoingInference, + std::function<bool(CallBase &)> GetDefaultAdvice, std::unique_ptr<TrainingLogger> Logger); size_t getTotalSizeEstimate(); @@ -341,10 +343,11 @@ void TrainingLogger::print() { DevelopmentModeMLInlineAdvisor::DevelopmentModeMLInlineAdvisor( Module &M, ModuleAnalysisManager &MAM, std::unique_ptr<MLModelRunner> ModelRunner, - std::function<bool(CallBase &)> GetDefaultAdvice, bool IsDoingInference, + std::function<bool(CallBase &)> GetDefaultAdvice, std::unique_ptr<TrainingLogger> Logger) : MLInlineAdvisor(M, MAM, std::move(ModelRunner)), - GetDefaultAdvice(GetDefaultAdvice), IsDoingInference(IsDoingInference), + GetDefaultAdvice(GetDefaultAdvice), + IsDoingInference(isa<ModelUnderTrainingRunner>(getModelRunner())), Logger(std::move(Logger)), InitialNativeSize(isLogging() ? getTotalSizeEstimate() : 0), CurrentNativeSize(InitialNativeSize) { @@ -410,8 +413,6 @@ size_t DevelopmentModeMLInlineAdvisor::getTotalSizeEstimate() { for (auto &F : M) { if (F.isDeclaration()) continue; - if (isFunctionDeleted(&F)) - continue; Ret += *getNativeSizeEstimate(F); } return Ret; @@ -422,30 +423,20 @@ std::unique_ptr<InlineAdvisor> llvm::getDevelopmentModeAdvisor( std::function<bool(CallBase &)> GetDefaultAdvice) { auto &Ctx = M.getContext(); std::unique_ptr<MLModelRunner> Runner; - ModelUnderTrainingRunner *MUTRPtr = nullptr; - bool IsDoingInference = false; if (TFModelUnderTrainingPath.empty()) Runner.reset(new NoInferenceModelRunner(Ctx, getInputFeatures())); - else { - std::unique_ptr<ModelUnderTrainingRunner> MUTR; - if (auto MaybeOutputSpecs = loadOutputSpecs( - Ctx, DecisionName, TFModelUnderTrainingPath, TFOutputSpecOverride)) - MUTR = std::make_unique<ModelUnderTrainingRunner>( - Ctx, TFModelUnderTrainingPath, getInputFeatures(), *MaybeOutputSpecs); - if (!MUTR || !MUTR->isValid()) { - Ctx.emitError("Could not load the policy model from the provided path"); - return nullptr; - } - IsDoingInference = true; - MUTRPtr = MUTR.get(); - Runner = std::move(MUTR); - } + else + Runner = ModelUnderTrainingRunner::createAndEnsureValid( + Ctx, TFModelUnderTrainingPath, DecisionName, getInputFeatures(), + TFOutputSpecOverride); + if (!Runner) + return nullptr; std::unique_ptr<TrainingLogger> Logger; if (!TrainingLog.empty()) - Logger = std::make_unique<TrainingLogger>(TrainingLog, MUTRPtr); + Logger = std::make_unique<TrainingLogger>( + TrainingLog, dyn_cast<ModelUnderTrainingRunner>(Runner.get())); return std::make_unique<DevelopmentModeMLInlineAdvisor>( - M, MAM, std::move(Runner), GetDefaultAdvice, IsDoingInference, - std::move(Logger)); + M, MAM, std::move(Runner), GetDefaultAdvice, std::move(Logger)); } #endif // defined(LLVM_HAVE_TF_API) |