1 //===- ReleaseModeModelRunner.h - Fast, precompiled model runner ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a model runner wrapping an AOT compiled ML model.
10 // Only inference is supported.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #ifndef LLVM_ANALYSIS_RELEASEMODEMODELRUNNER_H
15 #define LLVM_ANALYSIS_RELEASEMODEMODELRUNNER_H
16
17 #include "llvm/ADT/StringExtras.h"
18 #include "llvm/Analysis/MLModelRunner.h"
19 #include "llvm/Analysis/TensorSpec.h"
20 #include "llvm/Support/ErrorHandling.h"
21 #include "llvm/Support/MD5.h"
22
23 #include <memory>
24
25 namespace llvm {
26
27 /// ReleaseModeModelRunner - production mode implementation of the
28 /// MLModelRunner. It uses an AOT-compiled SavedModel for efficient execution.
29 struct EmbeddedModelRunnerOptions {
30 /// Feed and Fetch feature prefixes - i.e. a feature named "foo" will be
31 /// looked up as {FeedPrefix}_foo; and the output named "bar" will be looked
32 /// up as {FetchPrefix}_bar
33 StringRef FeedPrefix = "feed_";
34 StringRef FetchPrefix = "fetch_";
35
36 /// ModelSelector is the name (recognized by the AOT-ed model) of a sub-model
37 /// to use. "" is allowed if the model doesn't support sub-models.
38 StringRef ModelSelector = "";
39
setFeedPrefixEmbeddedModelRunnerOptions40 EmbeddedModelRunnerOptions &setFeedPrefix(StringRef Value) {
41 FeedPrefix = Value;
42 return *this;
43 }
setFetchPrefixEmbeddedModelRunnerOptions44 EmbeddedModelRunnerOptions &setFetchPrefix(StringRef Value) {
45 FetchPrefix = Value;
46 return *this;
47 }
setModelSelectorEmbeddedModelRunnerOptions48 EmbeddedModelRunnerOptions &setModelSelector(StringRef Value) {
49 ModelSelector = Value;
50 return *this;
51 }
52 };
53
54 template <class TGen>
55 class ReleaseModeModelRunner final : public MLModelRunner {
56 public:
57 /// FeatureNames' type should be an indexed collection of std::string, like
58 /// std::array or std::vector, that has a size() method.
59 template <class FType>
60 ReleaseModeModelRunner(LLVMContext &Ctx, const FType &InputSpec,
61 StringRef DecisionName,
62 const EmbeddedModelRunnerOptions &Options = {})
63 : MLModelRunner(Ctx, MLModelRunner::Kind::Release, InputSpec.size() + 1),
64 CompiledModel(std::make_unique<TGen>()) {
65 assert(CompiledModel && "The CompiledModel should be valid");
66 // Set up the model_selector past all the InputSpecs in all cases.
67 // - if the model doesn't have such a feature, but the user requested it,
68 // we report error. Same if the model supports it but the user didn't
69 // specify it
70 // - finally, we compute the MD5 hash of the user input and set the value
71 // of the model selector to {high, low}
72 bool InputIsPresent = true;
73 populateTensor(InputSpec.size(),
74 TensorSpec::createSpec<uint64_t>("model_selector", {2}),
75 Options.FeedPrefix, InputIsPresent);
76
77 // If we hit the "report an error" cases outlined above, continue with the
78 // set up in case there's some custom diagnostics handler installed and it
79 // doesn't promptly exit.
80 if (Options.ModelSelector.empty() && InputIsPresent)
81 Ctx.emitError(
82 "A model selector was not specified but the underlying model "
83 "requires selecting one because it exposes a model_selector input");
84 uint64_t High = 0;
85 uint64_t Low = 0;
86 if (!Options.ModelSelector.empty()) {
87 if (!InputIsPresent)
88 Ctx.emitError("A model selector was specified but the underlying model "
89 "does not expose a model_selector input");
90 const auto Hash = MD5::hash(arrayRefFromStringRef(Options.ModelSelector));
91 High = Hash.high();
92 Low = Hash.low();
93 }
94 getTensor<uint64_t>(InputSpec.size())[0] = High;
95 getTensor<uint64_t>(InputSpec.size())[1] = Low;
96 // At this point, the model selector is set up. If the user didn't provide
97 // one, but the model has a model_selector, it'll be set to (0, 0) which
98 // the composite model should treat as error as part of its implementation
99 // (but that should only matter if there is a custom handler that doesn't
100 // exit on error)
101 for (size_t I = 0; I < InputSpec.size(); ++I)
102 populateTensor(I, InputSpec[I], Options.FeedPrefix, InputIsPresent);
103
104 ResultIndex = CompiledModel->LookupResultIndex(Options.FetchPrefix.str() +
105 DecisionName.str());
106 assert(ResultIndex >= 0 && "Cannot find DecisionName in inlining model");
107 }
108
109 virtual ~ReleaseModeModelRunner() = default;
110
classof(const MLModelRunner * R)111 static bool classof(const MLModelRunner *R) {
112 return R->getKind() == MLModelRunner::Kind::Release;
113 }
114
115 private:
116 // fetch the model-provided buffer for the given Spec, or let MLModelRunner
117 // create a scratch buffer. Indicate back to the caller if the model had that
118 // input in the first place.
populateTensor(size_t Pos,const TensorSpec & Spec,StringRef Prefix,bool & InputIsPresent)119 void populateTensor(size_t Pos, const TensorSpec &Spec, StringRef Prefix,
120 bool &InputIsPresent) {
121 const int Index =
122 CompiledModel->LookupArgIndex((Prefix + Spec.name()).str());
123 void *Buffer = nullptr;
124 InputIsPresent = Index >= 0;
125 if (InputIsPresent)
126 Buffer = CompiledModel->arg_data(Index);
127 setUpBufferForTensor(Pos, Spec, Buffer);
128 }
129
evaluateUntyped()130 void *evaluateUntyped() override {
131 CompiledModel->Run();
132 return CompiledModel->result_data(ResultIndex);
133 }
134
135 int32_t ResultIndex = -1;
136 std::unique_ptr<TGen> CompiledModel;
137 };
138
139 /// A mock class satisfying the interface expected by ReleaseModeModelRunner for
140 /// its `TGen` parameter. Useful to avoid conditional compilation complexity, as
141 /// a compile-time replacement for a real AOT-ed model.
142 class NoopSavedModelImpl final {
143 #define NOOP_MODEL_ERRMSG \
144 "The mock AOT-ed saved model is a compile-time stub and should not be " \
145 "called."
146
147 public:
148 NoopSavedModelImpl() = default;
LookupArgIndex(const std::string &)149 int LookupArgIndex(const std::string &) { llvm_unreachable(NOOP_MODEL_ERRMSG); }
LookupResultIndex(const std::string &)150 int LookupResultIndex(const std::string &) { llvm_unreachable(NOOP_MODEL_ERRMSG); }
Run()151 void Run() { llvm_unreachable(NOOP_MODEL_ERRMSG); }
result_data(int)152 void *result_data(int) { llvm_unreachable(NOOP_MODEL_ERRMSG); }
arg_data(int)153 void *arg_data(int) { llvm_unreachable(NOOP_MODEL_ERRMSG); }
154 #undef NOOP_MODEL_ERRMSG
155 };
156
isEmbeddedModelEvaluatorValid()157 template <class T> bool isEmbeddedModelEvaluatorValid() { return true; }
158
159 template <> inline bool isEmbeddedModelEvaluatorValid<NoopSavedModelImpl>() {
160 return false;
161 }
162 } // namespace llvm
163
164 #endif // LLVM_ANALYSIS_RELEASEMODEMODELRUNNER_H
165