xref: /freebsd/contrib/llvm-project/llvm/lib/Analysis/ModelUnderTrainingRunner.cpp (revision 81ad626541db97eb356e2c1d4a20eb2a26a766ab)
10eae32dcSDimitry Andric //===- ModelUnderTrainingRunner.cpp - 'development' mode runner -----------===//
20eae32dcSDimitry Andric //
30eae32dcSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
40eae32dcSDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
50eae32dcSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
60eae32dcSDimitry Andric //
70eae32dcSDimitry Andric //===----------------------------------------------------------------------===//
80eae32dcSDimitry Andric //
90eae32dcSDimitry Andric // Implementation of a MLModelRunner for 'development' mode, i.e. evaluation
100eae32dcSDimitry Andric // happens off a model that's provided from the command line and is interpreted.
110eae32dcSDimitry Andric //
120eae32dcSDimitry Andric //===----------------------------------------------------------------------===//
130eae32dcSDimitry Andric 
140eae32dcSDimitry Andric #include "llvm/Config/config.h"
150eae32dcSDimitry Andric #if defined(LLVM_HAVE_TF_API)
160eae32dcSDimitry Andric 
170eae32dcSDimitry Andric #include "llvm/Analysis/ModelUnderTrainingRunner.h"
180eae32dcSDimitry Andric 
190eae32dcSDimitry Andric using namespace llvm;
200eae32dcSDimitry Andric 
210eae32dcSDimitry Andric ModelUnderTrainingRunner::ModelUnderTrainingRunner(
220eae32dcSDimitry Andric     LLVMContext &Ctx, const std::string &ModelPath,
230eae32dcSDimitry Andric     const std::vector<TensorSpec> &InputSpecs,
240eae32dcSDimitry Andric     const std::vector<LoggedFeatureSpec> &OutputSpecs)
25*81ad6265SDimitry Andric     : MLModelRunner(Ctx, MLModelRunner::Kind::Development, InputSpecs.size()),
2604eeddc0SDimitry Andric       OutputSpecs(OutputSpecs) {
270eae32dcSDimitry Andric   Evaluator = std::make_unique<TFModelEvaluator>(
280eae32dcSDimitry Andric       ModelPath, InputSpecs, [&](size_t I) { return OutputSpecs[I].Spec; },
290eae32dcSDimitry Andric       OutputSpecs.size());
300eae32dcSDimitry Andric   if (!Evaluator || !Evaluator->isValid()) {
3104eeddc0SDimitry Andric     Ctx.emitError("Failed to create saved model evaluator");
320eae32dcSDimitry Andric     Evaluator.reset();
330eae32dcSDimitry Andric     return;
340eae32dcSDimitry Andric   }
35*81ad6265SDimitry Andric 
36*81ad6265SDimitry Andric   for (size_t I = 0, E = InputSpecs.size(); I < E; ++I) {
37*81ad6265SDimitry Andric     setUpBufferForTensor(I, InputSpecs[I], Evaluator->getUntypedInput(I));
38*81ad6265SDimitry Andric   }
390eae32dcSDimitry Andric }
400eae32dcSDimitry Andric 
410eae32dcSDimitry Andric void *ModelUnderTrainingRunner::evaluateUntyped() {
420eae32dcSDimitry Andric   LastEvaluationResult = Evaluator->evaluate();
430eae32dcSDimitry Andric   if (!LastEvaluationResult.hasValue()) {
440eae32dcSDimitry Andric     Ctx.emitError("Error evaluating model.");
450eae32dcSDimitry Andric     return nullptr;
460eae32dcSDimitry Andric   }
470eae32dcSDimitry Andric   return LastEvaluationResult->getUntypedTensorValue(0);
480eae32dcSDimitry Andric }
490eae32dcSDimitry Andric 
50*81ad6265SDimitry Andric std::unique_ptr<ModelUnderTrainingRunner>
51*81ad6265SDimitry Andric ModelUnderTrainingRunner::createAndEnsureValid(
52*81ad6265SDimitry Andric     LLVMContext &Ctx, const std::string &ModelPath, StringRef DecisionName,
53*81ad6265SDimitry Andric     const std::vector<TensorSpec> &InputSpecs,
54*81ad6265SDimitry Andric     StringRef OutputSpecsPathOverride) {
55*81ad6265SDimitry Andric   if (auto MaybeOutputSpecs = loadOutputSpecs(Ctx, DecisionName, ModelPath,
56*81ad6265SDimitry Andric                                               OutputSpecsPathOverride))
57*81ad6265SDimitry Andric     return createAndEnsureValid(Ctx, ModelPath, DecisionName, InputSpecs,
58*81ad6265SDimitry Andric                                 *MaybeOutputSpecs);
59*81ad6265SDimitry Andric   Ctx.emitError("Could not load the policy model from the provided path");
60*81ad6265SDimitry Andric   return nullptr;
610eae32dcSDimitry Andric }
620eae32dcSDimitry Andric 
6304eeddc0SDimitry Andric std::unique_ptr<ModelUnderTrainingRunner>
6404eeddc0SDimitry Andric ModelUnderTrainingRunner::createAndEnsureValid(
6504eeddc0SDimitry Andric     LLVMContext &Ctx, const std::string &ModelPath, StringRef DecisionName,
6604eeddc0SDimitry Andric     const std::vector<TensorSpec> &InputSpecs,
67*81ad6265SDimitry Andric     const std::vector<LoggedFeatureSpec> &OutputSpecs) {
6804eeddc0SDimitry Andric   std::unique_ptr<ModelUnderTrainingRunner> MUTR;
69*81ad6265SDimitry Andric   MUTR.reset(
70*81ad6265SDimitry Andric       new ModelUnderTrainingRunner(Ctx, ModelPath, InputSpecs, OutputSpecs));
7104eeddc0SDimitry Andric   if (MUTR && MUTR->isValid())
7204eeddc0SDimitry Andric     return MUTR;
7304eeddc0SDimitry Andric 
74*81ad6265SDimitry Andric   Ctx.emitError("Could not load or create model evaluator.");
7504eeddc0SDimitry Andric   return nullptr;
7604eeddc0SDimitry Andric }
7704eeddc0SDimitry Andric 
780eae32dcSDimitry Andric #endif // defined(LLVM_HAVE_TF_API)
79