aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Analysis/DevelopmentModeInlineAdvisor.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Analysis/DevelopmentModeInlineAdvisor.cpp')
-rw-r--r--llvm/lib/Analysis/DevelopmentModeInlineAdvisor.cpp41
1 files changed, 16 insertions, 25 deletions
diff --git a/llvm/lib/Analysis/DevelopmentModeInlineAdvisor.cpp b/llvm/lib/Analysis/DevelopmentModeInlineAdvisor.cpp
index 31b2dafa29b4..4a792fce51d1 100644
--- a/llvm/lib/Analysis/DevelopmentModeInlineAdvisor.cpp
+++ b/llvm/lib/Analysis/DevelopmentModeInlineAdvisor.cpp
@@ -11,8 +11,10 @@
//
//===----------------------------------------------------------------------===//
#include "llvm/Config/config.h"
+#include "llvm/Support/Casting.h"
#if defined(LLVM_HAVE_TF_API)
+#include "llvm/ADT/BitVector.h"
#include "llvm/Analysis/CallGraph.h"
#include "llvm/Analysis/InlineSizeEstimatorAnalysis.h"
#include "llvm/Analysis/MLInlineAdvisor.h"
@@ -111,7 +113,7 @@ private:
StringRef LogFileName;
const ModelUnderTrainingRunner *const MUTR;
std::unique_ptr<Logger> L;
- std::vector<bool> Effects;
+ BitVector Effects;
/// There's at least one output. We'll set this to a different value if MUTR
/// is avaliable.
size_t OutputCount = 1;
@@ -150,7 +152,7 @@ public:
DevelopmentModeMLInlineAdvisor(
Module &M, ModuleAnalysisManager &MAM,
std::unique_ptr<MLModelRunner> ModelRunner,
- std::function<bool(CallBase &)> GetDefaultAdvice, bool IsDoingInference,
+ std::function<bool(CallBase &)> GetDefaultAdvice,
std::unique_ptr<TrainingLogger> Logger);
size_t getTotalSizeEstimate();
@@ -341,10 +343,11 @@ void TrainingLogger::print() {
DevelopmentModeMLInlineAdvisor::DevelopmentModeMLInlineAdvisor(
Module &M, ModuleAnalysisManager &MAM,
std::unique_ptr<MLModelRunner> ModelRunner,
- std::function<bool(CallBase &)> GetDefaultAdvice, bool IsDoingInference,
+ std::function<bool(CallBase &)> GetDefaultAdvice,
std::unique_ptr<TrainingLogger> Logger)
: MLInlineAdvisor(M, MAM, std::move(ModelRunner)),
- GetDefaultAdvice(GetDefaultAdvice), IsDoingInference(IsDoingInference),
+ GetDefaultAdvice(GetDefaultAdvice),
+ IsDoingInference(isa<ModelUnderTrainingRunner>(getModelRunner())),
Logger(std::move(Logger)),
InitialNativeSize(isLogging() ? getTotalSizeEstimate() : 0),
CurrentNativeSize(InitialNativeSize) {
@@ -410,8 +413,6 @@ size_t DevelopmentModeMLInlineAdvisor::getTotalSizeEstimate() {
for (auto &F : M) {
if (F.isDeclaration())
continue;
- if (isFunctionDeleted(&F))
- continue;
Ret += *getNativeSizeEstimate(F);
}
return Ret;
@@ -422,30 +423,20 @@ std::unique_ptr<InlineAdvisor> llvm::getDevelopmentModeAdvisor(
std::function<bool(CallBase &)> GetDefaultAdvice) {
auto &Ctx = M.getContext();
std::unique_ptr<MLModelRunner> Runner;
- ModelUnderTrainingRunner *MUTRPtr = nullptr;
- bool IsDoingInference = false;
if (TFModelUnderTrainingPath.empty())
Runner.reset(new NoInferenceModelRunner(Ctx, getInputFeatures()));
- else {
- std::unique_ptr<ModelUnderTrainingRunner> MUTR;
- if (auto MaybeOutputSpecs = loadOutputSpecs(
- Ctx, DecisionName, TFModelUnderTrainingPath, TFOutputSpecOverride))
- MUTR = std::make_unique<ModelUnderTrainingRunner>(
- Ctx, TFModelUnderTrainingPath, getInputFeatures(), *MaybeOutputSpecs);
- if (!MUTR || !MUTR->isValid()) {
- Ctx.emitError("Could not load the policy model from the provided path");
- return nullptr;
- }
- IsDoingInference = true;
- MUTRPtr = MUTR.get();
- Runner = std::move(MUTR);
- }
+ else
+ Runner = ModelUnderTrainingRunner::createAndEnsureValid(
+ Ctx, TFModelUnderTrainingPath, DecisionName, getInputFeatures(),
+ TFOutputSpecOverride);
+ if (!Runner)
+ return nullptr;
std::unique_ptr<TrainingLogger> Logger;
if (!TrainingLog.empty())
- Logger = std::make_unique<TrainingLogger>(TrainingLog, MUTRPtr);
+ Logger = std::make_unique<TrainingLogger>(
+ TrainingLog, dyn_cast<ModelUnderTrainingRunner>(Runner.get()));
return std::make_unique<DevelopmentModeMLInlineAdvisor>(
- M, MAM, std::move(Runner), GetDefaultAdvice, IsDoingInference,
- std::move(Logger));
+ M, MAM, std::move(Runner), GetDefaultAdvice, std::move(Logger));
}
#endif // defined(LLVM_HAVE_TF_API)