1*0eae32dcSDimitry Andric //===- ModelUnderTrainingRunner.cpp - 'development' mode runner -----------===// 2*0eae32dcSDimitry Andric // 3*0eae32dcSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*0eae32dcSDimitry Andric // See https://llvm.org/LICENSE.txt for license information. 5*0eae32dcSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*0eae32dcSDimitry Andric // 7*0eae32dcSDimitry Andric //===----------------------------------------------------------------------===// 8*0eae32dcSDimitry Andric // 9*0eae32dcSDimitry Andric // Implementation of a MLModelRunner for 'development' mode, i.e. evaluation 10*0eae32dcSDimitry Andric // happens off a model that's provided from the command line and is interpreted. 11*0eae32dcSDimitry Andric // 12*0eae32dcSDimitry Andric //===----------------------------------------------------------------------===// 13*0eae32dcSDimitry Andric 14*0eae32dcSDimitry Andric #include "llvm/Config/config.h" 15*0eae32dcSDimitry Andric #if defined(LLVM_HAVE_TF_API) 16*0eae32dcSDimitry Andric 17*0eae32dcSDimitry Andric #include "llvm/Analysis/ModelUnderTrainingRunner.h" 18*0eae32dcSDimitry Andric 19*0eae32dcSDimitry Andric using namespace llvm; 20*0eae32dcSDimitry Andric 21*0eae32dcSDimitry Andric ModelUnderTrainingRunner::ModelUnderTrainingRunner( 22*0eae32dcSDimitry Andric LLVMContext &Ctx, const std::string &ModelPath, 23*0eae32dcSDimitry Andric const std::vector<TensorSpec> &InputSpecs, 24*0eae32dcSDimitry Andric const std::vector<LoggedFeatureSpec> &OutputSpecs) 25*0eae32dcSDimitry Andric : MLModelRunner(Ctx), OutputSpecs(OutputSpecs) { 26*0eae32dcSDimitry Andric Evaluator = std::make_unique<TFModelEvaluator>( 27*0eae32dcSDimitry Andric ModelPath, InputSpecs, [&](size_t I) { return OutputSpecs[I].Spec; }, 28*0eae32dcSDimitry Andric OutputSpecs.size()); 29*0eae32dcSDimitry Andric if (!Evaluator || !Evaluator->isValid()) { 30*0eae32dcSDimitry Andric Ctx.emitError("Failed to create inliner saved model evaluator"); 31*0eae32dcSDimitry Andric Evaluator.reset(); 32*0eae32dcSDimitry Andric return; 33*0eae32dcSDimitry Andric } 34*0eae32dcSDimitry Andric } 35*0eae32dcSDimitry Andric 36*0eae32dcSDimitry Andric void *ModelUnderTrainingRunner::evaluateUntyped() { 37*0eae32dcSDimitry Andric LastEvaluationResult = Evaluator->evaluate(); 38*0eae32dcSDimitry Andric if (!LastEvaluationResult.hasValue()) { 39*0eae32dcSDimitry Andric Ctx.emitError("Error evaluating model."); 40*0eae32dcSDimitry Andric return nullptr; 41*0eae32dcSDimitry Andric } 42*0eae32dcSDimitry Andric return LastEvaluationResult->getUntypedTensorValue(0); 43*0eae32dcSDimitry Andric } 44*0eae32dcSDimitry Andric 45*0eae32dcSDimitry Andric void *ModelUnderTrainingRunner::getTensorUntyped(size_t Index) { 46*0eae32dcSDimitry Andric return Evaluator->getUntypedInput(Index); 47*0eae32dcSDimitry Andric } 48*0eae32dcSDimitry Andric 49*0eae32dcSDimitry Andric #endif // defined(LLVM_HAVE_TF_API) 50