1 //===- TFUtils.h - utilities for TFLite -------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 #ifndef LLVM_ANALYSIS_UTILS_TFUTILS_H 10 #define LLVM_ANALYSIS_UTILS_TFUTILS_H 11 12 #include "llvm/Config/llvm-config.h" 13 14 #ifdef LLVM_HAVE_TFLITE 15 #include "llvm/ADT/StringMap.h" 16 #include "llvm/Analysis/TensorSpec.h" 17 #include "llvm/IR/LLVMContext.h" 18 #include "llvm/Support/JSON.h" 19 20 #include <memory> 21 #include <vector> 22 23 namespace llvm { 24 25 /// Load a SavedModel, find the given inputs and outputs, and setup storage 26 /// for input tensors. The user is responsible for correctly dimensioning the 27 /// input tensors and setting their values before calling evaluate(). 28 /// To initialize: 29 /// - construct the object 30 /// - initialize the input tensors using initInput. Indices must correspond to 31 /// indices in the InputNames used at construction. 32 /// To use: 33 /// - set input values by using getInput to get each input tensor, and then 34 /// setting internal scalars, for all dimensions (tensors are row-major: 35 /// https://github.com/tensorflow/tensorflow/blob/r1.5/tensorflow/c/c_api.h#L205) 36 /// - call evaluate. The input tensors' values are not consumed after this, and 37 /// may still be read. 38 /// - use the outputs in the output vector 39 class TFModelEvaluatorImpl; 40 class EvaluationResultImpl; 41 42 class TFModelEvaluator final { 43 public: 44 /// The result of a model evaluation. Handles the lifetime of the output 45 /// tensors, which means that their values need to be used before 46 /// the EvaluationResult's dtor is called. 47 class EvaluationResult { 48 public: 49 EvaluationResult(const EvaluationResult &) = delete; 50 EvaluationResult &operator=(const EvaluationResult &Other) = delete; 51 52 EvaluationResult(EvaluationResult &&Other); 53 EvaluationResult &operator=(EvaluationResult &&Other); 54 55 ~EvaluationResult(); 56 57 /// Get a (const) pointer to the first element of the tensor at Index. 58 template <typename T> T *getTensorValue(size_t Index) { 59 return static_cast<T *>(getUntypedTensorValue(Index)); 60 } 61 62 template <typename T> const T *getTensorValue(size_t Index) const { 63 return static_cast<T *>(getUntypedTensorValue(Index)); 64 } 65 66 /// Get a (const) pointer to the untyped data of the tensor. 67 void *getUntypedTensorValue(size_t Index); 68 const void *getUntypedTensorValue(size_t Index) const; 69 70 private: 71 friend class TFModelEvaluator; 72 EvaluationResult(std::unique_ptr<EvaluationResultImpl> Impl); 73 std::unique_ptr<EvaluationResultImpl> Impl; 74 }; 75 76 TFModelEvaluator(StringRef SavedModelPath, 77 const std::vector<TensorSpec> &InputSpecs, 78 const std::vector<TensorSpec> &OutputSpecs, 79 const char *Tags = "serve"); 80 81 ~TFModelEvaluator(); 82 TFModelEvaluator(const TFModelEvaluator &) = delete; 83 TFModelEvaluator(TFModelEvaluator &&) = delete; 84 85 /// Evaluate the model, assuming it is valid. Returns std::nullopt if the 86 /// evaluation fails or the model is invalid, or an EvaluationResult 87 /// otherwise. The inputs are assumed to have been already provided via 88 /// getInput(). When returning std::nullopt, it also invalidates this object. 89 std::optional<EvaluationResult> evaluate(); 90 91 /// Provides access to the input vector. 92 template <typename T> T *getInput(size_t Index) { 93 return static_cast<T *>(getUntypedInput(Index)); 94 } 95 96 /// Returns true if the model was loaded successfully, false 97 /// otherwise. 98 bool isValid() const { return !!Impl; } 99 100 /// Untyped access to input. 101 void *getUntypedInput(size_t Index); 102 103 private: 104 std::unique_ptr<TFModelEvaluatorImpl> Impl; 105 }; 106 107 } // namespace llvm 108 109 #endif // LLVM_HAVE_TFLITE 110 #endif // LLVM_ANALYSIS_UTILS_TFUTILS_H 111