xref: /freebsd/contrib/llvm-project/llvm/lib/CodeGen/MLRegAllocPriorityAdvisor.cpp (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
15f757f3fSDimitry Andric //===- MLRegAllocPriorityAdvisor.cpp - ML priority advisor-----------------===//
25f757f3fSDimitry Andric //
35f757f3fSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
45f757f3fSDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
55f757f3fSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
65f757f3fSDimitry Andric //
75f757f3fSDimitry Andric //===----------------------------------------------------------------------===//
85f757f3fSDimitry Andric //
95f757f3fSDimitry Andric // Implementation of the ML priority advisor and reward injection pass
105f757f3fSDimitry Andric //
115f757f3fSDimitry Andric //===----------------------------------------------------------------------===//
125f757f3fSDimitry Andric 
135f757f3fSDimitry Andric #include "AllocationOrder.h"
145f757f3fSDimitry Andric #include "RegAllocGreedy.h"
155f757f3fSDimitry Andric #include "RegAllocPriorityAdvisor.h"
165f757f3fSDimitry Andric #include "llvm/Analysis/AliasAnalysis.h"
175f757f3fSDimitry Andric #include "llvm/Analysis/InteractiveModelRunner.h"
185f757f3fSDimitry Andric #include "llvm/Analysis/MLModelRunner.h"
195f757f3fSDimitry Andric #include "llvm/Analysis/ReleaseModeModelRunner.h"
205f757f3fSDimitry Andric #include "llvm/Analysis/TensorSpec.h"
215f757f3fSDimitry Andric #include "llvm/CodeGen/CalcSpillWeights.h"
225f757f3fSDimitry Andric #include "llvm/CodeGen/LiveRegMatrix.h"
235f757f3fSDimitry Andric #include "llvm/CodeGen/MachineBlockFrequencyInfo.h"
245f757f3fSDimitry Andric #include "llvm/CodeGen/MachineFunction.h"
255f757f3fSDimitry Andric #include "llvm/CodeGen/MachineLoopInfo.h"
265f757f3fSDimitry Andric #include "llvm/CodeGen/MachineRegisterInfo.h"
275f757f3fSDimitry Andric #include "llvm/CodeGen/Passes.h"
285f757f3fSDimitry Andric #include "llvm/CodeGen/RegisterClassInfo.h"
295f757f3fSDimitry Andric #include "llvm/CodeGen/SlotIndexes.h"
305f757f3fSDimitry Andric #include "llvm/CodeGen/VirtRegMap.h"
315f757f3fSDimitry Andric #include "llvm/InitializePasses.h"
325f757f3fSDimitry Andric #include "llvm/Pass.h"
335f757f3fSDimitry Andric #include "llvm/PassRegistry.h"
345f757f3fSDimitry Andric #include "llvm/Support/CommandLine.h"
355f757f3fSDimitry Andric 
365f757f3fSDimitry Andric #if defined(LLVM_HAVE_TFLITE)
375f757f3fSDimitry Andric #include "llvm/Analysis/ModelUnderTrainingRunner.h"
385f757f3fSDimitry Andric #include "llvm/Analysis/NoInferenceModelRunner.h"
395f757f3fSDimitry Andric #include "llvm/Analysis/Utils/TrainingLogger.h"
40*0fca6ea1SDimitry Andric #include "llvm/IR/Module.h"
415f757f3fSDimitry Andric #endif
425f757f3fSDimitry Andric 
435f757f3fSDimitry Andric using namespace llvm;
445f757f3fSDimitry Andric 
455f757f3fSDimitry Andric static cl::opt<std::string> InteractiveChannelBaseName(
465f757f3fSDimitry Andric     "regalloc-priority-interactive-channel-base", cl::Hidden,
475f757f3fSDimitry Andric     cl::desc(
485f757f3fSDimitry Andric         "Base file path for the interactive mode. The incoming filename should "
495f757f3fSDimitry Andric         "have the name <regalloc-priority-interactive-channel-base>.in, while "
505f757f3fSDimitry Andric         "the outgoing name should be "
515f757f3fSDimitry Andric         "<regalloc-priority-interactive-channel-base>.out"));
525f757f3fSDimitry Andric 
535f757f3fSDimitry Andric using CompiledModelType = NoopSavedModelImpl;
545f757f3fSDimitry Andric 
555f757f3fSDimitry Andric // Options that only make sense in development mode
565f757f3fSDimitry Andric #ifdef LLVM_HAVE_TFLITE
575f757f3fSDimitry Andric #include "RegAllocScore.h"
585f757f3fSDimitry Andric #include "llvm/Analysis/Utils/TFUtils.h"
595f757f3fSDimitry Andric 
605f757f3fSDimitry Andric static cl::opt<std::string> TrainingLog(
615f757f3fSDimitry Andric     "regalloc-priority-training-log", cl::Hidden,
625f757f3fSDimitry Andric     cl::desc("Training log for the register allocator priority model"));
635f757f3fSDimitry Andric 
645f757f3fSDimitry Andric static cl::opt<std::string> ModelUnderTraining(
655f757f3fSDimitry Andric     "regalloc-priority-model", cl::Hidden,
665f757f3fSDimitry Andric     cl::desc("The model being trained for register allocation priority"));
675f757f3fSDimitry Andric 
685f757f3fSDimitry Andric #endif // #ifdef LLVM_HAVE_TFLITE
695f757f3fSDimitry Andric 
705f757f3fSDimitry Andric namespace llvm {
715f757f3fSDimitry Andric 
725f757f3fSDimitry Andric static const std::vector<int64_t> PerLiveRangeShape{1};
735f757f3fSDimitry Andric 
745f757f3fSDimitry Andric #define RA_PRIORITY_FEATURES_LIST(M)                                           \
755f757f3fSDimitry Andric   M(int64_t, li_size, PerLiveRangeShape, "size")                               \
765f757f3fSDimitry Andric   M(int64_t, stage, PerLiveRangeShape, "stage")                                \
775f757f3fSDimitry Andric   M(float, weight, PerLiveRangeShape, "weight")
785f757f3fSDimitry Andric 
795f757f3fSDimitry Andric #define DecisionName "priority"
805f757f3fSDimitry Andric static const TensorSpec DecisionSpec =
815f757f3fSDimitry Andric     TensorSpec::createSpec<float>(DecisionName, {1});
825f757f3fSDimitry Andric 
835f757f3fSDimitry Andric 
845f757f3fSDimitry Andric // Named features index.
855f757f3fSDimitry Andric enum FeatureIDs {
865f757f3fSDimitry Andric #define _FEATURE_IDX(_, name, __, ___) name,
875f757f3fSDimitry Andric   RA_PRIORITY_FEATURES_LIST(_FEATURE_IDX)
885f757f3fSDimitry Andric #undef _FEATURE_IDX
895f757f3fSDimitry Andric       FeatureCount
905f757f3fSDimitry Andric };
915f757f3fSDimitry Andric 
925f757f3fSDimitry Andric class MLPriorityAdvisor : public RegAllocPriorityAdvisor {
935f757f3fSDimitry Andric public:
945f757f3fSDimitry Andric   MLPriorityAdvisor(const MachineFunction &MF, const RAGreedy &RA,
955f757f3fSDimitry Andric                     SlotIndexes *const Indexes, MLModelRunner *Runner);
965f757f3fSDimitry Andric 
975f757f3fSDimitry Andric protected:
985f757f3fSDimitry Andric   const RegAllocPriorityAdvisor &getDefaultAdvisor() const {
995f757f3fSDimitry Andric     return static_cast<const RegAllocPriorityAdvisor &>(DefaultAdvisor);
1005f757f3fSDimitry Andric   }
1015f757f3fSDimitry Andric 
1025f757f3fSDimitry Andric   // The assumption is that if the Runner could not be constructed, we emit-ed
1035f757f3fSDimitry Andric   // error, and we shouldn't be asking for it here.
1045f757f3fSDimitry Andric   const MLModelRunner &getRunner() const { return *Runner; }
1055f757f3fSDimitry Andric   float getPriorityImpl(const LiveInterval &LI) const;
1065f757f3fSDimitry Andric   unsigned getPriority(const LiveInterval &LI) const override;
1075f757f3fSDimitry Andric 
1085f757f3fSDimitry Andric private:
1095f757f3fSDimitry Andric   const DefaultPriorityAdvisor DefaultAdvisor;
1105f757f3fSDimitry Andric   MLModelRunner *const Runner;
1115f757f3fSDimitry Andric };
1125f757f3fSDimitry Andric 
1135f757f3fSDimitry Andric #define _DECL_FEATURES(type, name, shape, _)                                   \
1145f757f3fSDimitry Andric   TensorSpec::createSpec<type>(#name, shape),
1155f757f3fSDimitry Andric 
1165f757f3fSDimitry Andric static const std::vector<TensorSpec> InputFeatures{
1175f757f3fSDimitry Andric     {RA_PRIORITY_FEATURES_LIST(_DECL_FEATURES)},
1185f757f3fSDimitry Andric };
1195f757f3fSDimitry Andric #undef _DECL_FEATURES
1205f757f3fSDimitry Andric 
1215f757f3fSDimitry Andric // ===================================
1225f757f3fSDimitry Andric // Release (AOT) - specifics
1235f757f3fSDimitry Andric // ===================================
1245f757f3fSDimitry Andric class ReleaseModePriorityAdvisorAnalysis final
1255f757f3fSDimitry Andric     : public RegAllocPriorityAdvisorAnalysis {
1265f757f3fSDimitry Andric public:
1275f757f3fSDimitry Andric   ReleaseModePriorityAdvisorAnalysis()
1285f757f3fSDimitry Andric       : RegAllocPriorityAdvisorAnalysis(AdvisorMode::Release) {}
1295f757f3fSDimitry Andric   // support for isa<> and dyn_cast.
1305f757f3fSDimitry Andric   static bool classof(const RegAllocPriorityAdvisorAnalysis *R) {
1315f757f3fSDimitry Andric     return R->getAdvisorMode() == AdvisorMode::Release;
1325f757f3fSDimitry Andric   }
1335f757f3fSDimitry Andric 
1345f757f3fSDimitry Andric private:
1355f757f3fSDimitry Andric   void getAnalysisUsage(AnalysisUsage &AU) const override {
1365f757f3fSDimitry Andric     AU.setPreservesAll();
137*0fca6ea1SDimitry Andric     AU.addRequired<SlotIndexesWrapperPass>();
1385f757f3fSDimitry Andric     RegAllocPriorityAdvisorAnalysis::getAnalysisUsage(AU);
1395f757f3fSDimitry Andric   }
1405f757f3fSDimitry Andric 
1415f757f3fSDimitry Andric   std::unique_ptr<RegAllocPriorityAdvisor>
1425f757f3fSDimitry Andric   getAdvisor(const MachineFunction &MF, const RAGreedy &RA) override {
1435f757f3fSDimitry Andric     if (!Runner) {
1445f757f3fSDimitry Andric       if (InteractiveChannelBaseName.empty())
1455f757f3fSDimitry Andric         Runner = std::make_unique<ReleaseModeModelRunner<CompiledModelType>>(
1465f757f3fSDimitry Andric             MF.getFunction().getContext(), InputFeatures, DecisionName);
1475f757f3fSDimitry Andric       else
1485f757f3fSDimitry Andric         Runner = std::make_unique<InteractiveModelRunner>(
1495f757f3fSDimitry Andric             MF.getFunction().getContext(), InputFeatures, DecisionSpec,
1505f757f3fSDimitry Andric             InteractiveChannelBaseName + ".out",
1515f757f3fSDimitry Andric             InteractiveChannelBaseName + ".in");
1525f757f3fSDimitry Andric     }
1535f757f3fSDimitry Andric     return std::make_unique<MLPriorityAdvisor>(
154*0fca6ea1SDimitry Andric         MF, RA, &getAnalysis<SlotIndexesWrapperPass>().getSI(), Runner.get());
1555f757f3fSDimitry Andric   }
1565f757f3fSDimitry Andric   std::unique_ptr<MLModelRunner> Runner;
1575f757f3fSDimitry Andric };
1585f757f3fSDimitry Andric 
1595f757f3fSDimitry Andric // ===================================
1605f757f3fSDimitry Andric // Development mode-specifics
1615f757f3fSDimitry Andric // ===================================
1625f757f3fSDimitry Andric //
1635f757f3fSDimitry Andric // Features we log
1645f757f3fSDimitry Andric #ifdef LLVM_HAVE_TFLITE
1655f757f3fSDimitry Andric static const TensorSpec Reward = TensorSpec::createSpec<float>("reward", {1});
1665f757f3fSDimitry Andric 
1675f757f3fSDimitry Andric #define _DECL_TRAIN_FEATURES(type, name, shape, _)                             \
1685f757f3fSDimitry Andric   TensorSpec::createSpec<type>(std::string("action_") + #name, shape),
1695f757f3fSDimitry Andric 
1705f757f3fSDimitry Andric static const std::vector<TensorSpec> TrainingInputFeatures{
1715f757f3fSDimitry Andric     {RA_PRIORITY_FEATURES_LIST(_DECL_TRAIN_FEATURES)
1725f757f3fSDimitry Andric          TensorSpec::createSpec<float>("action_discount", {1}),
1735f757f3fSDimitry Andric      TensorSpec::createSpec<int32_t>("action_step_type", {1}),
1745f757f3fSDimitry Andric      TensorSpec::createSpec<float>("action_reward", {1})}};
1755f757f3fSDimitry Andric #undef _DECL_TRAIN_FEATURES
1765f757f3fSDimitry Andric 
1775f757f3fSDimitry Andric class DevelopmentModePriorityAdvisor : public MLPriorityAdvisor {
1785f757f3fSDimitry Andric public:
1795f757f3fSDimitry Andric   DevelopmentModePriorityAdvisor(const MachineFunction &MF, const RAGreedy &RA,
1805f757f3fSDimitry Andric                                  SlotIndexes *const Indexes,
1815f757f3fSDimitry Andric                                  MLModelRunner *Runner, Logger *Log)
1825f757f3fSDimitry Andric       : MLPriorityAdvisor(MF, RA, Indexes, Runner), Log(Log) {}
1835f757f3fSDimitry Andric 
1845f757f3fSDimitry Andric private:
1855f757f3fSDimitry Andric   unsigned getPriority(const LiveInterval &LI) const override;
1865f757f3fSDimitry Andric   Logger *const Log;
1875f757f3fSDimitry Andric };
1885f757f3fSDimitry Andric 
1895f757f3fSDimitry Andric class DevelopmentModePriorityAdvisorAnalysis final
1905f757f3fSDimitry Andric     : public RegAllocPriorityAdvisorAnalysis {
1915f757f3fSDimitry Andric public:
1925f757f3fSDimitry Andric   DevelopmentModePriorityAdvisorAnalysis()
1935f757f3fSDimitry Andric       : RegAllocPriorityAdvisorAnalysis(AdvisorMode::Development) {}
1945f757f3fSDimitry Andric   // support for isa<> and dyn_cast.
1955f757f3fSDimitry Andric   static bool classof(const RegAllocPriorityAdvisorAnalysis *R) {
1965f757f3fSDimitry Andric     return R->getAdvisorMode() == AdvisorMode::Development;
1975f757f3fSDimitry Andric   }
1985f757f3fSDimitry Andric 
1995f757f3fSDimitry Andric   void logRewardIfNeeded(const MachineFunction &MF,
2005f757f3fSDimitry Andric                          llvm::function_ref<float()> GetReward) override {
2015f757f3fSDimitry Andric     if (!Log || !Log->hasAnyObservationForContext(MF.getName()))
2025f757f3fSDimitry Andric       return;
2035f757f3fSDimitry Andric     // The function pass manager would run all the function passes for a
2045f757f3fSDimitry Andric     // function, so we assume the last context belongs to this function. If
2055f757f3fSDimitry Andric     // this invariant ever changes, we can implement at that time switching
2065f757f3fSDimitry Andric     // contexts. At this point, it'd be an error
2075f757f3fSDimitry Andric     if (Log->currentContext() != MF.getName()) {
2085f757f3fSDimitry Andric       MF.getFunction().getContext().emitError(
2095f757f3fSDimitry Andric           "The training log context shouldn't have had changed.");
2105f757f3fSDimitry Andric     }
2115f757f3fSDimitry Andric     if (Log->hasObservationInProgress())
2125f757f3fSDimitry Andric       Log->logReward<float>(GetReward());
2135f757f3fSDimitry Andric   }
2145f757f3fSDimitry Andric 
2155f757f3fSDimitry Andric private:
2165f757f3fSDimitry Andric   void getAnalysisUsage(AnalysisUsage &AU) const override {
2175f757f3fSDimitry Andric     AU.setPreservesAll();
218*0fca6ea1SDimitry Andric     AU.addRequired<SlotIndexesWrapperPass>();
2195f757f3fSDimitry Andric     RegAllocPriorityAdvisorAnalysis::getAnalysisUsage(AU);
2205f757f3fSDimitry Andric   }
2215f757f3fSDimitry Andric 
2225f757f3fSDimitry Andric   // Save all the logs (when requested).
2235f757f3fSDimitry Andric   bool doInitialization(Module &M) override {
2245f757f3fSDimitry Andric     LLVMContext &Ctx = M.getContext();
2255f757f3fSDimitry Andric     if (ModelUnderTraining.empty() && TrainingLog.empty()) {
2265f757f3fSDimitry Andric       Ctx.emitError("Regalloc development mode should be requested with at "
2275f757f3fSDimitry Andric                     "least logging enabled and/or a training model");
2285f757f3fSDimitry Andric       return false;
2295f757f3fSDimitry Andric     }
2305f757f3fSDimitry Andric     if (ModelUnderTraining.empty())
2315f757f3fSDimitry Andric       Runner = std::make_unique<NoInferenceModelRunner>(Ctx, InputFeatures);
2325f757f3fSDimitry Andric     else
2335f757f3fSDimitry Andric       Runner = ModelUnderTrainingRunner::createAndEnsureValid(
2345f757f3fSDimitry Andric           Ctx, ModelUnderTraining, DecisionName, TrainingInputFeatures);
2355f757f3fSDimitry Andric     if (!Runner) {
2365f757f3fSDimitry Andric       Ctx.emitError("Regalloc: could not set up the model runner");
2375f757f3fSDimitry Andric       return false;
2385f757f3fSDimitry Andric     }
2395f757f3fSDimitry Andric     if (TrainingLog.empty())
2405f757f3fSDimitry Andric       return false;
2415f757f3fSDimitry Andric     std::error_code EC;
2425f757f3fSDimitry Andric     auto OS = std::make_unique<raw_fd_ostream>(TrainingLog, EC);
2435f757f3fSDimitry Andric     if (EC) {
2445f757f3fSDimitry Andric       M.getContext().emitError(EC.message() + ":" + TrainingLog);
2455f757f3fSDimitry Andric       return false;
2465f757f3fSDimitry Andric     }
2475f757f3fSDimitry Andric     std::vector<TensorSpec> LFS = InputFeatures;
2485f757f3fSDimitry Andric     if (auto *MUTR = dyn_cast<ModelUnderTrainingRunner>(Runner.get()))
2495f757f3fSDimitry Andric       append_range(LFS, MUTR->extraOutputsForLoggingSpecs());
2505f757f3fSDimitry Andric     // We always log the output; in particular, if we're not evaluating, we
2515f757f3fSDimitry Andric     // don't have an output spec json file. That's why we handle the
2525f757f3fSDimitry Andric     // 'normal' output separately.
2535f757f3fSDimitry Andric     LFS.push_back(DecisionSpec);
2545f757f3fSDimitry Andric 
2555f757f3fSDimitry Andric     Log = std::make_unique<Logger>(std::move(OS), LFS, Reward,
2565f757f3fSDimitry Andric                                    /*IncludeReward*/ true);
2575f757f3fSDimitry Andric     return false;
2585f757f3fSDimitry Andric   }
2595f757f3fSDimitry Andric 
2605f757f3fSDimitry Andric   std::unique_ptr<RegAllocPriorityAdvisor>
2615f757f3fSDimitry Andric   getAdvisor(const MachineFunction &MF, const RAGreedy &RA) override {
2625f757f3fSDimitry Andric     if (!Runner)
2635f757f3fSDimitry Andric       return nullptr;
2645f757f3fSDimitry Andric     if (Log) {
2655f757f3fSDimitry Andric       Log->switchContext(MF.getName());
2665f757f3fSDimitry Andric     }
2675f757f3fSDimitry Andric 
2685f757f3fSDimitry Andric     return std::make_unique<DevelopmentModePriorityAdvisor>(
269*0fca6ea1SDimitry Andric         MF, RA, &getAnalysis<SlotIndexesWrapperPass>().getSI(), Runner.get(),
270*0fca6ea1SDimitry Andric         Log.get());
2715f757f3fSDimitry Andric   }
2725f757f3fSDimitry Andric 
2735f757f3fSDimitry Andric   std::unique_ptr<MLModelRunner> Runner;
2745f757f3fSDimitry Andric   std::unique_ptr<Logger> Log;
2755f757f3fSDimitry Andric };
2765f757f3fSDimitry Andric #endif //#ifdef LLVM_HAVE_TFLITE
2775f757f3fSDimitry Andric 
2785f757f3fSDimitry Andric } // namespace llvm
2795f757f3fSDimitry Andric 
2805f757f3fSDimitry Andric RegAllocPriorityAdvisorAnalysis *llvm::createReleaseModePriorityAdvisor() {
2815f757f3fSDimitry Andric   return llvm::isEmbeddedModelEvaluatorValid<CompiledModelType>() ||
2825f757f3fSDimitry Andric                  !InteractiveChannelBaseName.empty()
2835f757f3fSDimitry Andric              ? new ReleaseModePriorityAdvisorAnalysis()
2845f757f3fSDimitry Andric              : nullptr;
2855f757f3fSDimitry Andric }
2865f757f3fSDimitry Andric 
2875f757f3fSDimitry Andric MLPriorityAdvisor::MLPriorityAdvisor(const MachineFunction &MF,
2885f757f3fSDimitry Andric                                      const RAGreedy &RA,
2895f757f3fSDimitry Andric                                      SlotIndexes *const Indexes,
2905f757f3fSDimitry Andric                                      MLModelRunner *Runner)
2915f757f3fSDimitry Andric     : RegAllocPriorityAdvisor(MF, RA, Indexes), DefaultAdvisor(MF, RA, Indexes),
2925f757f3fSDimitry Andric       Runner(std::move(Runner)) {
2935f757f3fSDimitry Andric   assert(this->Runner);
2945f757f3fSDimitry Andric   Runner->switchContext(MF.getName());
2955f757f3fSDimitry Andric }
2965f757f3fSDimitry Andric 
2975f757f3fSDimitry Andric float MLPriorityAdvisor::getPriorityImpl(const LiveInterval &LI) const {
2985f757f3fSDimitry Andric   const unsigned Size = LI.getSize();
2995f757f3fSDimitry Andric   LiveRangeStage Stage = RA.getExtraInfo().getStage(LI);
3005f757f3fSDimitry Andric 
3015f757f3fSDimitry Andric   *Runner->getTensor<int64_t>(0) = static_cast<int64_t>(Size);
3025f757f3fSDimitry Andric   *Runner->getTensor<int64_t>(1) = static_cast<int64_t>(Stage);
3035f757f3fSDimitry Andric   *Runner->getTensor<float>(2) = static_cast<float>(LI.weight());
3045f757f3fSDimitry Andric 
3055f757f3fSDimitry Andric   return Runner->evaluate<float>();
3065f757f3fSDimitry Andric }
3075f757f3fSDimitry Andric 
3085f757f3fSDimitry Andric unsigned MLPriorityAdvisor::getPriority(const LiveInterval &LI) const {
3095f757f3fSDimitry Andric   return static_cast<unsigned>(getPriorityImpl(LI));
3105f757f3fSDimitry Andric }
3115f757f3fSDimitry Andric 
3125f757f3fSDimitry Andric #ifdef LLVM_HAVE_TFLITE
3135f757f3fSDimitry Andric RegAllocPriorityAdvisorAnalysis *llvm::createDevelopmentModePriorityAdvisor() {
3145f757f3fSDimitry Andric   return new DevelopmentModePriorityAdvisorAnalysis();
3155f757f3fSDimitry Andric }
3165f757f3fSDimitry Andric 
3175f757f3fSDimitry Andric unsigned
3185f757f3fSDimitry Andric DevelopmentModePriorityAdvisor::getPriority(const LiveInterval &LI) const {
3195f757f3fSDimitry Andric   double Prio = 0;
3205f757f3fSDimitry Andric 
3215f757f3fSDimitry Andric   if (isa<ModelUnderTrainingRunner>(getRunner())) {
3225f757f3fSDimitry Andric     Prio = MLPriorityAdvisor::getPriorityImpl(LI);
3235f757f3fSDimitry Andric   } else {
3245f757f3fSDimitry Andric     Prio = getDefaultAdvisor().getPriority(LI);
3255f757f3fSDimitry Andric   }
3265f757f3fSDimitry Andric 
3275f757f3fSDimitry Andric   if (TrainingLog.empty())
3285f757f3fSDimitry Andric     return Prio;
3295f757f3fSDimitry Andric 
3305f757f3fSDimitry Andric   // TODO(mtrofin): when we support optional rewards, this can go away. In the
3315f757f3fSDimitry Andric   // meantime, we log the "pretend" reward (0) for the previous observation
3325f757f3fSDimitry Andric   // before starting a new one.
3335f757f3fSDimitry Andric   if (Log->hasObservationInProgress())
3345f757f3fSDimitry Andric     Log->logReward<float>(0.0);
3355f757f3fSDimitry Andric 
3365f757f3fSDimitry Andric   Log->startObservation();
3375f757f3fSDimitry Andric   size_t CurrentFeature = 0;
3385f757f3fSDimitry Andric   for (; CurrentFeature < InputFeatures.size(); ++CurrentFeature) {
3395f757f3fSDimitry Andric     Log->logTensorValue(CurrentFeature,
3405f757f3fSDimitry Andric                         reinterpret_cast<const char *>(
3415f757f3fSDimitry Andric                             getRunner().getTensorUntyped(CurrentFeature)));
3425f757f3fSDimitry Andric   }
3435f757f3fSDimitry Andric 
3445f757f3fSDimitry Andric   if (auto *MUTR = dyn_cast<ModelUnderTrainingRunner>(&getRunner())) {
3455f757f3fSDimitry Andric     for (size_t I = 0; I < MUTR->extraOutputsForLoggingSpecs().size();
3465f757f3fSDimitry Andric          ++I, ++CurrentFeature)
3475f757f3fSDimitry Andric       Log->logTensorValue(
3485f757f3fSDimitry Andric           CurrentFeature,
3495f757f3fSDimitry Andric           reinterpret_cast<const char *>(MUTR->getUntypedExtraOutputValue(I)));
3505f757f3fSDimitry Andric   }
3515f757f3fSDimitry Andric 
3525f757f3fSDimitry Andric   float Ret = static_cast<float>(Prio);
3535f757f3fSDimitry Andric   Log->logTensorValue(CurrentFeature, reinterpret_cast<const char *>(&Ret));
3545f757f3fSDimitry Andric   Log->endObservation();
3555f757f3fSDimitry Andric 
3565f757f3fSDimitry Andric   return static_cast<unsigned>(Prio);
3575f757f3fSDimitry Andric }
3585f757f3fSDimitry Andric 
3595f757f3fSDimitry Andric #endif // #ifdef LLVM_HAVE_TFLITE
360