xref: /freebsd/contrib/llvm-project/llvm/lib/Analysis/TFLiteUtils.cpp (revision 5f757f3ff9144b609b3c433dfd370cc6bdc191ad)
1*5f757f3fSDimitry Andric //===- TFUtils.cpp - TFLite-based evaluation utilities --------------------===//
2bdd1243dSDimitry Andric //
3bdd1243dSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4bdd1243dSDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5bdd1243dSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6bdd1243dSDimitry Andric //
7bdd1243dSDimitry Andric //===----------------------------------------------------------------------===//
8bdd1243dSDimitry Andric //
9*5f757f3fSDimitry Andric // This file implements utilities for interfacing with TFLite.
10bdd1243dSDimitry Andric //
11bdd1243dSDimitry Andric //===----------------------------------------------------------------------===//
12bdd1243dSDimitry Andric #include "llvm/Config/config.h"
13bdd1243dSDimitry Andric #if defined(LLVM_HAVE_TFLITE)
14bdd1243dSDimitry Andric 
15bdd1243dSDimitry Andric #include "llvm/ADT/Twine.h"
16bdd1243dSDimitry Andric #include "llvm/Analysis/Utils/TFUtils.h"
17bdd1243dSDimitry Andric #include "llvm/Support/Base64.h"
18bdd1243dSDimitry Andric #include "llvm/Support/CommandLine.h"
19bdd1243dSDimitry Andric #include "llvm/Support/Debug.h"
20bdd1243dSDimitry Andric #include "llvm/Support/JSON.h"
21bdd1243dSDimitry Andric #include "llvm/Support/MemoryBuffer.h"
22bdd1243dSDimitry Andric #include "llvm/Support/Path.h"
23bdd1243dSDimitry Andric #include "llvm/Support/raw_ostream.h"
24bdd1243dSDimitry Andric 
25bdd1243dSDimitry Andric #include "tensorflow/lite/interpreter.h"
26bdd1243dSDimitry Andric #include "tensorflow/lite/kernels/register.h"
27bdd1243dSDimitry Andric #include "tensorflow/lite/model.h"
28bdd1243dSDimitry Andric #include "tensorflow/lite/model_builder.h"
29bdd1243dSDimitry Andric #include "tensorflow/lite/op_resolver.h"
30bdd1243dSDimitry Andric #include "tensorflow/lite/logger.h"
31bdd1243dSDimitry Andric 
32bdd1243dSDimitry Andric #include <cassert>
33bdd1243dSDimitry Andric #include <numeric>
34bdd1243dSDimitry Andric #include <optional>
35bdd1243dSDimitry Andric 
36bdd1243dSDimitry Andric using namespace llvm;
37bdd1243dSDimitry Andric 
38bdd1243dSDimitry Andric namespace llvm {
39bdd1243dSDimitry Andric class EvaluationResultImpl {
40bdd1243dSDimitry Andric public:
41bdd1243dSDimitry Andric   EvaluationResultImpl(const std::vector<const TfLiteTensor *> &Outputs)
42bdd1243dSDimitry Andric       : Outputs(Outputs){};
43bdd1243dSDimitry Andric 
44bdd1243dSDimitry Andric   const TfLiteTensor *getOutput(size_t I) { return Outputs[I]; }
45bdd1243dSDimitry Andric 
46bdd1243dSDimitry Andric   EvaluationResultImpl(const EvaluationResultImpl &) = delete;
47bdd1243dSDimitry Andric   EvaluationResultImpl(EvaluationResultImpl &&Other) = delete;
48bdd1243dSDimitry Andric 
49bdd1243dSDimitry Andric private:
50bdd1243dSDimitry Andric   const std::vector<const TfLiteTensor *> Outputs;
51bdd1243dSDimitry Andric };
52bdd1243dSDimitry Andric 
53bdd1243dSDimitry Andric class TFModelEvaluatorImpl {
54bdd1243dSDimitry Andric public:
55bdd1243dSDimitry Andric   TFModelEvaluatorImpl(StringRef SavedModelPath,
56bdd1243dSDimitry Andric                        const std::vector<TensorSpec> &InputSpecs,
57bdd1243dSDimitry Andric                        const std::vector<TensorSpec> &OutputSpecs,
58bdd1243dSDimitry Andric                        const char *Tags);
59bdd1243dSDimitry Andric 
60bdd1243dSDimitry Andric   bool isValid() const { return IsValid; }
61bdd1243dSDimitry Andric   size_t outputSize() const { return Output.size(); }
62bdd1243dSDimitry Andric 
63bdd1243dSDimitry Andric   std::unique_ptr<EvaluationResultImpl> evaluate() {
64bdd1243dSDimitry Andric     Interpreter->Invoke();
65bdd1243dSDimitry Andric     return std::make_unique<EvaluationResultImpl>(Output);
66bdd1243dSDimitry Andric   }
67bdd1243dSDimitry Andric 
68bdd1243dSDimitry Andric   const std::vector<TfLiteTensor *> &getInput() const { return Input; }
69bdd1243dSDimitry Andric 
70bdd1243dSDimitry Andric   ~TFModelEvaluatorImpl();
71bdd1243dSDimitry Andric 
72bdd1243dSDimitry Andric private:
73bdd1243dSDimitry Andric   std::unique_ptr<tflite::FlatBufferModel> Model;
74bdd1243dSDimitry Andric 
75bdd1243dSDimitry Andric   /// The objects necessary for carrying out an evaluation of the SavedModel.
76bdd1243dSDimitry Andric   /// They are expensive to set up, and we maintain them accross all the
77bdd1243dSDimitry Andric   /// evaluations of the model.
78bdd1243dSDimitry Andric   std::unique_ptr<tflite::Interpreter> Interpreter;
79bdd1243dSDimitry Andric 
80bdd1243dSDimitry Andric   /// The input tensors. We set up the tensors once and just mutate theirs
81bdd1243dSDimitry Andric   /// scalars before each evaluation. The input tensors keep their value after
82bdd1243dSDimitry Andric   /// an evaluation.
83bdd1243dSDimitry Andric   std::vector<TfLiteTensor *> Input;
84bdd1243dSDimitry Andric 
85bdd1243dSDimitry Andric   /// The output nodes.
86bdd1243dSDimitry Andric   std::vector<const TfLiteTensor *> Output;
87bdd1243dSDimitry Andric 
88bdd1243dSDimitry Andric   void invalidate() { IsValid = false; }
89bdd1243dSDimitry Andric 
90bdd1243dSDimitry Andric   bool IsValid = true;
91bdd1243dSDimitry Andric 
92bdd1243dSDimitry Andric   /// Reusable utility for ensuring we can bind the requested Name to a node in
93bdd1243dSDimitry Andric   /// the SavedModel Graph.
94bdd1243dSDimitry Andric   bool checkReportAndInvalidate(const TfLiteTensor *Tensor,
95bdd1243dSDimitry Andric                                 const TensorSpec &Spec);
96bdd1243dSDimitry Andric };
97bdd1243dSDimitry Andric 
98bdd1243dSDimitry Andric } // namespace llvm
99bdd1243dSDimitry Andric 
100bdd1243dSDimitry Andric TFModelEvaluatorImpl::TFModelEvaluatorImpl(
101bdd1243dSDimitry Andric     StringRef SavedModelPath, const std::vector<TensorSpec> &InputSpecs,
102bdd1243dSDimitry Andric     const std::vector<TensorSpec> &OutputSpecs, const char *Tags = "serve")
103bdd1243dSDimitry Andric     : Input(InputSpecs.size()), Output(OutputSpecs.size()) {
104bdd1243dSDimitry Andric   // INFO and DEBUG messages could be numerous and not particularly interesting
105bdd1243dSDimitry Andric   tflite::LoggerOptions::SetMinimumLogSeverity(tflite::TFLITE_LOG_WARNING);
106bdd1243dSDimitry Andric   // FIXME: make ErrorReporter a member (may also need subclassing
107bdd1243dSDimitry Andric   // StatefulErrorReporter) to easily get the latest error status, for
108bdd1243dSDimitry Andric   // debugging.
109bdd1243dSDimitry Andric   tflite::StderrReporter ErrorReporter;
110bdd1243dSDimitry Andric   SmallVector<char, 128> TFLitePathBuff;
111bdd1243dSDimitry Andric   llvm::sys::path::append(TFLitePathBuff, SavedModelPath, "model.tflite");
112bdd1243dSDimitry Andric   StringRef TFLitePath(TFLitePathBuff.data(), TFLitePathBuff.size());
113bdd1243dSDimitry Andric   Model = tflite::FlatBufferModel::BuildFromFile(TFLitePath.str().c_str(),
114bdd1243dSDimitry Andric                                                  &ErrorReporter);
115bdd1243dSDimitry Andric   if (!Model) {
116bdd1243dSDimitry Andric     invalidate();
117bdd1243dSDimitry Andric     return;
118bdd1243dSDimitry Andric   }
119bdd1243dSDimitry Andric 
120bdd1243dSDimitry Andric   tflite::ops::builtin::BuiltinOpResolver Resolver;
121bdd1243dSDimitry Andric   tflite::InterpreterBuilder Builder(*Model, Resolver);
122bdd1243dSDimitry Andric   Builder(&Interpreter);
123bdd1243dSDimitry Andric 
124bdd1243dSDimitry Andric   if (!Interpreter) {
125bdd1243dSDimitry Andric     invalidate();
126bdd1243dSDimitry Andric     return;
127bdd1243dSDimitry Andric   }
128bdd1243dSDimitry Andric 
129bdd1243dSDimitry Andric   // We assume the input buffers are valid for the lifetime of the interpreter.
130bdd1243dSDimitry Andric   // By default, tflite allocates memory in an arena and will periodically take
131bdd1243dSDimitry Andric   // away memory and reallocate it in a different location after evaluations in
132bdd1243dSDimitry Andric   // order to improve utilization of the buffers owned in the arena. So, we
133bdd1243dSDimitry Andric   // explicitly mark our input buffers as persistent to avoid this behavior.
134bdd1243dSDimitry Andric   for (size_t I = 0; I < Interpreter->inputs().size(); ++I)
135bdd1243dSDimitry Andric     Interpreter->tensor(I)->allocation_type =
136bdd1243dSDimitry Andric         TfLiteAllocationType::kTfLiteArenaRwPersistent;
137bdd1243dSDimitry Andric 
138bdd1243dSDimitry Andric   if (Interpreter->AllocateTensors() != TfLiteStatus::kTfLiteOk) {
139bdd1243dSDimitry Andric     invalidate();
140bdd1243dSDimitry Andric     return;
141bdd1243dSDimitry Andric   }
142bdd1243dSDimitry Andric   // Known inputs and outputs
143bdd1243dSDimitry Andric   StringMap<int> InputsMap;
144bdd1243dSDimitry Andric   StringMap<int> OutputsMap;
145bdd1243dSDimitry Andric   for (size_t I = 0; I < Interpreter->inputs().size(); ++I)
146bdd1243dSDimitry Andric     InputsMap[Interpreter->GetInputName(I)] = I;
147bdd1243dSDimitry Andric   for (size_t I = 0; I < Interpreter->outputs().size(); ++I)
148bdd1243dSDimitry Andric     OutputsMap[Interpreter->GetOutputName(I)] = I;
149bdd1243dSDimitry Andric 
150bdd1243dSDimitry Andric   size_t NumberFeaturesPassed = 0;
151bdd1243dSDimitry Andric   for (size_t I = 0; I < InputSpecs.size(); ++I) {
152bdd1243dSDimitry Andric     auto &InputSpec = InputSpecs[I];
153bdd1243dSDimitry Andric     auto MapI = InputsMap.find(InputSpec.name() + ":" +
154bdd1243dSDimitry Andric                                std::to_string(InputSpec.port()));
155bdd1243dSDimitry Andric     if (MapI == InputsMap.end()) {
156bdd1243dSDimitry Andric       Input[I] = nullptr;
157bdd1243dSDimitry Andric       continue;
158bdd1243dSDimitry Andric     }
159bdd1243dSDimitry Andric     Input[I] = Interpreter->tensor(MapI->second);
160bdd1243dSDimitry Andric     if (!checkReportAndInvalidate(Input[I], InputSpec))
161bdd1243dSDimitry Andric       return;
162bdd1243dSDimitry Andric     std::memset(Input[I]->data.data, 0,
163bdd1243dSDimitry Andric                 InputSpecs[I].getTotalTensorBufferSize());
164bdd1243dSDimitry Andric     ++NumberFeaturesPassed;
165bdd1243dSDimitry Andric   }
166bdd1243dSDimitry Andric 
167bdd1243dSDimitry Andric   if (NumberFeaturesPassed < Interpreter->inputs().size()) {
168bdd1243dSDimitry Andric     // we haven't passed all the required features to the model, throw an error.
169bdd1243dSDimitry Andric     errs() << "Required feature(s) have not been passed to the ML model";
170bdd1243dSDimitry Andric     invalidate();
171bdd1243dSDimitry Andric     return;
172bdd1243dSDimitry Andric   }
173bdd1243dSDimitry Andric 
174bdd1243dSDimitry Andric   for (size_t I = 0; I < OutputSpecs.size(); ++I) {
175bdd1243dSDimitry Andric     const auto &OutputSpec = OutputSpecs[I];
176bdd1243dSDimitry Andric     Output[I] = Interpreter->output_tensor(
177bdd1243dSDimitry Andric         OutputsMap[OutputSpec.name() + ":" +
178bdd1243dSDimitry Andric                    std::to_string(OutputSpec.port())]);
179bdd1243dSDimitry Andric     if (!checkReportAndInvalidate(Output[I], OutputSpec))
180bdd1243dSDimitry Andric       return;
181bdd1243dSDimitry Andric   }
182bdd1243dSDimitry Andric }
183bdd1243dSDimitry Andric 
184bdd1243dSDimitry Andric TFModelEvaluator::TFModelEvaluator(StringRef SavedModelPath,
185bdd1243dSDimitry Andric                                    const std::vector<TensorSpec> &InputSpecs,
186bdd1243dSDimitry Andric                                    const std::vector<TensorSpec> &OutputSpecs,
187bdd1243dSDimitry Andric                                    const char *Tags)
188bdd1243dSDimitry Andric     : Impl(new TFModelEvaluatorImpl(SavedModelPath, InputSpecs, OutputSpecs,
189bdd1243dSDimitry Andric                                     Tags)) {
190bdd1243dSDimitry Andric   if (!Impl->isValid())
191bdd1243dSDimitry Andric     Impl.reset();
192bdd1243dSDimitry Andric }
193bdd1243dSDimitry Andric 
194bdd1243dSDimitry Andric TFModelEvaluatorImpl::~TFModelEvaluatorImpl() {}
195bdd1243dSDimitry Andric 
196bdd1243dSDimitry Andric bool TFModelEvaluatorImpl::checkReportAndInvalidate(const TfLiteTensor *Tensor,
197bdd1243dSDimitry Andric                                                     const TensorSpec &Spec) {
198bdd1243dSDimitry Andric   if (!Tensor) {
199bdd1243dSDimitry Andric     errs() << "Could not find TF_Output named: " + Spec.name();
200bdd1243dSDimitry Andric     IsValid = false;
201bdd1243dSDimitry Andric   }
202bdd1243dSDimitry Andric   if (Spec.getTotalTensorBufferSize() != Tensor->bytes)
203bdd1243dSDimitry Andric     IsValid = false;
204bdd1243dSDimitry Andric 
205bdd1243dSDimitry Andric   // If the total sizes match, there could still be a mismatch in the shape.
206bdd1243dSDimitry Andric   // We ignore that for now.
207bdd1243dSDimitry Andric 
208bdd1243dSDimitry Andric   return IsValid;
209bdd1243dSDimitry Andric }
210bdd1243dSDimitry Andric 
211bdd1243dSDimitry Andric std::optional<TFModelEvaluator::EvaluationResult> TFModelEvaluator::evaluate() {
212bdd1243dSDimitry Andric   if (!isValid())
213bdd1243dSDimitry Andric     return std::nullopt;
214bdd1243dSDimitry Andric   return EvaluationResult(Impl->evaluate());
215bdd1243dSDimitry Andric }
216bdd1243dSDimitry Andric 
217bdd1243dSDimitry Andric void *TFModelEvaluator::getUntypedInput(size_t Index) {
218bdd1243dSDimitry Andric   TfLiteTensor *T = Impl->getInput()[Index];
219bdd1243dSDimitry Andric   if (!T)
220bdd1243dSDimitry Andric     return nullptr;
221bdd1243dSDimitry Andric   return T->data.data;
222bdd1243dSDimitry Andric }
223bdd1243dSDimitry Andric 
224bdd1243dSDimitry Andric TFModelEvaluator::EvaluationResult::EvaluationResult(
225bdd1243dSDimitry Andric     std::unique_ptr<EvaluationResultImpl> Impl)
226bdd1243dSDimitry Andric     : Impl(std::move(Impl)) {}
227bdd1243dSDimitry Andric 
228bdd1243dSDimitry Andric TFModelEvaluator::EvaluationResult::EvaluationResult(EvaluationResult &&Other)
229bdd1243dSDimitry Andric     : Impl(std::move(Other.Impl)) {}
230bdd1243dSDimitry Andric 
231bdd1243dSDimitry Andric TFModelEvaluator::EvaluationResult &
232bdd1243dSDimitry Andric TFModelEvaluator::EvaluationResult::operator=(EvaluationResult &&Other) {
233bdd1243dSDimitry Andric   Impl = std::move(Other.Impl);
234bdd1243dSDimitry Andric   return *this;
235bdd1243dSDimitry Andric }
236bdd1243dSDimitry Andric 
237bdd1243dSDimitry Andric void *TFModelEvaluator::EvaluationResult::getUntypedTensorValue(size_t Index) {
238bdd1243dSDimitry Andric   return Impl->getOutput(Index)->data.data;
239bdd1243dSDimitry Andric }
240bdd1243dSDimitry Andric 
241bdd1243dSDimitry Andric const void *
242bdd1243dSDimitry Andric TFModelEvaluator::EvaluationResult::getUntypedTensorValue(size_t Index) const {
243bdd1243dSDimitry Andric   return Impl->getOutput(Index)->data.data;
244bdd1243dSDimitry Andric }
245bdd1243dSDimitry Andric 
246bdd1243dSDimitry Andric TFModelEvaluator::EvaluationResult::~EvaluationResult() {}
247bdd1243dSDimitry Andric TFModelEvaluator::~TFModelEvaluator() {}
248bdd1243dSDimitry Andric 
249bdd1243dSDimitry Andric #endif // defined(LLVM_HAVE_TFLITE)
250