10eae32dcSDimitry Andric //===- ModelUnderTrainingRunner.cpp - 'development' mode runner -----------===// 20eae32dcSDimitry Andric // 30eae32dcSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 40eae32dcSDimitry Andric // See https://llvm.org/LICENSE.txt for license information. 50eae32dcSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 60eae32dcSDimitry Andric // 70eae32dcSDimitry Andric //===----------------------------------------------------------------------===// 80eae32dcSDimitry Andric // 90eae32dcSDimitry Andric // Implementation of a MLModelRunner for 'development' mode, i.e. evaluation 100eae32dcSDimitry Andric // happens off a model that's provided from the command line and is interpreted. 110eae32dcSDimitry Andric // 120eae32dcSDimitry Andric //===----------------------------------------------------------------------===// 130eae32dcSDimitry Andric 140eae32dcSDimitry Andric #include "llvm/Config/config.h" 150eae32dcSDimitry Andric #if defined(LLVM_HAVE_TF_API) 160eae32dcSDimitry Andric 170eae32dcSDimitry Andric #include "llvm/Analysis/ModelUnderTrainingRunner.h" 180eae32dcSDimitry Andric 190eae32dcSDimitry Andric using namespace llvm; 200eae32dcSDimitry Andric 210eae32dcSDimitry Andric ModelUnderTrainingRunner::ModelUnderTrainingRunner( 220eae32dcSDimitry Andric LLVMContext &Ctx, const std::string &ModelPath, 230eae32dcSDimitry Andric const std::vector<TensorSpec> &InputSpecs, 240eae32dcSDimitry Andric const std::vector<LoggedFeatureSpec> &OutputSpecs) 25*81ad6265SDimitry Andric : MLModelRunner(Ctx, MLModelRunner::Kind::Development, InputSpecs.size()), 2604eeddc0SDimitry Andric OutputSpecs(OutputSpecs) { 270eae32dcSDimitry Andric Evaluator = std::make_unique<TFModelEvaluator>( 280eae32dcSDimitry Andric ModelPath, InputSpecs, [&](size_t I) { return OutputSpecs[I].Spec; }, 290eae32dcSDimitry Andric OutputSpecs.size()); 300eae32dcSDimitry Andric if (!Evaluator || !Evaluator->isValid()) { 3104eeddc0SDimitry Andric Ctx.emitError("Failed to create saved model evaluator"); 320eae32dcSDimitry Andric Evaluator.reset(); 330eae32dcSDimitry Andric return; 340eae32dcSDimitry Andric } 35*81ad6265SDimitry Andric 36*81ad6265SDimitry Andric for (size_t I = 0, E = InputSpecs.size(); I < E; ++I) { 37*81ad6265SDimitry Andric setUpBufferForTensor(I, InputSpecs[I], Evaluator->getUntypedInput(I)); 38*81ad6265SDimitry Andric } 390eae32dcSDimitry Andric } 400eae32dcSDimitry Andric 410eae32dcSDimitry Andric void *ModelUnderTrainingRunner::evaluateUntyped() { 420eae32dcSDimitry Andric LastEvaluationResult = Evaluator->evaluate(); 430eae32dcSDimitry Andric if (!LastEvaluationResult.hasValue()) { 440eae32dcSDimitry Andric Ctx.emitError("Error evaluating model."); 450eae32dcSDimitry Andric return nullptr; 460eae32dcSDimitry Andric } 470eae32dcSDimitry Andric return LastEvaluationResult->getUntypedTensorValue(0); 480eae32dcSDimitry Andric } 490eae32dcSDimitry Andric 50*81ad6265SDimitry Andric std::unique_ptr<ModelUnderTrainingRunner> 51*81ad6265SDimitry Andric ModelUnderTrainingRunner::createAndEnsureValid( 52*81ad6265SDimitry Andric LLVMContext &Ctx, const std::string &ModelPath, StringRef DecisionName, 53*81ad6265SDimitry Andric const std::vector<TensorSpec> &InputSpecs, 54*81ad6265SDimitry Andric StringRef OutputSpecsPathOverride) { 55*81ad6265SDimitry Andric if (auto MaybeOutputSpecs = loadOutputSpecs(Ctx, DecisionName, ModelPath, 56*81ad6265SDimitry Andric OutputSpecsPathOverride)) 57*81ad6265SDimitry Andric return createAndEnsureValid(Ctx, ModelPath, DecisionName, InputSpecs, 58*81ad6265SDimitry Andric *MaybeOutputSpecs); 59*81ad6265SDimitry Andric Ctx.emitError("Could not load the policy model from the provided path"); 60*81ad6265SDimitry Andric return nullptr; 610eae32dcSDimitry Andric } 620eae32dcSDimitry Andric 6304eeddc0SDimitry Andric std::unique_ptr<ModelUnderTrainingRunner> 6404eeddc0SDimitry Andric ModelUnderTrainingRunner::createAndEnsureValid( 6504eeddc0SDimitry Andric LLVMContext &Ctx, const std::string &ModelPath, StringRef DecisionName, 6604eeddc0SDimitry Andric const std::vector<TensorSpec> &InputSpecs, 67*81ad6265SDimitry Andric const std::vector<LoggedFeatureSpec> &OutputSpecs) { 6804eeddc0SDimitry Andric std::unique_ptr<ModelUnderTrainingRunner> MUTR; 69*81ad6265SDimitry Andric MUTR.reset( 70*81ad6265SDimitry Andric new ModelUnderTrainingRunner(Ctx, ModelPath, InputSpecs, OutputSpecs)); 7104eeddc0SDimitry Andric if (MUTR && MUTR->isValid()) 7204eeddc0SDimitry Andric return MUTR; 7304eeddc0SDimitry Andric 74*81ad6265SDimitry Andric Ctx.emitError("Could not load or create model evaluator."); 7504eeddc0SDimitry Andric return nullptr; 7604eeddc0SDimitry Andric } 7704eeddc0SDimitry Andric 780eae32dcSDimitry Andric #endif // defined(LLVM_HAVE_TF_API) 79