1 //===- MLModelRunner.h ---- ML model runner interface -----------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 10 #ifndef LLVM_ANALYSIS_MLMODELRUNNER_H 11 #define LLVM_ANALYSIS_MLMODELRUNNER_H 12 13 #include "llvm/Analysis/TensorSpec.h" 14 #include "llvm/IR/PassManager.h" 15 16 namespace llvm { 17 class LLVMContext; 18 19 /// MLModelRunner interface: abstraction of a mechanism for evaluating a 20 /// ML model. More abstractly, evaluating a function that has as tensors as 21 /// arguments, described via TensorSpecs, and returns a tensor. Currently, the 22 /// latter is assumed to be a scalar, in absence of more elaborate scenarios. 23 /// NOTE: feature indices are expected to be consistent all accross 24 /// MLModelRunners (pertaining to the same model), and also Loggers (see 25 /// TFUtils.h) 26 class MLModelRunner { 27 public: 28 // Disallows copy and assign. 29 MLModelRunner(const MLModelRunner &) = delete; 30 MLModelRunner &operator=(const MLModelRunner &) = delete; 31 virtual ~MLModelRunner() = default; 32 33 template <typename T> T evaluate() { 34 return *reinterpret_cast<T *>(evaluateUntyped()); 35 } 36 37 template <typename T, typename I> T *getTensor(I FeatureID) { 38 return reinterpret_cast<T *>( 39 getTensorUntyped(static_cast<size_t>(FeatureID))); 40 } 41 42 template <typename T, typename I> const T *getTensor(I FeatureID) const { 43 return reinterpret_cast<const T *>( 44 getTensorUntyped(static_cast<size_t>(FeatureID))); 45 } 46 47 void *getTensorUntyped(size_t Index) { return InputBuffers[Index]; } 48 const void *getTensorUntyped(size_t Index) const { 49 return (const_cast<MLModelRunner *>(this))->getTensorUntyped(Index); 50 } 51 52 enum class Kind : int { Unknown, Release, Development, NoOp, Interactive }; 53 Kind getKind() const { return Type; } 54 virtual void switchContext(StringRef Name) {} 55 56 protected: 57 MLModelRunner(LLVMContext &Ctx, Kind Type, size_t NrInputs) 58 : Ctx(Ctx), Type(Type), InputBuffers(NrInputs) { 59 assert(Type != Kind::Unknown); 60 } 61 virtual void *evaluateUntyped() = 0; 62 63 void setUpBufferForTensor(size_t Index, const TensorSpec &Spec, 64 void *Buffer) { 65 if (!Buffer) { 66 OwnedBuffers.emplace_back(Spec.getTotalTensorBufferSize()); 67 Buffer = OwnedBuffers.back().data(); 68 } 69 InputBuffers[Index] = Buffer; 70 } 71 72 LLVMContext &Ctx; 73 const Kind Type; 74 75 private: 76 std::vector<void *> InputBuffers; 77 std::vector<std::vector<char *>> OwnedBuffers; 78 }; 79 } // namespace llvm 80 81 #endif // LLVM_ANALYSIS_MLMODELRUNNER_H 82