xref: /freebsd/contrib/llvm-project/llvm/lib/CodeGen/MLRegAllocEvictAdvisor.cpp (revision 357378bbdedf24ce2b90e9bd831af4a9db3ec70a)
1 //===- MLRegAllocEvictAdvisor.cpp - ML eviction advisor -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Implementation of the ML eviction advisor and reward injection pass
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "AllocationOrder.h"
14 #include "RegAllocEvictionAdvisor.h"
15 #include "RegAllocGreedy.h"
16 #include "llvm/Analysis/InteractiveModelRunner.h"
17 #include "llvm/Analysis/MLModelRunner.h"
18 #include "llvm/Analysis/TensorSpec.h"
19 #if defined(LLVM_HAVE_TF_AOT_REGALLOCEVICTMODEL) || defined(LLVM_HAVE_TFLITE)
20 #include "llvm/Analysis/ModelUnderTrainingRunner.h"
21 #include "llvm/Analysis/NoInferenceModelRunner.h"
22 #include "llvm/Analysis/Utils/TrainingLogger.h"
23 #endif
24 #include "MLRegAllocEvictAdvisor.h"
25 #include "llvm/Analysis/ReleaseModeModelRunner.h"
26 #include "llvm/CodeGen/CalcSpillWeights.h"
27 #include "llvm/CodeGen/LiveRegMatrix.h"
28 #include "llvm/CodeGen/MachineBlockFrequencyInfo.h"
29 #include "llvm/CodeGen/MachineFunction.h"
30 #include "llvm/CodeGen/MachineLoopInfo.h"
31 #include "llvm/CodeGen/MachineRegisterInfo.h"
32 #include "llvm/CodeGen/Passes.h"
33 #include "llvm/CodeGen/RegisterClassInfo.h"
34 #include "llvm/CodeGen/VirtRegMap.h"
35 #include "llvm/InitializePasses.h"
36 #include "llvm/Pass.h"
37 #include "llvm/PassRegistry.h"
38 #include "llvm/Support/CommandLine.h"
39 #include "llvm/Support/ErrorHandling.h"
40 
41 #include <array>
42 #include <bitset>
43 #include <memory>
44 
45 using namespace llvm;
46 
47 #define DEBUG_TYPE "ml-regalloc"
48 
49 // Generated header in release (AOT) mode
50 #if defined(LLVM_HAVE_TF_AOT_REGALLOCEVICTMODEL)
51 #include "RegAllocEvictModel.h"
52 using CompiledModelType = RegAllocEvictModel;
53 #else
54 using CompiledModelType = NoopSavedModelImpl;
55 #endif
56 
57 static cl::opt<std::string> InteractiveChannelBaseName(
58     "regalloc-evict-interactive-channel-base", cl::Hidden,
59     cl::desc(
60         "Base file path for the interactive mode. The incoming filename should "
61         "have the name <regalloc-evict-interactive-channel-base>.in, while the "
62         "outgoing name should be "
63         "<regalloc-evict-interactive-channel-base>.out"));
64 
65 // Options that only make sense in development mode
66 #ifdef LLVM_HAVE_TFLITE
67 #include "RegAllocScore.h"
68 #include "llvm/Analysis/Utils/TFUtils.h"
69 
70 static cl::opt<std::string> TrainingLog(
71     "regalloc-training-log", cl::Hidden,
72     cl::desc("Training log for the register allocator eviction model"));
73 
74 static cl::opt<std::string> ModelUnderTraining(
75     "regalloc-model", cl::Hidden,
76     cl::desc("The model being trained for register allocation eviction"));
77 
78 static cl::opt<bool> EnableDevelopmentFeatures(
79     "regalloc-enable-development-features", cl::Hidden,
80     cl::desc("Whether or not to enable features under development for the ML "
81              "regalloc advisor"));
82 
83 #else
84 static const bool EnableDevelopmentFeatures = false;
85 #endif // #ifdef LLVM_HAVE_TFLITE
86 
87 /// The score injection pass.
88 /// This pass calculates the score for a function and inserts it in the log, but
89 /// this happens only in development mode. It's a no-op otherwise.
90 namespace llvm {
91 extern cl::opt<unsigned> EvictInterferenceCutoff;
92 
93 class RegAllocScoring : public MachineFunctionPass {
94 public:
95   static char ID;
96 
97   RegAllocScoring() : MachineFunctionPass(ID) {
98     initializeRegAllocScoringPass(*PassRegistry::getPassRegistry());
99   }
100 
101   ~RegAllocScoring() override = default;
102 
103   StringRef getPassName() const override {
104     return "Register Allocation Pass Scoring";
105   }
106 
107   /// RegAllocReward analysis usage.
108   void getAnalysisUsage(AnalysisUsage &AU) const override {
109     AU.setPreservesAll();
110     AU.addRequired<RegAllocEvictionAdvisorAnalysis>();
111     AU.addRequired<RegAllocPriorityAdvisorAnalysis>();
112     AU.addRequired<MachineBlockFrequencyInfo>();
113     MachineFunctionPass::getAnalysisUsage(AU);
114   }
115 
116   /// Performs this pass
117   bool runOnMachineFunction(MachineFunction &) override;
118 };
119 
120 char RegAllocScoring::ID = 0;
121 FunctionPass *createRegAllocScoringPass() { return new RegAllocScoring(); }
122 
123 } // namespace llvm
124 
125 INITIALIZE_PASS(RegAllocScoring, "regallocscoringpass",
126                 "Register Allocation Scoring Pass", false, false)
127 
128 // ===================================
129 // Common ML Advisor declarations
130 // ===================================
131 namespace {
132 // The model can only accept a specified number of opcodes and will error it if
133 // fed an opcode it hasn't seen before. This constant sets the current cutoff.
134 static const int OpcodeValueCutoff = 17716;
135 
136 // Most features are as described above, so we'll reuse this vector in defining
137 // them.
138 static const std::vector<int64_t> PerLiveRangeShape{1, NumberOfInterferences};
139 
140 // --------------
141 // Features table
142 // --------------
143 // For each interfering live range (incl. the candidate) we collect a number of
144 // features. However, because the features are of different types (and because
145 // of ML best practices), we organize the tensors per feature, not per
146 // candidate. Each such tensor has a scalar value corresponding to the
147 // interferring live range at that position, in the order in AllocationOrder.
148 // The last position corresponds to the virt reg seeking allocation.
149 // Exception to all that is the progression feature, which is just a scalar (see
150 // its documentation for details).
151 // Note on naming: the "_by_max" are normalized using the largest value of that
152 // tensor, as observed in the current decision making stage (i.e. for the
153 // current call to the advisor's tryFindEvictionCandidate)
154 //
155 // The feature list format: type, name, shape, documentation.
156 // Note: we can really just use int64 and float, hence the modeling of some
157 // bools as int64 values.
158 #define RA_EVICT_FEATURES_LIST(M)                                              \
159   M(int64_t, mask, PerLiveRangeShape,                                          \
160     "boolean values, 0 for unavailable candidates (i.e. if a position is 0, "  \
161     "it "                                                                      \
162     "can't be evicted)")                                                       \
163   M(int64_t, is_free, PerLiveRangeShape,                                       \
164     "boolean values, 1 if this phys reg is actually free (no interferences)")  \
165   M(float, nr_urgent, PerLiveRangeShape,                                       \
166     "number of 'urgent' intervals, normalized. Urgent are those that are OK "  \
167     "to break cascades")                                                       \
168   M(float, nr_broken_hints, PerLiveRangeShape,                                 \
169     "if this position were evicted, how many broken hints would there be")     \
170   M(int64_t, is_hint, PerLiveRangeShape,                                       \
171     "is this a preferred phys reg for the candidate")                          \
172   M(int64_t, is_local, PerLiveRangeShape,                                      \
173     "is this live range local to a basic block")                               \
174   M(float, nr_rematerializable, PerLiveRangeShape,                             \
175     "nr rematerializable ranges")                                              \
176   M(float, nr_defs_and_uses, PerLiveRangeShape,                                \
177     "bb freq - weighed nr defs and uses")                                      \
178   M(float, weighed_reads_by_max, PerLiveRangeShape,                            \
179     "bb freq - weighed nr of reads, normalized")                               \
180   M(float, weighed_writes_by_max, PerLiveRangeShape,                           \
181     "bb feq - weighed nr of writes, normalized")                               \
182   M(float, weighed_read_writes_by_max, PerLiveRangeShape,                      \
183     "bb freq - weighed nr of uses that are both read and writes, normalized")  \
184   M(float, weighed_indvars_by_max, PerLiveRangeShape,                          \
185     "bb freq - weighed nr of uses that are indvars, normalized")               \
186   M(float, hint_weights_by_max, PerLiveRangeShape,                             \
187     "bb freq - weighed nr of uses that are hints, normalized")                 \
188   M(float, start_bb_freq_by_max, PerLiveRangeShape,                            \
189     "the freq in the start block, normalized")                                 \
190   M(float, end_bb_freq_by_max, PerLiveRangeShape,                              \
191     "freq of end block, normalized")                                           \
192   M(float, hottest_bb_freq_by_max, PerLiveRangeShape,                          \
193     "hottest BB freq, normalized")                                             \
194   M(float, liverange_size, PerLiveRangeShape,                                  \
195     "size (instr index diff) of the LR")                                       \
196   M(float, use_def_density, PerLiveRangeShape,                                 \
197     "the max weight, as computed by the manual heuristic")                     \
198   M(int64_t, max_stage, PerLiveRangeShape,                                     \
199     "largest stage of an interval in this LR")                                 \
200   M(int64_t, min_stage, PerLiveRangeShape,                                     \
201     "lowest stage of an interval in this LR")                                  \
202   M(float, progress, {1}, "ratio of current queue size to initial size")
203 
204 #ifdef LLVM_HAVE_TFLITE
205 #define RA_EVICT_FIRST_DEVELOPMENT_FEATURE(M)                                  \
206   M(int64_t, instructions, InstructionsShape,                                  \
207     "Opcodes of the instructions covered by the eviction problem")
208 
209 #define RA_EVICT_REST_DEVELOPMENT_FEATURES(M)                                  \
210   M(int64_t, instructions_mapping, InstructionsMappingShape,                   \
211     "A binary matrix mapping LRs to instruction opcodes")                      \
212   M(float, mbb_frequencies, MBBFrequencyShape,                                 \
213     "A vector of machine basic block frequencies")                             \
214   M(int64_t, mbb_mapping, InstructionsShape,                                   \
215     "A vector of indicies mapping instructions to MBBs")
216 #else
217 #define RA_EVICT_FIRST_DEVELOPMENT_FEATURE(M)
218 #define RA_EVICT_REST_DEVELOPMENT_FEATURES(M)
219 #endif
220 
221 // The model learns to pick one of the mask == 1 interferences. This is the
222 // name of the output tensor. The contract with the model is that the output
223 // will be guaranteed to be to a mask == 1 position. Using a macro here to
224 // avoid 'not used' warnings (and keep cond compilation to a minimum)
225 #define DecisionName "index_to_evict"
226 static const TensorSpec DecisionSpec =
227     TensorSpec::createSpec<int64_t>(DecisionName, {1});
228 
229 // Named features index.
230 enum FeatureIDs {
231 #define _FEATURE_IDX_SIMPLE(_, name, __, ___) name
232 #define _FEATURE_IDX(A, B, C, D) _FEATURE_IDX_SIMPLE(A, B, C, D),
233   RA_EVICT_FEATURES_LIST(_FEATURE_IDX) FeatureCount,
234 #ifdef LLVM_HAVE_TFLITE
235   RA_EVICT_FIRST_DEVELOPMENT_FEATURE(_FEATURE_IDX_SIMPLE) = FeatureCount,
236 #else
237   RA_EVICT_FIRST_DEVELOPMENT_FEATURE(_FEATURE_IDX)
238 #endif // #ifdef LLVM_HAVE_TFLITE
239   RA_EVICT_REST_DEVELOPMENT_FEATURES(_FEATURE_IDX) FeaturesWithDevelopmentCount
240 #undef _FEATURE_IDX
241 #undef _FEATURE_IDX_SIMPLE
242 };
243 
244 // The ML advisor will typically have a sparse input to the evaluator, because
245 // various phys regs won't be available. It's easier (maintenance-wise) to
246 // bulk-reset the state of the evaluator each time we are about to use it
247 // again.
248 template <typename T> size_t getTotalSize(const std::vector<int64_t> &Shape) {
249   size_t Ret = sizeof(T);
250   for (const auto V : Shape)
251     Ret *= V;
252   return Ret;
253 }
254 
255 void resetInputs(MLModelRunner &Runner) {
256 #define _RESET(TYPE, NAME, SHAPE, __)                                          \
257   std::memset(Runner.getTensorUntyped(FeatureIDs::NAME), 0,                    \
258               getTotalSize<TYPE>(SHAPE));
259   RA_EVICT_FEATURES_LIST(_RESET)
260   if (EnableDevelopmentFeatures) {
261     RA_EVICT_FIRST_DEVELOPMENT_FEATURE(_RESET)
262     RA_EVICT_REST_DEVELOPMENT_FEATURES(_RESET)
263 #undef _RESET
264   }
265 }
266 
267 // Per-live interval components that get aggregated into the feature values
268 // that will be passed to the evaluator.
269 struct LIFeatureComponents {
270   double R = 0;
271   double W = 0;
272   double RW = 0;
273   double IndVarUpdates = 0;
274   double HintWeights = 0.0;
275   int64_t NrDefsAndUses = 0;
276   float HottestBlockFreq = 0.0;
277   bool IsRemat = false;
278 };
279 
280 using CandidateRegList =
281     std::array<std::pair<MCRegister, bool>, NumberOfInterferences>;
282 using FeaturesListNormalizer =
283     llvm::SmallVector<float, FeatureIDs::FeatureCount>;
284 
285 /// The ML evictor (commonalities between release and development mode)
286 class MLEvictAdvisor : public RegAllocEvictionAdvisor {
287 public:
288   MLEvictAdvisor(const MachineFunction &MF, const RAGreedy &RA,
289                  MLModelRunner *Runner, const MachineBlockFrequencyInfo &MBFI,
290                  const MachineLoopInfo &Loops);
291 
292 protected:
293   const RegAllocEvictionAdvisor &getDefaultAdvisor() const {
294     return static_cast<const RegAllocEvictionAdvisor &>(DefaultAdvisor);
295   }
296 
297   // The assumption is that if the Runner could not be constructed, we emit-ed
298   // error, and we shouldn't be asking for it here.
299   const MLModelRunner &getRunner() const { return *Runner; }
300 
301   /// This just calls Evaluate on the Runner, but in the development mode
302   /// case, if we're just capturing the log of the default advisor, it needs
303   /// to call the latter instead, so we need to pass all the necessary
304   /// parameters for it. In the development case, it will also log.
305   virtual int64_t
306   tryFindEvictionCandidatePosition(const LiveInterval &VirtReg,
307                                    const AllocationOrder &Order,
308                                    unsigned OrderLimit, uint8_t CostPerUseLimit,
309                                    const SmallVirtRegSet &FixedRegisters) const;
310 
311   /// Load the features of the given VirtReg (allocated or not) at column Pos,
312   /// but if  that can't be evicted, return false instead.
313   bool
314   loadInterferenceFeatures(const LiveInterval &VirtReg, MCRegister PhysReg,
315                            bool IsHint, const SmallVirtRegSet &FixedRegisters,
316                            llvm::SmallVectorImpl<float> &Largest, size_t Pos,
317                            SmallVectorImpl<LRStartEndInfo> &LRPosInfo) const;
318 
319 private:
320   static float getInitialQueueSize(const MachineFunction &MF);
321 
322   MCRegister tryFindEvictionCandidate(
323       const LiveInterval &VirtReg, const AllocationOrder &Order,
324       uint8_t CostPerUseLimit,
325       const SmallVirtRegSet &FixedRegisters) const override;
326 
327   void extractFeatures(const SmallVectorImpl<const LiveInterval *> &Intervals,
328                        llvm::SmallVectorImpl<float> &Largest, size_t Pos,
329                        int64_t IsHint, int64_t LocalIntfsCount, float NrUrgent,
330                        SmallVectorImpl<LRStartEndInfo> &LRPosInfo) const;
331 
332   // Point-in-time: we didn't learn this, so we always delegate to the
333   // default.
334   bool canEvictHintInterference(
335       const LiveInterval &VirtReg, MCRegister PhysReg,
336       const SmallVirtRegSet &FixedRegisters) const override {
337     return getDefaultAdvisor().canEvictHintInterference(VirtReg, PhysReg,
338                                                         FixedRegisters);
339   }
340 
341   const LIFeatureComponents &
342   getLIFeatureComponents(const LiveInterval &LI) const;
343 
344   // Hold on to a default advisor for:
345   // 1) the implementation of canEvictHintInterference, because we didn't
346   // learn that nuance yet; 2) for bootstrapping (logging) in the development
347   // mode case.
348   const DefaultEvictionAdvisor DefaultAdvisor;
349   MLModelRunner *const Runner;
350   const MachineBlockFrequencyInfo &MBFI;
351   const MachineLoopInfo &Loops;
352 
353   // Indices of those features we don't want to normalize.
354   // This could be static and shared, but its initialization is non-trivial.
355   std::bitset<FeatureIDs::FeatureCount> DoNotNormalize;
356   const float InitialQSize;
357 
358   using RegID = unsigned;
359   mutable DenseMap<RegID, LIFeatureComponents> CachedFeatures;
360 };
361 
362 #define _DECL_FEATURES(type, name, shape, _)                                   \
363   TensorSpec::createSpec<type>(#name, shape),
364 
365 // ===================================
366 // Release (AOT) - specifics
367 // ===================================
368 class ReleaseModeEvictionAdvisorAnalysis final
369     : public RegAllocEvictionAdvisorAnalysis {
370 public:
371   ReleaseModeEvictionAdvisorAnalysis()
372       : RegAllocEvictionAdvisorAnalysis(AdvisorMode::Release) {
373     if (EnableDevelopmentFeatures) {
374       InputFeatures = {RA_EVICT_FEATURES_LIST(
375           _DECL_FEATURES) RA_EVICT_FIRST_DEVELOPMENT_FEATURE(_DECL_FEATURES)
376                            RA_EVICT_REST_DEVELOPMENT_FEATURES(_DECL_FEATURES)};
377     } else {
378       InputFeatures = {RA_EVICT_FEATURES_LIST(_DECL_FEATURES)};
379     }
380   }
381   // support for isa<> and dyn_cast.
382   static bool classof(const RegAllocEvictionAdvisorAnalysis *R) {
383     return R->getAdvisorMode() == AdvisorMode::Release;
384   }
385 
386 private:
387   std::vector<TensorSpec> InputFeatures;
388 
389   void getAnalysisUsage(AnalysisUsage &AU) const override {
390     AU.addRequired<MachineBlockFrequencyInfo>();
391     AU.addRequired<MachineLoopInfo>();
392     RegAllocEvictionAdvisorAnalysis::getAnalysisUsage(AU);
393   }
394 
395   std::unique_ptr<RegAllocEvictionAdvisor>
396   getAdvisor(const MachineFunction &MF, const RAGreedy &RA) override {
397     if (!Runner) {
398       if (InteractiveChannelBaseName.empty())
399         Runner = std::make_unique<ReleaseModeModelRunner<CompiledModelType>>(
400             MF.getFunction().getContext(), InputFeatures, DecisionName);
401       else
402         Runner = std::make_unique<InteractiveModelRunner>(
403             MF.getFunction().getContext(), InputFeatures, DecisionSpec,
404             InteractiveChannelBaseName + ".out",
405             InteractiveChannelBaseName + ".in");
406     }
407     return std::make_unique<MLEvictAdvisor>(
408         MF, RA, Runner.get(), getAnalysis<MachineBlockFrequencyInfo>(),
409         getAnalysis<MachineLoopInfo>());
410   }
411   std::unique_ptr<MLModelRunner> Runner;
412 };
413 
414 // ===================================
415 // Development mode-specifics
416 // ===================================
417 //
418 // Features we log
419 #ifdef LLVM_HAVE_TFLITE
420 static const TensorSpec Reward = TensorSpec::createSpec<float>("reward", {1});
421 
422 // Features we bind on the model. The tensor names have a prefix, and we also
423 // need to include some tensors that are expected to be present by the
424 // training algo.
425 // TODO: can we just get rid of these?
426 #define _DECL_TRAIN_FEATURES(type, name, shape, _)                             \
427   TensorSpec::createSpec<type>(std::string("action_") + #name, shape),
428 
429 class DevelopmentModeEvictAdvisor : public MLEvictAdvisor {
430 public:
431   DevelopmentModeEvictAdvisor(const MachineFunction &MF, const RAGreedy &RA,
432                               MLModelRunner *Runner,
433                               const MachineBlockFrequencyInfo &MBFI,
434                               const MachineLoopInfo &Loops, Logger *Log)
435       : MLEvictAdvisor(MF, RA, Runner, MBFI, Loops), Log(Log) {}
436 
437 private:
438   int64_t tryFindEvictionCandidatePosition(
439       const LiveInterval &VirtReg, const AllocationOrder &Order,
440       unsigned OrderLimit, uint8_t CostPerUseLimit,
441       const SmallVirtRegSet &FixedRegisters) const override;
442 
443   Logger *const Log;
444 };
445 
446 class DevelopmentModeEvictionAdvisorAnalysis final
447     : public RegAllocEvictionAdvisorAnalysis {
448 public:
449   DevelopmentModeEvictionAdvisorAnalysis()
450       : RegAllocEvictionAdvisorAnalysis(AdvisorMode::Development) {
451     if (EnableDevelopmentFeatures) {
452       InputFeatures = {RA_EVICT_FEATURES_LIST(
453           _DECL_FEATURES) RA_EVICT_FIRST_DEVELOPMENT_FEATURE(_DECL_FEATURES)
454                            RA_EVICT_REST_DEVELOPMENT_FEATURES(_DECL_FEATURES)};
455       TrainingInputFeatures = {
456           RA_EVICT_FEATURES_LIST(_DECL_TRAIN_FEATURES)
457               RA_EVICT_FIRST_DEVELOPMENT_FEATURE(_DECL_TRAIN_FEATURES)
458                   RA_EVICT_REST_DEVELOPMENT_FEATURES(_DECL_TRAIN_FEATURES)
459                       TensorSpec::createSpec<float>("action_discount", {1}),
460           TensorSpec::createSpec<int32_t>("action_step_type", {1}),
461           TensorSpec::createSpec<float>("action_reward", {1})};
462     } else {
463       InputFeatures = {RA_EVICT_FEATURES_LIST(_DECL_FEATURES)};
464       TrainingInputFeatures = {
465           RA_EVICT_FEATURES_LIST(_DECL_TRAIN_FEATURES)
466               TensorSpec::createSpec<float>("action_discount", {1}),
467           TensorSpec::createSpec<int32_t>("action_step_type", {1}),
468           TensorSpec::createSpec<float>("action_reward", {1})};
469     }
470   }
471   // support for isa<> and dyn_cast.
472   static bool classof(const RegAllocEvictionAdvisorAnalysis *R) {
473     return R->getAdvisorMode() == AdvisorMode::Development;
474   }
475 
476   void logRewardIfNeeded(const MachineFunction &MF,
477                          llvm::function_ref<float()> GetReward) override {
478     if (!Log || !Log->hasAnyObservationForContext(MF.getName()))
479       return;
480     // The function pass manager would run all the function passes for a
481     // function, so we assume the last context belongs to this function. If
482     // this invariant ever changes, we can implement at that time switching
483     // contexts. At this point, it'd be an error
484     if (Log->currentContext() != MF.getName()) {
485       MF.getFunction().getContext().emitError(
486           "The training log context shouldn't have had changed.");
487     }
488     if (Log->hasObservationInProgress())
489       Log->logReward<float>(GetReward());
490   }
491 
492 private:
493   std::vector<TensorSpec> InputFeatures;
494   std::vector<TensorSpec> TrainingInputFeatures;
495 
496   void getAnalysisUsage(AnalysisUsage &AU) const override {
497     AU.addRequired<MachineBlockFrequencyInfo>();
498     AU.addRequired<MachineLoopInfo>();
499     RegAllocEvictionAdvisorAnalysis::getAnalysisUsage(AU);
500   }
501 
502   bool doInitialization(Module &M) override {
503     LLVMContext &Ctx = M.getContext();
504     if (ModelUnderTraining.empty() && TrainingLog.empty()) {
505       Ctx.emitError("Regalloc development mode should be requested with at "
506                     "least logging enabled and/or a training model");
507       return false;
508     }
509     if (ModelUnderTraining.empty())
510       Runner = std::make_unique<NoInferenceModelRunner>(Ctx, InputFeatures);
511     else
512       Runner = ModelUnderTrainingRunner::createAndEnsureValid(
513           Ctx, ModelUnderTraining, DecisionName, TrainingInputFeatures);
514     if (!Runner) {
515       Ctx.emitError("Regalloc: could not set up the model runner");
516       return false;
517     }
518     if (TrainingLog.empty())
519       return false;
520     std::error_code EC;
521     auto OS = std::make_unique<raw_fd_ostream>(TrainingLog, EC);
522     if (EC) {
523       M.getContext().emitError(EC.message() + ":" + TrainingLog);
524       return false;
525     }
526     std::vector<TensorSpec> LFS = InputFeatures;
527     if (auto *MUTR = dyn_cast<ModelUnderTrainingRunner>(Runner.get()))
528       append_range(LFS, MUTR->extraOutputsForLoggingSpecs());
529     // We always log the output; in particular, if we're not evaluating, we
530     // don't have an output spec json file. That's why we handle the
531     // 'normal' output separately.
532     LFS.push_back(DecisionSpec);
533 
534     Log = std::make_unique<Logger>(std::move(OS), LFS, Reward,
535                                    /*IncludeReward*/ true);
536     return false;
537   }
538 
539   std::unique_ptr<RegAllocEvictionAdvisor>
540   getAdvisor(const MachineFunction &MF, const RAGreedy &RA) override {
541     if (!Runner)
542       return nullptr;
543     if (Log)
544       Log->switchContext(MF.getName());
545     return std::make_unique<DevelopmentModeEvictAdvisor>(
546         MF, RA, Runner.get(), getAnalysis<MachineBlockFrequencyInfo>(),
547         getAnalysis<MachineLoopInfo>(), Log.get());
548   }
549 
550   std::unique_ptr<MLModelRunner> Runner;
551   std::unique_ptr<Logger> Log;
552 };
553 
554 #endif //#ifdef LLVM_HAVE_TFLITE
555 } // namespace
556 
557 float MLEvictAdvisor::getInitialQueueSize(const MachineFunction &MF) {
558   auto &MRI = MF.getRegInfo();
559   float Ret = 0.0;
560   for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
561     Register Reg = Register::index2VirtReg(I);
562     if (MRI.reg_nodbg_empty(Reg))
563       continue;
564     ++Ret;
565   }
566   return Ret;
567 }
568 
569 MLEvictAdvisor::MLEvictAdvisor(const MachineFunction &MF, const RAGreedy &RA,
570                                MLModelRunner *Runner,
571                                const MachineBlockFrequencyInfo &MBFI,
572                                const MachineLoopInfo &Loops)
573     : RegAllocEvictionAdvisor(MF, RA), DefaultAdvisor(MF, RA),
574       Runner(std::move(Runner)), MBFI(MBFI), Loops(Loops),
575       InitialQSize(MLEvictAdvisor::getInitialQueueSize(MF)) {
576   assert(this->Runner);
577   Runner->switchContext(MF.getName());
578   DoNotNormalize.set(FeatureIDs::mask);
579   DoNotNormalize.set(FeatureIDs::is_free);
580   DoNotNormalize.set(FeatureIDs::is_hint);
581   DoNotNormalize.set(FeatureIDs::is_local);
582   DoNotNormalize.set(FeatureIDs::min_stage);
583   DoNotNormalize.set(FeatureIDs::max_stage);
584   DoNotNormalize.set(FeatureIDs::progress);
585 }
586 
587 int64_t MLEvictAdvisor::tryFindEvictionCandidatePosition(
588     const LiveInterval &, const AllocationOrder &, unsigned, uint8_t,
589     const SmallVirtRegSet &) const {
590   int64_t Ret = Runner->evaluate<int64_t>();
591   assert(Ret >= 0);
592   assert(Ret <= CandidateVirtRegPos);
593   return Ret;
594 }
595 
596 bool MLEvictAdvisor::loadInterferenceFeatures(
597     const LiveInterval &VirtReg, MCRegister PhysReg, bool IsHint,
598     const SmallVirtRegSet &FixedRegisters,
599     llvm::SmallVectorImpl<float> &Largest, size_t Pos,
600     llvm::SmallVectorImpl<LRStartEndInfo> &LRPosInfo) const {
601   // It is only possible to evict virtual register interference.
602   if (Matrix->checkInterference(VirtReg, PhysReg) > LiveRegMatrix::IK_VirtReg) {
603     // leave unavailable
604     return false;
605   }
606 
607   const bool IsLocal = LIS->intervalIsInOneMBB(VirtReg);
608   int64_t LocalIntfs = 0;
609   float NrUrgent = 0.0f;
610 
611   // The cascade tracking is the same as in the default advisor
612   unsigned Cascade = RA.getExtraInfo().getCascadeOrCurrentNext(VirtReg.reg());
613 
614   SmallVector<const LiveInterval *, MaxInterferences> InterferingIntervals;
615   for (MCRegUnit Unit : TRI->regunits(PhysReg)) {
616     LiveIntervalUnion::Query &Q = Matrix->query(VirtReg, Unit);
617     // Different from the default heuristic, we don't make any assumptions
618     // about what having more than 10 results in the query may mean.
619     const auto &IFIntervals = Q.interferingVRegs(EvictInterferenceCutoff);
620     if (IFIntervals.empty() && InterferingIntervals.empty())
621       continue;
622     if (IFIntervals.size() >= EvictInterferenceCutoff)
623       return false;
624     InterferingIntervals.append(IFIntervals.begin(), IFIntervals.end());
625     for (const LiveInterval *Intf : reverse(IFIntervals)) {
626       assert(Intf->reg().isVirtual() &&
627              "Only expecting virtual register interference from query");
628       // This is the same set of legality checks as in the default case: don't
629       // try to evict fixed regs or 'done' ones. Also don't break cascades,
630       // except in the urgent case, with the same nuances used in the default
631       // heuristic.
632       // We could try sharing this between the advisors, but it may end up
633       // more complex than it is right now.
634       if (FixedRegisters.count(Intf->reg()))
635         return false;
636       if (RA.getExtraInfo().getStage(*Intf) == RS_Done)
637         return false;
638       bool Urgent =
639           !VirtReg.isSpillable() &&
640           (Intf->isSpillable() ||
641            RegClassInfo.getNumAllocatableRegs(MRI->getRegClass(VirtReg.reg())) <
642                RegClassInfo.getNumAllocatableRegs(
643                    MRI->getRegClass(Intf->reg())));
644       // Only evict older cascades or live ranges without a cascade.
645       unsigned IntfCascade = RA.getExtraInfo().getCascade(Intf->reg());
646       if (Cascade <= IntfCascade) {
647         if (!Urgent)
648           return false;
649         ++NrUrgent;
650       }
651 
652       LocalIntfs += (IsLocal && LIS->intervalIsInOneMBB(*Intf) &&
653                      (!EnableLocalReassign || !canReassign(*Intf, PhysReg)));
654     }
655   }
656   // OK, so if we made it this far, this LR is an eviction candidate, load its
657   // features.
658   extractFeatures(InterferingIntervals, Largest, Pos, IsHint, LocalIntfs,
659                   NrUrgent, LRPosInfo);
660   return true;
661 }
662 
663 MCRegister MLEvictAdvisor::tryFindEvictionCandidate(
664     const LiveInterval &VirtReg, const AllocationOrder &Order,
665     uint8_t CostPerUseLimit, const SmallVirtRegSet &FixedRegisters) const {
666   auto MaybeOrderLimit = getOrderLimit(VirtReg, Order, CostPerUseLimit);
667   if (!MaybeOrderLimit)
668     return MCRegister::NoRegister;
669   unsigned OrderLimit = *MaybeOrderLimit;
670 
671   // The heuristic sets initial costs such as, if CostPerUseLimit is
672   // max<uint8_t>, then any of the costs of the legally-evictable intervals
673   // would be lower. When that happens, one of those will be selected.
674   // Therefore, we allow the candidate be selected, unless the candidate is
675   // unspillable, in which case it would be incorrect to not find a register
676   // for it.
677   const bool MustFindEviction =
678       (!VirtReg.isSpillable() && CostPerUseLimit == static_cast<uint8_t>(~0u));
679   // Number of available candidates - if 0, no need to continue.
680   size_t Available = 0;
681   // Make sure we don't have leftover partial state from an attempt where we
682   // had no available candidates and bailed out early.
683   resetInputs(*Runner);
684 
685   // Track the index->register mapping because AllocationOrder doesn't do that
686   // and we'd have to scan it.
687   // Also track their mask, to write asserts/debug.
688   CandidateRegList Regs;
689   Regs.fill({0, false});
690 
691   // Track the largest value of features seen during this eviction session. We
692   // only normalize (some of) the float features, but it's just simpler to
693   // dimension 'Largest' to all the features, especially since we have the
694   // 'DoNotNormalize' list.
695   FeaturesListNormalizer Largest(FeatureIDs::FeatureCount, 0.0);
696 
697   // Same overal idea as in the default eviction policy - we visit the values
698   // of AllocationOrder one at a time. If it's not legally available, we mask
699   // off the corresponding feature column (==do nothing because we already
700   // reset all the features to 0) Use Pos to capture the column we load
701   // features at - in AllocationOrder order.
702   size_t Pos = 0;
703   SmallVector<LRStartEndInfo, NumberOfInterferences> LRPosInfo;
704   for (auto I = Order.begin(), E = Order.getOrderLimitEnd(OrderLimit); I != E;
705        ++I, ++Pos) {
706     MCRegister PhysReg = *I;
707     assert(!Regs[Pos].second);
708     assert(PhysReg);
709     if (!canAllocatePhysReg(CostPerUseLimit, PhysReg)) {
710       continue;
711     }
712     if (loadInterferenceFeatures(VirtReg, PhysReg, I.isHint(), FixedRegisters,
713                                  Largest, Pos, LRPosInfo)) {
714       ++Available;
715       Regs[Pos] = std::make_pair(PhysReg, true);
716     }
717   }
718   if (Available == 0) {
719     // Nothing to decide, nothing to learn.
720     assert(!MustFindEviction);
721     return MCRegister::NoRegister;
722   }
723   const size_t ValidPosLimit = Pos;
724   // If we must find eviction, the candidate should be masked out of the
725   // decision making process.
726   Regs[CandidateVirtRegPos].second = !MustFindEviction;
727   if (!MustFindEviction)
728     extractFeatures(SmallVector<const LiveInterval *, 1>(1, &VirtReg), Largest,
729                     CandidateVirtRegPos, /*IsHint*/ 0,
730                     /*LocalIntfsCount*/ 0,
731                     /*NrUrgent*/ 0.0, LRPosInfo);
732   assert(InitialQSize > 0.0 && "We couldn't have gotten here if we had "
733                                "nothing to allocate initially.");
734 #ifdef LLVM_HAVE_TFLITE
735   if (EnableDevelopmentFeatures) {
736     extractInstructionFeatures(
737         LRPosInfo, Runner,
738         [this](SlotIndex InputIndex) -> int {
739           auto *CurrentMachineInstruction =
740               LIS->getInstructionFromIndex(InputIndex);
741           if (!CurrentMachineInstruction) {
742             return -1;
743           }
744           return CurrentMachineInstruction->getOpcode();
745         },
746         [this](SlotIndex InputIndex) -> float {
747           auto *CurrentMachineInstruction =
748               LIS->getInstructionFromIndex(InputIndex);
749           return MBFI.getBlockFreqRelativeToEntryBlock(
750               CurrentMachineInstruction->getParent());
751         },
752         [this](SlotIndex InputIndex) -> MachineBasicBlock * {
753           auto *CurrentMachineInstruction =
754               LIS->getInstructionFromIndex(InputIndex);
755           return CurrentMachineInstruction->getParent();
756         },
757         FeatureIDs::instructions, FeatureIDs::instructions_mapping,
758         FeatureIDs::mbb_frequencies, FeatureIDs::mbb_mapping,
759         LIS->getSlotIndexes()->getLastIndex());
760   }
761 #endif // #ifdef LLVM_HAVE_TFLITE
762   // Normalize the features.
763   for (auto &V : Largest)
764     V = V ? V : 1.0;
765   for (size_t FeatureIndex = 0; FeatureIndex < FeatureIDs::FeatureCount;
766        ++FeatureIndex) {
767     if (DoNotNormalize.test(FeatureIndex))
768       continue;
769     for (size_t Pos = 0; Pos < NumberOfInterferences; ++Pos) {
770       Runner->getTensor<float>(FeatureIndex)[Pos] /= Largest[FeatureIndex];
771     }
772   }
773   *Runner->getTensor<float>(FeatureIDs::progress) =
774       static_cast<float>(RA.getQueueSize()) / InitialQSize;
775 
776   // Get a decision.
777   size_t CandidatePos = tryFindEvictionCandidatePosition(
778       VirtReg, Order, OrderLimit, CostPerUseLimit, FixedRegisters);
779   // The contract with the ML side is that CandidatePos is mask == 1 (i.e.
780   // Regs[CandidatePos].second)
781   assert(Regs[CandidatePos].second);
782   if (CandidatePos == CandidateVirtRegPos) {
783     assert(!MustFindEviction);
784     return MCRegister::NoRegister;
785   }
786   assert(CandidatePos < ValidPosLimit);
787   (void)ValidPosLimit;
788   return Regs[CandidatePos].first;
789 }
790 
791 const LIFeatureComponents &
792 MLEvictAdvisor::getLIFeatureComponents(const LiveInterval &LI) const {
793   RegID ID = LI.reg().id();
794   LIFeatureComponents Empty;
795   auto I = CachedFeatures.insert(std::make_pair(ID, Empty));
796   LIFeatureComponents &Ret = I.first->getSecond();
797   if (!I.second)
798     return Ret;
799 
800   SmallPtrSet<MachineInstr *, 8> Visited;
801   const TargetRegisterInfo &TRI = *MF.getSubtarget().getRegisterInfo();
802 
803   for (MachineRegisterInfo::reg_instr_nodbg_iterator
804            I = MRI->reg_instr_nodbg_begin(LI.reg()),
805            E = MRI->reg_instr_nodbg_end();
806        I != E;) {
807     MachineInstr *MI = &*(I++);
808 
809     ++Ret.NrDefsAndUses;
810     if (!Visited.insert(MI).second)
811       continue;
812 
813     if (MI->isIdentityCopy() || MI->isImplicitDef())
814       continue;
815 
816     bool Reads, Writes;
817     std::tie(Reads, Writes) = MI->readsWritesVirtualRegister(LI.reg());
818 
819     float Freq = MBFI.getBlockFreqRelativeToEntryBlock(MI->getParent());
820     Ret.HottestBlockFreq = std::max(Freq, Ret.HottestBlockFreq);
821 
822     Ret.R += (Reads && !Writes) * Freq;
823     Ret.W += (!Reads && Writes) * Freq;
824     Ret.RW += (Reads && Writes) * Freq;
825 
826     auto *MBB = MI->getParent();
827     auto *Loop = Loops.getLoopFor(MBB);
828     bool IsExiting = Loop ? Loop->isLoopExiting(MBB) : false;
829 
830     if (Writes && IsExiting && LIS->isLiveOutOfMBB(LI, MBB))
831       Ret.IndVarUpdates += Freq;
832 
833     if (MI->isCopy() && VirtRegAuxInfo::copyHint(MI, LI.reg(), TRI, *MRI))
834       Ret.HintWeights += Freq;
835   }
836   Ret.IsRemat = VirtRegAuxInfo::isRematerializable(
837       LI, *LIS, *VRM, *MF.getSubtarget().getInstrInfo());
838   return Ret;
839 }
840 
841 // Overall, this currently mimics what we do for weight calculation, but instead
842 // of accummulating the various features, we keep them separate.
843 void MLEvictAdvisor::extractFeatures(
844     const SmallVectorImpl<const LiveInterval *> &Intervals,
845     llvm::SmallVectorImpl<float> &Largest, size_t Pos, int64_t IsHint,
846     int64_t LocalIntfsCount, float NrUrgent,
847     SmallVectorImpl<LRStartEndInfo> &LRPosInfo) const {
848   int64_t NrDefsAndUses = 0;
849   int64_t NrBrokenHints = 0;
850   double R = 0.0;
851   double W = 0.0;
852   double RW = 0.0;
853   double IndVarUpdates = 0.0;
854   double HintWeights = 0.0;
855   float StartBBFreq = 0.0;
856   float EndBBFreq = 0.0;
857   float HottestBlockFreq = 0.0;
858   int32_t NrRematerializable = 0;
859   float TotalWeight = 0.0;
860 
861   SlotIndex EndSI = LIS->getSlotIndexes()->getZeroIndex();
862   SlotIndex StartSI = LIS->getSlotIndexes()->getLastIndex();
863   int64_t MaxStage = 0;
864   int64_t MinStage =
865       Intervals.empty() ? 0 : std::numeric_limits<int64_t>::max();
866 
867   for (const auto *L : Intervals) {
868     const LiveInterval &LI = *L;
869     MaxStage = std::max<int64_t>(
870         MaxStage, static_cast<int64_t>(RA.getExtraInfo().getStage(LI)));
871     MinStage = std::min<int64_t>(
872         MinStage, static_cast<int64_t>(RA.getExtraInfo().getStage(LI)));
873 
874     TotalWeight = std::max(TotalWeight, LI.weight());
875 
876     if (LI.beginIndex() < StartSI)
877       StartSI = LI.beginIndex();
878 
879     if (LI.endIndex() > EndSI)
880       EndSI = LI.endIndex();
881     const LIFeatureComponents &LIFC = getLIFeatureComponents(LI);
882     NrBrokenHints += VRM->hasPreferredPhys(LI.reg());
883 
884     NrDefsAndUses += LIFC.NrDefsAndUses;
885     HottestBlockFreq = std::max(HottestBlockFreq, LIFC.HottestBlockFreq);
886     R += LIFC.R;
887     W += LIFC.W;
888     RW += LIFC.RW;
889 
890     IndVarUpdates += LIFC.IndVarUpdates;
891 
892     HintWeights += LIFC.HintWeights;
893     NrRematerializable += LIFC.IsRemat;
894 
895     if (EnableDevelopmentFeatures) {
896       for (auto CurrentSegment : LI) {
897         LRPosInfo.push_back(
898             LRStartEndInfo{CurrentSegment.start, CurrentSegment.end, Pos});
899       }
900     }
901   }
902   size_t Size = 0;
903   if (!Intervals.empty()) {
904     StartBBFreq =
905         MBFI.getBlockFreqRelativeToEntryBlock(LIS->getMBBFromIndex(StartSI));
906     if (EndSI >= LIS->getSlotIndexes()->getLastIndex())
907       EndSI = LIS->getSlotIndexes()->getLastIndex().getPrevIndex();
908     EndBBFreq =
909         MBFI.getBlockFreqRelativeToEntryBlock(LIS->getMBBFromIndex(EndSI));
910     Size = StartSI.distance(EndSI);
911   }
912   // Set the features at the column 'Pos'.
913 #define SET(ID, TYPE, VAL)                                                     \
914   do {                                                                         \
915     Runner->getTensor<TYPE>(FeatureIDs::ID)[Pos] = static_cast<TYPE>(VAL);     \
916     if (!DoNotNormalize.test(FeatureIDs::ID))                                  \
917       Largest[FeatureIDs::ID] =                                                \
918           std::max(Largest[FeatureIDs::ID], static_cast<float>(VAL));          \
919   } while (false)
920   SET(mask, int64_t, 1);
921   SET(is_free, int64_t, Intervals.empty());
922   SET(nr_urgent, float, NrUrgent);
923   SET(nr_broken_hints, float, NrBrokenHints);
924   SET(is_hint, int64_t, IsHint);
925   SET(is_local, int64_t, LocalIntfsCount);
926   SET(nr_rematerializable, float, NrRematerializable);
927   SET(nr_defs_and_uses, float, NrDefsAndUses);
928   SET(weighed_reads_by_max, float, R);
929   SET(weighed_writes_by_max, float, W);
930   SET(weighed_read_writes_by_max, float, RW);
931   SET(weighed_indvars_by_max, float, IndVarUpdates);
932   SET(hint_weights_by_max, float, HintWeights);
933   SET(start_bb_freq_by_max, float, StartBBFreq);
934   SET(end_bb_freq_by_max, float, EndBBFreq);
935   SET(hottest_bb_freq_by_max, float, HottestBlockFreq);
936   SET(liverange_size, float, Size);
937   SET(use_def_density, float, TotalWeight);
938   SET(max_stage, int64_t, MaxStage);
939   SET(min_stage, int64_t, MinStage);
940 #undef SET
941 }
942 
943 void extractInstructionFeatures(
944     SmallVectorImpl<LRStartEndInfo> &LRPosInfo, MLModelRunner *RegallocRunner,
945     function_ref<int(SlotIndex)> GetOpcode,
946     function_ref<float(SlotIndex)> GetMBBFreq,
947     function_ref<MachineBasicBlock *(SlotIndex)> GetMBBReference,
948     const int InstructionsIndex, const int InstructionsMappingIndex,
949     const int MBBFreqIndex, const int MBBMappingIndex,
950     const SlotIndex LastIndex) {
951   // This function extracts instruction based features relevant to the eviction
952   // problem currently being solved. This function ends up extracting two
953   // tensors.
954   // 1 - A vector of size max instruction count. It contains the opcodes of the
955   // instructions spanned by all the intervals in the current instance of the
956   // eviction problem.
957   // 2 - A binary mapping matrix of size (LR count * max
958   // instruction count) which maps where the LRs are live to the actual opcodes
959   // for which they are live.
960   // 3 - A vector of size max supported MBB count storing MBB frequencies,
961   // encompassing all of the MBBs covered by the eviction problem.
962   // 4 - A vector of size max instruction count of indices to members of the MBB
963   // frequency vector, mapping each instruction to its associated MBB.
964 
965   // Start off by sorting the segments based on the beginning slot index.
966   std::sort(
967       LRPosInfo.begin(), LRPosInfo.end(),
968       [](LRStartEndInfo A, LRStartEndInfo B) { return A.Begin < B.Begin; });
969   size_t InstructionIndex = 0;
970   size_t CurrentSegmentIndex = 0;
971   SlotIndex CurrentIndex = LRPosInfo[0].Begin;
972   std::map<MachineBasicBlock *, size_t> VisitedMBBs;
973   size_t CurrentMBBIndex = 0;
974   // This loop processes all the segments sequentially by starting at the
975   // beginning slot index of the first segment, iterating through all the slot
976   // indices before the end slot index of that segment (while checking for
977   // overlaps with segments that start at greater slot indices). After hitting
978   // that end index, the current segment being processed gets bumped until they
979   // are all processed or the max instruction count is hit, where everything is
980   // just truncated.
981   while (true) {
982     // If the index that we are currently at is within the current segment and
983     // we haven't hit the max instruction count, continue processing the current
984     // segment.
985     while (CurrentIndex <= LRPosInfo[CurrentSegmentIndex].End &&
986            InstructionIndex < ModelMaxSupportedInstructionCount) {
987       int CurrentOpcode = GetOpcode(CurrentIndex);
988       // If the current machine instruction is null, skip it
989       if (CurrentOpcode == -1) {
990         // If we're currently at the last index in the SlotIndex analysis,
991         // we can't go any further, so return from the function
992         if (CurrentIndex >= LastIndex) {
993           return;
994         }
995         CurrentIndex = CurrentIndex.getNextIndex();
996         continue;
997       }
998       MachineBasicBlock *CurrentMBBReference = GetMBBReference(CurrentIndex);
999       if (VisitedMBBs.count(CurrentMBBReference) == 0) {
1000         VisitedMBBs[CurrentMBBReference] = CurrentMBBIndex;
1001         ++CurrentMBBIndex;
1002       }
1003       extractMBBFrequency(CurrentIndex, InstructionIndex, VisitedMBBs,
1004                           GetMBBFreq, CurrentMBBReference, RegallocRunner,
1005                           MBBFreqIndex, MBBMappingIndex);
1006       // Current code assumes we're not going to get any disjointed segments
1007       assert(LRPosInfo[CurrentSegmentIndex].Begin <= CurrentIndex);
1008       RegallocRunner->getTensor<int64_t>(InstructionsIndex)[InstructionIndex] =
1009           CurrentOpcode < OpcodeValueCutoff ? CurrentOpcode : 0;
1010       // set value in the binary mapping matrix for the current instruction
1011       auto CurrentSegmentPosition = LRPosInfo[CurrentSegmentIndex].Pos;
1012       RegallocRunner->getTensor<int64_t>(
1013           InstructionsMappingIndex)[CurrentSegmentPosition *
1014                                         ModelMaxSupportedInstructionCount +
1015                                     InstructionIndex] = 1;
1016       // All of the segments are sorted based on the beginning slot index, but
1017       // this doesn't mean that the beginning slot index of the next segment is
1018       // after the end segment of the one being currently processed. This while
1019       // loop checks for overlapping segments and modifies the portion of the
1020       // column in the mapping matrix for the currently processed instruction
1021       // for the LR it is checking. Also make sure that the beginning of the
1022       // current segment we're checking for overlap in is less than the current
1023       // index, otherwise we're done checking overlaps.
1024       size_t OverlapCheckCurrentSegment = CurrentSegmentIndex + 1;
1025       while (OverlapCheckCurrentSegment < LRPosInfo.size() &&
1026              LRPosInfo[OverlapCheckCurrentSegment].Begin <= CurrentIndex) {
1027         auto OverlapCurrentSegmentPosition =
1028             LRPosInfo[OverlapCheckCurrentSegment].Pos;
1029         if (LRPosInfo[OverlapCheckCurrentSegment].End >= CurrentIndex) {
1030           RegallocRunner->getTensor<int64_t>(
1031               InstructionsMappingIndex)[OverlapCurrentSegmentPosition *
1032                                             ModelMaxSupportedInstructionCount +
1033                                         InstructionIndex] = 1;
1034         }
1035         ++OverlapCheckCurrentSegment;
1036       }
1037       ++InstructionIndex;
1038       if (CurrentIndex >= LastIndex) {
1039         return;
1040       }
1041       CurrentIndex = CurrentIndex.getNextIndex();
1042     }
1043     // if we've just finished processing through the last segment or if we've
1044     // hit the maximum number of instructions, break out of the loop.
1045     if (CurrentSegmentIndex == LRPosInfo.size() - 1 ||
1046         InstructionIndex >= ModelMaxSupportedInstructionCount) {
1047       break;
1048     }
1049     // If the segments are not overlapping, we need to move to the beginning
1050     // index of the next segment to avoid having instructions not attached to
1051     // any register.
1052     if (LRPosInfo[CurrentSegmentIndex + 1].Begin >
1053         LRPosInfo[CurrentSegmentIndex].End) {
1054       CurrentIndex = LRPosInfo[CurrentSegmentIndex + 1].Begin;
1055     }
1056     ++CurrentSegmentIndex;
1057   }
1058 }
1059 
1060 void extractMBBFrequency(const SlotIndex CurrentIndex,
1061                          const size_t CurrentInstructionIndex,
1062                          std::map<MachineBasicBlock *, size_t> &VisitedMBBs,
1063                          function_ref<float(SlotIndex)> GetMBBFreq,
1064                          MachineBasicBlock *CurrentMBBReference,
1065                          MLModelRunner *RegallocRunner, const int MBBFreqIndex,
1066                          const int MBBMappingIndex) {
1067   size_t CurrentMBBIndex = VisitedMBBs[CurrentMBBReference];
1068   float CurrentMBBFreq = GetMBBFreq(CurrentIndex);
1069   if (CurrentMBBIndex < ModelMaxSupportedMBBCount) {
1070     RegallocRunner->getTensor<float>(MBBFreqIndex)[CurrentMBBIndex] =
1071         CurrentMBBFreq;
1072     RegallocRunner->getTensor<int64_t>(
1073         MBBMappingIndex)[CurrentInstructionIndex] = CurrentMBBIndex;
1074   }
1075 }
1076 
1077 // Development mode-specific implementations
1078 #ifdef LLVM_HAVE_TFLITE
1079 
1080 RegAllocEvictionAdvisorAnalysis *llvm::createDevelopmentModeAdvisor() {
1081   return new DevelopmentModeEvictionAdvisorAnalysis();
1082 }
1083 
1084 int64_t DevelopmentModeEvictAdvisor::tryFindEvictionCandidatePosition(
1085     const LiveInterval &VirtReg, const AllocationOrder &Order,
1086     unsigned OrderLimit, uint8_t CostPerUseLimit,
1087     const SmallVirtRegSet &FixedRegisters) const {
1088   int64_t Ret = 0;
1089   if (isa<ModelUnderTrainingRunner>(getRunner())) {
1090     Ret = MLEvictAdvisor::tryFindEvictionCandidatePosition(
1091         VirtReg, Order, OrderLimit, CostPerUseLimit, FixedRegisters);
1092   } else {
1093     MCRegister PhysReg = getDefaultAdvisor().tryFindEvictionCandidate(
1094         VirtReg, Order, CostPerUseLimit, FixedRegisters);
1095     // Find the index of the selected PhysReg. We need it for logging,
1096     // otherwise this is wasted cycles (but so would starting development mode
1097     // without a model nor logging)
1098     if (!PhysReg)
1099       Ret = CandidateVirtRegPos;
1100     else
1101       for (auto I = Order.begin(), E = Order.getOrderLimitEnd(OrderLimit);
1102            I != E; ++I, ++Ret)
1103         if (*I == PhysReg)
1104           break;
1105   }
1106   if (TrainingLog.empty())
1107     return Ret;
1108   // TODO(mtrofin): when we support optional rewards, this can go away. In the
1109   // meantime, we log the "pretend" reward (0) for the previous observation
1110   // before starting a new one.
1111   if (Log->hasObservationInProgress())
1112     Log->logReward<float>(0.0);
1113 
1114   Log->startObservation();
1115   size_t CurrentFeature = 0;
1116   size_t FeatureCount = EnableDevelopmentFeatures
1117                             ? FeatureIDs::FeaturesWithDevelopmentCount
1118                             : FeatureIDs::FeatureCount;
1119   for (; CurrentFeature < FeatureCount; ++CurrentFeature) {
1120     Log->logTensorValue(CurrentFeature,
1121                         reinterpret_cast<const char *>(
1122                             getRunner().getTensorUntyped(CurrentFeature)));
1123   }
1124   if (auto *MUTR = dyn_cast<ModelUnderTrainingRunner>(&getRunner()))
1125     for (size_t I = 0; I < MUTR->extraOutputsForLoggingSpecs().size();
1126          ++I, ++CurrentFeature)
1127       Log->logTensorValue(
1128           CurrentFeature,
1129           reinterpret_cast<const char *>(MUTR->getUntypedExtraOutputValue(I)));
1130   // The output is right after the features and the extra outputs
1131   Log->logTensorValue(CurrentFeature, reinterpret_cast<const char *>(&Ret));
1132   Log->endObservation();
1133   return Ret;
1134 }
1135 
1136 bool RegAllocScoring::runOnMachineFunction(MachineFunction &MF) {
1137   std::optional<float> CachedReward;
1138   auto GetReward = [&]() {
1139     if (!CachedReward)
1140       CachedReward = static_cast<float>(
1141           calculateRegAllocScore(MF, getAnalysis<MachineBlockFrequencyInfo>())
1142               .getScore());
1143     return *CachedReward;
1144   };
1145 
1146   getAnalysis<RegAllocEvictionAdvisorAnalysis>().logRewardIfNeeded(MF,
1147                                                                    GetReward);
1148   getAnalysis<RegAllocPriorityAdvisorAnalysis>().logRewardIfNeeded(MF,
1149                                                                    GetReward);
1150   return false;
1151 }
1152 #endif // #ifdef LLVM_HAVE_TFLITE
1153 
1154 RegAllocEvictionAdvisorAnalysis *llvm::createReleaseModeAdvisor() {
1155   return llvm::isEmbeddedModelEvaluatorValid<CompiledModelType>() ||
1156                  !InteractiveChannelBaseName.empty()
1157              ? new ReleaseModeEvictionAdvisorAnalysis()
1158              : nullptr;
1159 }
1160 
1161 // In all cases except development mode, we don't need scoring.
1162 #if !defined(LLVM_HAVE_TFLITE)
1163 bool RegAllocScoring::runOnMachineFunction(MachineFunction &) { return false; }
1164 #endif
1165