1 //===- ModelUnderTrainingRunner.cpp - 'development' mode runner -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Implementation of a MLModelRunner for 'development' mode, i.e. evaluation
10 // happens off a model that's provided from the command line and is interpreted.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "llvm/ADT/STLExtras.h"
15 #include "llvm/Config/config.h"
16 #if defined(LLVM_HAVE_TFLITE)
17 #include "llvm/Analysis/ModelUnderTrainingRunner.h"
18 #include "llvm/Support/MemoryBuffer.h"
19 #include "llvm/Support/Path.h"
20 #include <optional>
21
22 using namespace llvm;
23 namespace {
24 struct LoggedFeatureSpec {
25 TensorSpec Spec;
26 std::optional<std::string> LoggingName;
27 };
28
29 std::optional<std::vector<LoggedFeatureSpec>>
loadOutputSpecs(LLVMContext & Ctx,StringRef ExpectedDecisionName,StringRef ModelPath,StringRef SpecFileOverride)30 loadOutputSpecs(LLVMContext &Ctx, StringRef ExpectedDecisionName,
31 StringRef ModelPath, StringRef SpecFileOverride) {
32 SmallVector<char, 128> OutputSpecsPath;
33 StringRef FileName = SpecFileOverride;
34 if (FileName.empty()) {
35 llvm::sys::path::append(OutputSpecsPath, ModelPath, "output_spec.json");
36 FileName = {OutputSpecsPath.data(), OutputSpecsPath.size()};
37 }
38
39 auto BufferOrError = MemoryBuffer::getFileOrSTDIN(FileName);
40 if (!BufferOrError) {
41 Ctx.emitError("Error opening output specs file: " + FileName + " : " +
42 BufferOrError.getError().message());
43 return std::nullopt;
44 }
45 auto ParsedJSONValues = json::parse(BufferOrError.get()->getBuffer());
46 if (!ParsedJSONValues) {
47 Ctx.emitError("Could not parse specs file: " + FileName);
48 return std::nullopt;
49 }
50 auto ValuesArray = ParsedJSONValues->getAsArray();
51 if (!ValuesArray) {
52 Ctx.emitError("Expected an array of {tensor_spec:<TensorSpec>, "
53 "logging_name:<name>} dictionaries");
54 return std::nullopt;
55 }
56 std::vector<LoggedFeatureSpec> Ret;
57 for (const auto &Value : *ValuesArray)
58 if (const auto *Obj = Value.getAsObject())
59 if (const auto *SpecPart = Obj->get("tensor_spec"))
60 if (auto TensorSpec = getTensorSpecFromJSON(Ctx, *SpecPart))
61 if (auto LoggingName = Obj->getString("logging_name")) {
62 if (!TensorSpec->isElementType<int64_t>() &&
63 !TensorSpec->isElementType<int32_t>() &&
64 !TensorSpec->isElementType<float>()) {
65 Ctx.emitError(
66 "Only int64, int32, and float tensors are supported. "
67 "Found unsupported type for tensor named " +
68 TensorSpec->name());
69 return std::nullopt;
70 }
71 Ret.push_back({*TensorSpec, LoggingName->str()});
72 }
73
74 if (ValuesArray->size() != Ret.size()) {
75 Ctx.emitError(
76 "Unable to parse output spec. It should be a json file containing an "
77 "array of dictionaries. Each dictionary must have a 'tensor_spec' key, "
78 "with a json object describing a TensorSpec; and a 'logging_name' key, "
79 "which is a string to use as name when logging this tensor in the "
80 "training log.");
81 return std::nullopt;
82 }
83 if (Ret.empty() || *Ret[0].LoggingName != ExpectedDecisionName) {
84 Ctx.emitError("The first output spec must describe the decision tensor, "
85 "and must have the logging_name " +
86 StringRef(ExpectedDecisionName));
87 return std::nullopt;
88 }
89 return Ret;
90 }
91 } // namespace
92
ModelUnderTrainingRunner(LLVMContext & Ctx,const std::string & ModelPath,const std::vector<TensorSpec> & InputSpecs,const std::vector<TensorSpec> & OutputSpecs,const std::vector<TensorSpec> & ExtraOutputsForLogging)93 ModelUnderTrainingRunner::ModelUnderTrainingRunner(
94 LLVMContext &Ctx, const std::string &ModelPath,
95 const std::vector<TensorSpec> &InputSpecs,
96 const std::vector<TensorSpec> &OutputSpecs,
97 const std::vector<TensorSpec> &ExtraOutputsForLogging)
98 : MLModelRunner(Ctx, MLModelRunner::Kind::Development, InputSpecs.size()),
99 OutputSpecs(OutputSpecs), ExtraOutputsForLogging(ExtraOutputsForLogging) {
100 Evaluator =
101 std::make_unique<TFModelEvaluator>(ModelPath, InputSpecs, OutputSpecs);
102 if (!Evaluator || !Evaluator->isValid()) {
103 Ctx.emitError("Failed to create saved model evaluator");
104 Evaluator.reset();
105 return;
106 }
107
108 for (size_t I = 0, E = InputSpecs.size(); I < E; ++I) {
109 setUpBufferForTensor(I, InputSpecs[I], Evaluator->getUntypedInput(I));
110 }
111 }
112
evaluateUntyped()113 void *ModelUnderTrainingRunner::evaluateUntyped() {
114 LastEvaluationResult = Evaluator->evaluate();
115 if (!LastEvaluationResult.has_value()) {
116 Ctx.emitError("Error evaluating model.");
117 return nullptr;
118 }
119 return LastEvaluationResult->getUntypedTensorValue(0);
120 }
121
122 std::unique_ptr<ModelUnderTrainingRunner>
createAndEnsureValid(LLVMContext & Ctx,const std::string & ModelPath,StringRef DecisionName,const std::vector<TensorSpec> & InputSpecs,StringRef OutputSpecsPathOverride)123 ModelUnderTrainingRunner::createAndEnsureValid(
124 LLVMContext &Ctx, const std::string &ModelPath, StringRef DecisionName,
125 const std::vector<TensorSpec> &InputSpecs,
126 StringRef OutputSpecsPathOverride) {
127 if (auto MaybeOutputSpecs = loadOutputSpecs(Ctx, DecisionName, ModelPath,
128 OutputSpecsPathOverride)) {
129 std::unique_ptr<ModelUnderTrainingRunner> MUTR;
130 std::vector<TensorSpec> OutputSpecs;
131 std::vector<TensorSpec> ExtraOutputsForLogging;
132 append_range(OutputSpecs,
133 map_range(*MaybeOutputSpecs, [](const LoggedFeatureSpec &LFS) {
134 return LFS.Spec;
135 }));
136 append_range(ExtraOutputsForLogging,
137 map_range(drop_begin(*MaybeOutputSpecs),
138 [](const LoggedFeatureSpec &LFS) {
139 return TensorSpec(LFS.LoggingName
140 ? *LFS.LoggingName
141 : LFS.Spec.name(),
142 LFS.Spec);
143 }));
144
145 MUTR.reset(new ModelUnderTrainingRunner(
146 Ctx, ModelPath, InputSpecs, OutputSpecs, ExtraOutputsForLogging));
147 if (MUTR && MUTR->isValid())
148 return MUTR;
149
150 Ctx.emitError("Could not load or create model evaluator.");
151 return nullptr;
152 }
153 Ctx.emitError("Could not load the policy model from the provided path");
154 return nullptr;
155 }
156
157 #endif // defined(LLVM_HAVE_TFLITE)
158