1 //===- ModelUnderTrainingRunner.h -- 'development' mode runner --*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 10 #ifndef LLVM_ANALYSIS_MODELUNDERTRAININGRUNNER_H 11 #define LLVM_ANALYSIS_MODELUNDERTRAININGRUNNER_H 12 13 #include "llvm/ADT/STLExtras.h" 14 #include "llvm/ADT/iterator_range.h" 15 #include "llvm/Analysis/TensorSpec.h" 16 #include "llvm/Config/llvm-config.h" 17 18 #ifdef LLVM_HAVE_TFLITE 19 #include "llvm/Analysis/MLModelRunner.h" 20 #include "llvm/Analysis/Utils/TFUtils.h" 21 #include "llvm/IR/LLVMContext.h" 22 #include "llvm/IR/PassManager.h" 23 24 namespace llvm { 25 26 /// ModelUnderTrainingRunner - training mode implementation. It uses TFLite 27 /// to dynamically load and evaluate a TF SavedModel 28 /// (https://www.tensorflow.org/guide/saved_model) converted to TFLite. see 29 /// lib/Analysis/models/saved-model-to-tflite.py. Runtime performance is 30 /// sacrificed for ease of use while training. 31 class ModelUnderTrainingRunner final : public MLModelRunner { 32 public: 33 // Disallows copy and assign. 34 ModelUnderTrainingRunner(const ModelUnderTrainingRunner &) = delete; 35 ModelUnderTrainingRunner & 36 operator=(const ModelUnderTrainingRunner &) = delete; 37 38 const std::vector<TensorSpec> &extraOutputsForLoggingSpecs() const { 39 return ExtraOutputsForLogging; 40 } 41 42 const void *getUntypedExtraOutputValue(size_t ExtraOutputIndex) const { 43 return lastEvaluationResult()->getUntypedTensorValue(ExtraOutputIndex + 1); 44 } 45 46 const std::optional<TFModelEvaluator::EvaluationResult> & 47 lastEvaluationResult() const { 48 return LastEvaluationResult; 49 } 50 static bool classof(const MLModelRunner *R) { 51 return R->getKind() == MLModelRunner::Kind::Development; 52 } 53 54 static std::unique_ptr<ModelUnderTrainingRunner> 55 createAndEnsureValid(LLVMContext &Ctx, const std::string &ModelPath, 56 StringRef DecisionName, 57 const std::vector<TensorSpec> &InputSpecs, 58 StringRef OutputSpecsPathOverride = ""); 59 60 ModelUnderTrainingRunner( 61 LLVMContext &Ctx, const std::string &ModelPath, 62 const std::vector<TensorSpec> &InputSpecs, 63 const std::vector<TensorSpec> &OutputSpecs, 64 const std::vector<TensorSpec> &ExtraOutputsForLogging = {}); 65 66 bool isValid() const { return !!Evaluator; } 67 68 private: 69 std::unique_ptr<TFModelEvaluator> Evaluator; 70 const std::vector<TensorSpec> OutputSpecs; 71 const std::vector<TensorSpec> ExtraOutputsForLogging; 72 std::optional<TFModelEvaluator::EvaluationResult> LastEvaluationResult; 73 void *evaluateUntyped() override; 74 }; 75 76 } // namespace llvm 77 #endif // define(LLVM_HAVE_TFLITE) 78 #endif // LLVM_ANALYSIS_MODELUNDERTRAININGRUNNER_H 79