15f757f3fSDimitry Andric //===- MLRegAllocPriorityAdvisor.cpp - ML priority advisor-----------------===// 25f757f3fSDimitry Andric // 35f757f3fSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 45f757f3fSDimitry Andric // See https://llvm.org/LICENSE.txt for license information. 55f757f3fSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 65f757f3fSDimitry Andric // 75f757f3fSDimitry Andric //===----------------------------------------------------------------------===// 85f757f3fSDimitry Andric // 95f757f3fSDimitry Andric // Implementation of the ML priority advisor and reward injection pass 105f757f3fSDimitry Andric // 115f757f3fSDimitry Andric //===----------------------------------------------------------------------===// 125f757f3fSDimitry Andric 135f757f3fSDimitry Andric #include "AllocationOrder.h" 145f757f3fSDimitry Andric #include "RegAllocGreedy.h" 155f757f3fSDimitry Andric #include "RegAllocPriorityAdvisor.h" 165f757f3fSDimitry Andric #include "llvm/Analysis/AliasAnalysis.h" 175f757f3fSDimitry Andric #include "llvm/Analysis/InteractiveModelRunner.h" 185f757f3fSDimitry Andric #include "llvm/Analysis/MLModelRunner.h" 195f757f3fSDimitry Andric #include "llvm/Analysis/ReleaseModeModelRunner.h" 205f757f3fSDimitry Andric #include "llvm/Analysis/TensorSpec.h" 215f757f3fSDimitry Andric #include "llvm/CodeGen/CalcSpillWeights.h" 225f757f3fSDimitry Andric #include "llvm/CodeGen/LiveRegMatrix.h" 235f757f3fSDimitry Andric #include "llvm/CodeGen/MachineBlockFrequencyInfo.h" 245f757f3fSDimitry Andric #include "llvm/CodeGen/MachineFunction.h" 255f757f3fSDimitry Andric #include "llvm/CodeGen/MachineLoopInfo.h" 265f757f3fSDimitry Andric #include "llvm/CodeGen/MachineRegisterInfo.h" 275f757f3fSDimitry Andric #include "llvm/CodeGen/Passes.h" 285f757f3fSDimitry Andric #include "llvm/CodeGen/RegisterClassInfo.h" 295f757f3fSDimitry Andric #include "llvm/CodeGen/SlotIndexes.h" 305f757f3fSDimitry Andric #include "llvm/CodeGen/VirtRegMap.h" 315f757f3fSDimitry Andric #include "llvm/InitializePasses.h" 325f757f3fSDimitry Andric #include "llvm/Pass.h" 335f757f3fSDimitry Andric #include "llvm/PassRegistry.h" 345f757f3fSDimitry Andric #include "llvm/Support/CommandLine.h" 355f757f3fSDimitry Andric 365f757f3fSDimitry Andric #if defined(LLVM_HAVE_TFLITE) 375f757f3fSDimitry Andric #include "llvm/Analysis/ModelUnderTrainingRunner.h" 385f757f3fSDimitry Andric #include "llvm/Analysis/NoInferenceModelRunner.h" 395f757f3fSDimitry Andric #include "llvm/Analysis/Utils/TrainingLogger.h" 40*0fca6ea1SDimitry Andric #include "llvm/IR/Module.h" 415f757f3fSDimitry Andric #endif 425f757f3fSDimitry Andric 435f757f3fSDimitry Andric using namespace llvm; 445f757f3fSDimitry Andric 455f757f3fSDimitry Andric static cl::opt<std::string> InteractiveChannelBaseName( 465f757f3fSDimitry Andric "regalloc-priority-interactive-channel-base", cl::Hidden, 475f757f3fSDimitry Andric cl::desc( 485f757f3fSDimitry Andric "Base file path for the interactive mode. The incoming filename should " 495f757f3fSDimitry Andric "have the name <regalloc-priority-interactive-channel-base>.in, while " 505f757f3fSDimitry Andric "the outgoing name should be " 515f757f3fSDimitry Andric "<regalloc-priority-interactive-channel-base>.out")); 525f757f3fSDimitry Andric 535f757f3fSDimitry Andric using CompiledModelType = NoopSavedModelImpl; 545f757f3fSDimitry Andric 555f757f3fSDimitry Andric // Options that only make sense in development mode 565f757f3fSDimitry Andric #ifdef LLVM_HAVE_TFLITE 575f757f3fSDimitry Andric #include "RegAllocScore.h" 585f757f3fSDimitry Andric #include "llvm/Analysis/Utils/TFUtils.h" 595f757f3fSDimitry Andric 605f757f3fSDimitry Andric static cl::opt<std::string> TrainingLog( 615f757f3fSDimitry Andric "regalloc-priority-training-log", cl::Hidden, 625f757f3fSDimitry Andric cl::desc("Training log for the register allocator priority model")); 635f757f3fSDimitry Andric 645f757f3fSDimitry Andric static cl::opt<std::string> ModelUnderTraining( 655f757f3fSDimitry Andric "regalloc-priority-model", cl::Hidden, 665f757f3fSDimitry Andric cl::desc("The model being trained for register allocation priority")); 675f757f3fSDimitry Andric 685f757f3fSDimitry Andric #endif // #ifdef LLVM_HAVE_TFLITE 695f757f3fSDimitry Andric 705f757f3fSDimitry Andric namespace llvm { 715f757f3fSDimitry Andric 725f757f3fSDimitry Andric static const std::vector<int64_t> PerLiveRangeShape{1}; 735f757f3fSDimitry Andric 745f757f3fSDimitry Andric #define RA_PRIORITY_FEATURES_LIST(M) \ 755f757f3fSDimitry Andric M(int64_t, li_size, PerLiveRangeShape, "size") \ 765f757f3fSDimitry Andric M(int64_t, stage, PerLiveRangeShape, "stage") \ 775f757f3fSDimitry Andric M(float, weight, PerLiveRangeShape, "weight") 785f757f3fSDimitry Andric 795f757f3fSDimitry Andric #define DecisionName "priority" 805f757f3fSDimitry Andric static const TensorSpec DecisionSpec = 815f757f3fSDimitry Andric TensorSpec::createSpec<float>(DecisionName, {1}); 825f757f3fSDimitry Andric 835f757f3fSDimitry Andric 845f757f3fSDimitry Andric // Named features index. 855f757f3fSDimitry Andric enum FeatureIDs { 865f757f3fSDimitry Andric #define _FEATURE_IDX(_, name, __, ___) name, 875f757f3fSDimitry Andric RA_PRIORITY_FEATURES_LIST(_FEATURE_IDX) 885f757f3fSDimitry Andric #undef _FEATURE_IDX 895f757f3fSDimitry Andric FeatureCount 905f757f3fSDimitry Andric }; 915f757f3fSDimitry Andric 925f757f3fSDimitry Andric class MLPriorityAdvisor : public RegAllocPriorityAdvisor { 935f757f3fSDimitry Andric public: 945f757f3fSDimitry Andric MLPriorityAdvisor(const MachineFunction &MF, const RAGreedy &RA, 955f757f3fSDimitry Andric SlotIndexes *const Indexes, MLModelRunner *Runner); 965f757f3fSDimitry Andric 975f757f3fSDimitry Andric protected: 985f757f3fSDimitry Andric const RegAllocPriorityAdvisor &getDefaultAdvisor() const { 995f757f3fSDimitry Andric return static_cast<const RegAllocPriorityAdvisor &>(DefaultAdvisor); 1005f757f3fSDimitry Andric } 1015f757f3fSDimitry Andric 1025f757f3fSDimitry Andric // The assumption is that if the Runner could not be constructed, we emit-ed 1035f757f3fSDimitry Andric // error, and we shouldn't be asking for it here. 1045f757f3fSDimitry Andric const MLModelRunner &getRunner() const { return *Runner; } 1055f757f3fSDimitry Andric float getPriorityImpl(const LiveInterval &LI) const; 1065f757f3fSDimitry Andric unsigned getPriority(const LiveInterval &LI) const override; 1075f757f3fSDimitry Andric 1085f757f3fSDimitry Andric private: 1095f757f3fSDimitry Andric const DefaultPriorityAdvisor DefaultAdvisor; 1105f757f3fSDimitry Andric MLModelRunner *const Runner; 1115f757f3fSDimitry Andric }; 1125f757f3fSDimitry Andric 1135f757f3fSDimitry Andric #define _DECL_FEATURES(type, name, shape, _) \ 1145f757f3fSDimitry Andric TensorSpec::createSpec<type>(#name, shape), 1155f757f3fSDimitry Andric 1165f757f3fSDimitry Andric static const std::vector<TensorSpec> InputFeatures{ 1175f757f3fSDimitry Andric {RA_PRIORITY_FEATURES_LIST(_DECL_FEATURES)}, 1185f757f3fSDimitry Andric }; 1195f757f3fSDimitry Andric #undef _DECL_FEATURES 1205f757f3fSDimitry Andric 1215f757f3fSDimitry Andric // =================================== 1225f757f3fSDimitry Andric // Release (AOT) - specifics 1235f757f3fSDimitry Andric // =================================== 1245f757f3fSDimitry Andric class ReleaseModePriorityAdvisorAnalysis final 1255f757f3fSDimitry Andric : public RegAllocPriorityAdvisorAnalysis { 1265f757f3fSDimitry Andric public: 1275f757f3fSDimitry Andric ReleaseModePriorityAdvisorAnalysis() 1285f757f3fSDimitry Andric : RegAllocPriorityAdvisorAnalysis(AdvisorMode::Release) {} 1295f757f3fSDimitry Andric // support for isa<> and dyn_cast. 1305f757f3fSDimitry Andric static bool classof(const RegAllocPriorityAdvisorAnalysis *R) { 1315f757f3fSDimitry Andric return R->getAdvisorMode() == AdvisorMode::Release; 1325f757f3fSDimitry Andric } 1335f757f3fSDimitry Andric 1345f757f3fSDimitry Andric private: 1355f757f3fSDimitry Andric void getAnalysisUsage(AnalysisUsage &AU) const override { 1365f757f3fSDimitry Andric AU.setPreservesAll(); 137*0fca6ea1SDimitry Andric AU.addRequired<SlotIndexesWrapperPass>(); 1385f757f3fSDimitry Andric RegAllocPriorityAdvisorAnalysis::getAnalysisUsage(AU); 1395f757f3fSDimitry Andric } 1405f757f3fSDimitry Andric 1415f757f3fSDimitry Andric std::unique_ptr<RegAllocPriorityAdvisor> 1425f757f3fSDimitry Andric getAdvisor(const MachineFunction &MF, const RAGreedy &RA) override { 1435f757f3fSDimitry Andric if (!Runner) { 1445f757f3fSDimitry Andric if (InteractiveChannelBaseName.empty()) 1455f757f3fSDimitry Andric Runner = std::make_unique<ReleaseModeModelRunner<CompiledModelType>>( 1465f757f3fSDimitry Andric MF.getFunction().getContext(), InputFeatures, DecisionName); 1475f757f3fSDimitry Andric else 1485f757f3fSDimitry Andric Runner = std::make_unique<InteractiveModelRunner>( 1495f757f3fSDimitry Andric MF.getFunction().getContext(), InputFeatures, DecisionSpec, 1505f757f3fSDimitry Andric InteractiveChannelBaseName + ".out", 1515f757f3fSDimitry Andric InteractiveChannelBaseName + ".in"); 1525f757f3fSDimitry Andric } 1535f757f3fSDimitry Andric return std::make_unique<MLPriorityAdvisor>( 154*0fca6ea1SDimitry Andric MF, RA, &getAnalysis<SlotIndexesWrapperPass>().getSI(), Runner.get()); 1555f757f3fSDimitry Andric } 1565f757f3fSDimitry Andric std::unique_ptr<MLModelRunner> Runner; 1575f757f3fSDimitry Andric }; 1585f757f3fSDimitry Andric 1595f757f3fSDimitry Andric // =================================== 1605f757f3fSDimitry Andric // Development mode-specifics 1615f757f3fSDimitry Andric // =================================== 1625f757f3fSDimitry Andric // 1635f757f3fSDimitry Andric // Features we log 1645f757f3fSDimitry Andric #ifdef LLVM_HAVE_TFLITE 1655f757f3fSDimitry Andric static const TensorSpec Reward = TensorSpec::createSpec<float>("reward", {1}); 1665f757f3fSDimitry Andric 1675f757f3fSDimitry Andric #define _DECL_TRAIN_FEATURES(type, name, shape, _) \ 1685f757f3fSDimitry Andric TensorSpec::createSpec<type>(std::string("action_") + #name, shape), 1695f757f3fSDimitry Andric 1705f757f3fSDimitry Andric static const std::vector<TensorSpec> TrainingInputFeatures{ 1715f757f3fSDimitry Andric {RA_PRIORITY_FEATURES_LIST(_DECL_TRAIN_FEATURES) 1725f757f3fSDimitry Andric TensorSpec::createSpec<float>("action_discount", {1}), 1735f757f3fSDimitry Andric TensorSpec::createSpec<int32_t>("action_step_type", {1}), 1745f757f3fSDimitry Andric TensorSpec::createSpec<float>("action_reward", {1})}}; 1755f757f3fSDimitry Andric #undef _DECL_TRAIN_FEATURES 1765f757f3fSDimitry Andric 1775f757f3fSDimitry Andric class DevelopmentModePriorityAdvisor : public MLPriorityAdvisor { 1785f757f3fSDimitry Andric public: 1795f757f3fSDimitry Andric DevelopmentModePriorityAdvisor(const MachineFunction &MF, const RAGreedy &RA, 1805f757f3fSDimitry Andric SlotIndexes *const Indexes, 1815f757f3fSDimitry Andric MLModelRunner *Runner, Logger *Log) 1825f757f3fSDimitry Andric : MLPriorityAdvisor(MF, RA, Indexes, Runner), Log(Log) {} 1835f757f3fSDimitry Andric 1845f757f3fSDimitry Andric private: 1855f757f3fSDimitry Andric unsigned getPriority(const LiveInterval &LI) const override; 1865f757f3fSDimitry Andric Logger *const Log; 1875f757f3fSDimitry Andric }; 1885f757f3fSDimitry Andric 1895f757f3fSDimitry Andric class DevelopmentModePriorityAdvisorAnalysis final 1905f757f3fSDimitry Andric : public RegAllocPriorityAdvisorAnalysis { 1915f757f3fSDimitry Andric public: 1925f757f3fSDimitry Andric DevelopmentModePriorityAdvisorAnalysis() 1935f757f3fSDimitry Andric : RegAllocPriorityAdvisorAnalysis(AdvisorMode::Development) {} 1945f757f3fSDimitry Andric // support for isa<> and dyn_cast. 1955f757f3fSDimitry Andric static bool classof(const RegAllocPriorityAdvisorAnalysis *R) { 1965f757f3fSDimitry Andric return R->getAdvisorMode() == AdvisorMode::Development; 1975f757f3fSDimitry Andric } 1985f757f3fSDimitry Andric 1995f757f3fSDimitry Andric void logRewardIfNeeded(const MachineFunction &MF, 2005f757f3fSDimitry Andric llvm::function_ref<float()> GetReward) override { 2015f757f3fSDimitry Andric if (!Log || !Log->hasAnyObservationForContext(MF.getName())) 2025f757f3fSDimitry Andric return; 2035f757f3fSDimitry Andric // The function pass manager would run all the function passes for a 2045f757f3fSDimitry Andric // function, so we assume the last context belongs to this function. If 2055f757f3fSDimitry Andric // this invariant ever changes, we can implement at that time switching 2065f757f3fSDimitry Andric // contexts. At this point, it'd be an error 2075f757f3fSDimitry Andric if (Log->currentContext() != MF.getName()) { 2085f757f3fSDimitry Andric MF.getFunction().getContext().emitError( 2095f757f3fSDimitry Andric "The training log context shouldn't have had changed."); 2105f757f3fSDimitry Andric } 2115f757f3fSDimitry Andric if (Log->hasObservationInProgress()) 2125f757f3fSDimitry Andric Log->logReward<float>(GetReward()); 2135f757f3fSDimitry Andric } 2145f757f3fSDimitry Andric 2155f757f3fSDimitry Andric private: 2165f757f3fSDimitry Andric void getAnalysisUsage(AnalysisUsage &AU) const override { 2175f757f3fSDimitry Andric AU.setPreservesAll(); 218*0fca6ea1SDimitry Andric AU.addRequired<SlotIndexesWrapperPass>(); 2195f757f3fSDimitry Andric RegAllocPriorityAdvisorAnalysis::getAnalysisUsage(AU); 2205f757f3fSDimitry Andric } 2215f757f3fSDimitry Andric 2225f757f3fSDimitry Andric // Save all the logs (when requested). 2235f757f3fSDimitry Andric bool doInitialization(Module &M) override { 2245f757f3fSDimitry Andric LLVMContext &Ctx = M.getContext(); 2255f757f3fSDimitry Andric if (ModelUnderTraining.empty() && TrainingLog.empty()) { 2265f757f3fSDimitry Andric Ctx.emitError("Regalloc development mode should be requested with at " 2275f757f3fSDimitry Andric "least logging enabled and/or a training model"); 2285f757f3fSDimitry Andric return false; 2295f757f3fSDimitry Andric } 2305f757f3fSDimitry Andric if (ModelUnderTraining.empty()) 2315f757f3fSDimitry Andric Runner = std::make_unique<NoInferenceModelRunner>(Ctx, InputFeatures); 2325f757f3fSDimitry Andric else 2335f757f3fSDimitry Andric Runner = ModelUnderTrainingRunner::createAndEnsureValid( 2345f757f3fSDimitry Andric Ctx, ModelUnderTraining, DecisionName, TrainingInputFeatures); 2355f757f3fSDimitry Andric if (!Runner) { 2365f757f3fSDimitry Andric Ctx.emitError("Regalloc: could not set up the model runner"); 2375f757f3fSDimitry Andric return false; 2385f757f3fSDimitry Andric } 2395f757f3fSDimitry Andric if (TrainingLog.empty()) 2405f757f3fSDimitry Andric return false; 2415f757f3fSDimitry Andric std::error_code EC; 2425f757f3fSDimitry Andric auto OS = std::make_unique<raw_fd_ostream>(TrainingLog, EC); 2435f757f3fSDimitry Andric if (EC) { 2445f757f3fSDimitry Andric M.getContext().emitError(EC.message() + ":" + TrainingLog); 2455f757f3fSDimitry Andric return false; 2465f757f3fSDimitry Andric } 2475f757f3fSDimitry Andric std::vector<TensorSpec> LFS = InputFeatures; 2485f757f3fSDimitry Andric if (auto *MUTR = dyn_cast<ModelUnderTrainingRunner>(Runner.get())) 2495f757f3fSDimitry Andric append_range(LFS, MUTR->extraOutputsForLoggingSpecs()); 2505f757f3fSDimitry Andric // We always log the output; in particular, if we're not evaluating, we 2515f757f3fSDimitry Andric // don't have an output spec json file. That's why we handle the 2525f757f3fSDimitry Andric // 'normal' output separately. 2535f757f3fSDimitry Andric LFS.push_back(DecisionSpec); 2545f757f3fSDimitry Andric 2555f757f3fSDimitry Andric Log = std::make_unique<Logger>(std::move(OS), LFS, Reward, 2565f757f3fSDimitry Andric /*IncludeReward*/ true); 2575f757f3fSDimitry Andric return false; 2585f757f3fSDimitry Andric } 2595f757f3fSDimitry Andric 2605f757f3fSDimitry Andric std::unique_ptr<RegAllocPriorityAdvisor> 2615f757f3fSDimitry Andric getAdvisor(const MachineFunction &MF, const RAGreedy &RA) override { 2625f757f3fSDimitry Andric if (!Runner) 2635f757f3fSDimitry Andric return nullptr; 2645f757f3fSDimitry Andric if (Log) { 2655f757f3fSDimitry Andric Log->switchContext(MF.getName()); 2665f757f3fSDimitry Andric } 2675f757f3fSDimitry Andric 2685f757f3fSDimitry Andric return std::make_unique<DevelopmentModePriorityAdvisor>( 269*0fca6ea1SDimitry Andric MF, RA, &getAnalysis<SlotIndexesWrapperPass>().getSI(), Runner.get(), 270*0fca6ea1SDimitry Andric Log.get()); 2715f757f3fSDimitry Andric } 2725f757f3fSDimitry Andric 2735f757f3fSDimitry Andric std::unique_ptr<MLModelRunner> Runner; 2745f757f3fSDimitry Andric std::unique_ptr<Logger> Log; 2755f757f3fSDimitry Andric }; 2765f757f3fSDimitry Andric #endif //#ifdef LLVM_HAVE_TFLITE 2775f757f3fSDimitry Andric 2785f757f3fSDimitry Andric } // namespace llvm 2795f757f3fSDimitry Andric 2805f757f3fSDimitry Andric RegAllocPriorityAdvisorAnalysis *llvm::createReleaseModePriorityAdvisor() { 2815f757f3fSDimitry Andric return llvm::isEmbeddedModelEvaluatorValid<CompiledModelType>() || 2825f757f3fSDimitry Andric !InteractiveChannelBaseName.empty() 2835f757f3fSDimitry Andric ? new ReleaseModePriorityAdvisorAnalysis() 2845f757f3fSDimitry Andric : nullptr; 2855f757f3fSDimitry Andric } 2865f757f3fSDimitry Andric 2875f757f3fSDimitry Andric MLPriorityAdvisor::MLPriorityAdvisor(const MachineFunction &MF, 2885f757f3fSDimitry Andric const RAGreedy &RA, 2895f757f3fSDimitry Andric SlotIndexes *const Indexes, 2905f757f3fSDimitry Andric MLModelRunner *Runner) 2915f757f3fSDimitry Andric : RegAllocPriorityAdvisor(MF, RA, Indexes), DefaultAdvisor(MF, RA, Indexes), 2925f757f3fSDimitry Andric Runner(std::move(Runner)) { 2935f757f3fSDimitry Andric assert(this->Runner); 2945f757f3fSDimitry Andric Runner->switchContext(MF.getName()); 2955f757f3fSDimitry Andric } 2965f757f3fSDimitry Andric 2975f757f3fSDimitry Andric float MLPriorityAdvisor::getPriorityImpl(const LiveInterval &LI) const { 2985f757f3fSDimitry Andric const unsigned Size = LI.getSize(); 2995f757f3fSDimitry Andric LiveRangeStage Stage = RA.getExtraInfo().getStage(LI); 3005f757f3fSDimitry Andric 3015f757f3fSDimitry Andric *Runner->getTensor<int64_t>(0) = static_cast<int64_t>(Size); 3025f757f3fSDimitry Andric *Runner->getTensor<int64_t>(1) = static_cast<int64_t>(Stage); 3035f757f3fSDimitry Andric *Runner->getTensor<float>(2) = static_cast<float>(LI.weight()); 3045f757f3fSDimitry Andric 3055f757f3fSDimitry Andric return Runner->evaluate<float>(); 3065f757f3fSDimitry Andric } 3075f757f3fSDimitry Andric 3085f757f3fSDimitry Andric unsigned MLPriorityAdvisor::getPriority(const LiveInterval &LI) const { 3095f757f3fSDimitry Andric return static_cast<unsigned>(getPriorityImpl(LI)); 3105f757f3fSDimitry Andric } 3115f757f3fSDimitry Andric 3125f757f3fSDimitry Andric #ifdef LLVM_HAVE_TFLITE 3135f757f3fSDimitry Andric RegAllocPriorityAdvisorAnalysis *llvm::createDevelopmentModePriorityAdvisor() { 3145f757f3fSDimitry Andric return new DevelopmentModePriorityAdvisorAnalysis(); 3155f757f3fSDimitry Andric } 3165f757f3fSDimitry Andric 3175f757f3fSDimitry Andric unsigned 3185f757f3fSDimitry Andric DevelopmentModePriorityAdvisor::getPriority(const LiveInterval &LI) const { 3195f757f3fSDimitry Andric double Prio = 0; 3205f757f3fSDimitry Andric 3215f757f3fSDimitry Andric if (isa<ModelUnderTrainingRunner>(getRunner())) { 3225f757f3fSDimitry Andric Prio = MLPriorityAdvisor::getPriorityImpl(LI); 3235f757f3fSDimitry Andric } else { 3245f757f3fSDimitry Andric Prio = getDefaultAdvisor().getPriority(LI); 3255f757f3fSDimitry Andric } 3265f757f3fSDimitry Andric 3275f757f3fSDimitry Andric if (TrainingLog.empty()) 3285f757f3fSDimitry Andric return Prio; 3295f757f3fSDimitry Andric 3305f757f3fSDimitry Andric // TODO(mtrofin): when we support optional rewards, this can go away. In the 3315f757f3fSDimitry Andric // meantime, we log the "pretend" reward (0) for the previous observation 3325f757f3fSDimitry Andric // before starting a new one. 3335f757f3fSDimitry Andric if (Log->hasObservationInProgress()) 3345f757f3fSDimitry Andric Log->logReward<float>(0.0); 3355f757f3fSDimitry Andric 3365f757f3fSDimitry Andric Log->startObservation(); 3375f757f3fSDimitry Andric size_t CurrentFeature = 0; 3385f757f3fSDimitry Andric for (; CurrentFeature < InputFeatures.size(); ++CurrentFeature) { 3395f757f3fSDimitry Andric Log->logTensorValue(CurrentFeature, 3405f757f3fSDimitry Andric reinterpret_cast<const char *>( 3415f757f3fSDimitry Andric getRunner().getTensorUntyped(CurrentFeature))); 3425f757f3fSDimitry Andric } 3435f757f3fSDimitry Andric 3445f757f3fSDimitry Andric if (auto *MUTR = dyn_cast<ModelUnderTrainingRunner>(&getRunner())) { 3455f757f3fSDimitry Andric for (size_t I = 0; I < MUTR->extraOutputsForLoggingSpecs().size(); 3465f757f3fSDimitry Andric ++I, ++CurrentFeature) 3475f757f3fSDimitry Andric Log->logTensorValue( 3485f757f3fSDimitry Andric CurrentFeature, 3495f757f3fSDimitry Andric reinterpret_cast<const char *>(MUTR->getUntypedExtraOutputValue(I))); 3505f757f3fSDimitry Andric } 3515f757f3fSDimitry Andric 3525f757f3fSDimitry Andric float Ret = static_cast<float>(Prio); 3535f757f3fSDimitry Andric Log->logTensorValue(CurrentFeature, reinterpret_cast<const char *>(&Ret)); 3545f757f3fSDimitry Andric Log->endObservation(); 3555f757f3fSDimitry Andric 3565f757f3fSDimitry Andric return static_cast<unsigned>(Prio); 3575f757f3fSDimitry Andric } 3585f757f3fSDimitry Andric 3595f757f3fSDimitry Andric #endif // #ifdef LLVM_HAVE_TFLITE 360