xref: /freebsd/contrib/llvm-project/llvm/lib/Analysis/ModelUnderTrainingRunner.cpp (revision 0eae32dcef82f6f06de6419a0d623d7def0cc8f6)
1*0eae32dcSDimitry Andric //===- ModelUnderTrainingRunner.cpp - 'development' mode runner -----------===//
2*0eae32dcSDimitry Andric //
3*0eae32dcSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*0eae32dcSDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5*0eae32dcSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*0eae32dcSDimitry Andric //
7*0eae32dcSDimitry Andric //===----------------------------------------------------------------------===//
8*0eae32dcSDimitry Andric //
9*0eae32dcSDimitry Andric // Implementation of a MLModelRunner for 'development' mode, i.e. evaluation
10*0eae32dcSDimitry Andric // happens off a model that's provided from the command line and is interpreted.
11*0eae32dcSDimitry Andric //
12*0eae32dcSDimitry Andric //===----------------------------------------------------------------------===//
13*0eae32dcSDimitry Andric 
14*0eae32dcSDimitry Andric #include "llvm/Config/config.h"
15*0eae32dcSDimitry Andric #if defined(LLVM_HAVE_TF_API)
16*0eae32dcSDimitry Andric 
17*0eae32dcSDimitry Andric #include "llvm/Analysis/ModelUnderTrainingRunner.h"
18*0eae32dcSDimitry Andric 
19*0eae32dcSDimitry Andric using namespace llvm;
20*0eae32dcSDimitry Andric 
21*0eae32dcSDimitry Andric ModelUnderTrainingRunner::ModelUnderTrainingRunner(
22*0eae32dcSDimitry Andric     LLVMContext &Ctx, const std::string &ModelPath,
23*0eae32dcSDimitry Andric     const std::vector<TensorSpec> &InputSpecs,
24*0eae32dcSDimitry Andric     const std::vector<LoggedFeatureSpec> &OutputSpecs)
25*0eae32dcSDimitry Andric     : MLModelRunner(Ctx), OutputSpecs(OutputSpecs) {
26*0eae32dcSDimitry Andric   Evaluator = std::make_unique<TFModelEvaluator>(
27*0eae32dcSDimitry Andric       ModelPath, InputSpecs, [&](size_t I) { return OutputSpecs[I].Spec; },
28*0eae32dcSDimitry Andric       OutputSpecs.size());
29*0eae32dcSDimitry Andric   if (!Evaluator || !Evaluator->isValid()) {
30*0eae32dcSDimitry Andric     Ctx.emitError("Failed to create inliner saved model evaluator");
31*0eae32dcSDimitry Andric     Evaluator.reset();
32*0eae32dcSDimitry Andric     return;
33*0eae32dcSDimitry Andric   }
34*0eae32dcSDimitry Andric }
35*0eae32dcSDimitry Andric 
36*0eae32dcSDimitry Andric void *ModelUnderTrainingRunner::evaluateUntyped() {
37*0eae32dcSDimitry Andric   LastEvaluationResult = Evaluator->evaluate();
38*0eae32dcSDimitry Andric   if (!LastEvaluationResult.hasValue()) {
39*0eae32dcSDimitry Andric     Ctx.emitError("Error evaluating model.");
40*0eae32dcSDimitry Andric     return nullptr;
41*0eae32dcSDimitry Andric   }
42*0eae32dcSDimitry Andric   return LastEvaluationResult->getUntypedTensorValue(0);
43*0eae32dcSDimitry Andric }
44*0eae32dcSDimitry Andric 
45*0eae32dcSDimitry Andric void *ModelUnderTrainingRunner::getTensorUntyped(size_t Index) {
46*0eae32dcSDimitry Andric   return Evaluator->getUntypedInput(Index);
47*0eae32dcSDimitry Andric }
48*0eae32dcSDimitry Andric 
49*0eae32dcSDimitry Andric #endif // defined(LLVM_HAVE_TF_API)
50