1 //===- TrainingLogger.h - mlgo feature/reward logging ----------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // The design goals of the logger are: 10 // - no dependencies that llvm doesn't already have. 11 // - support streaming, so that we don't need to buffer data during compilation 12 // - 0-decoding tensor values. Tensor values are potentially very large buffers 13 // of scalars. Because of their potentially large size, avoiding 14 // serialization/deserialization overhead is preferred. 15 // 16 // The simple logger produces an output of the form (each line item on its line) 17 // - header: a json object describing the data that will follow. 18 // - context: e.g. function name, for regalloc, or "default" for module-wide 19 // optimizations like the inliner. This is the context to which the subsequent 20 // data corresponds. 21 // - observation number. 22 // - tensor values - raw bytes of the tensors, in the order given in the header. 23 // The values are in succession, i.e. no separator is found between successive 24 // tensor values. At the end, there is a new line character. 25 // - [score] - this is optional, and is present if it was present in the header. 26 // Currently, for final rewards, we output "0" scores after each observation, 27 // except for the last one. 28 // <repeat> 29 // The file should be read as binary, but the reason we use newlines is mostly 30 // ease of debugging: the log can be opened in a text editor and, while tensor 31 // values are inscrutable, at least the sequence of data can be easily observed. 32 // Of course, the buffer of tensor values could contain '\n' bytes. A reader 33 // should use the header information to know how much data to read for the 34 // tensor values, and not use line information for that. 35 // 36 // An example reader, used for test, is available at 37 // Analysis/models/log_reader.py 38 // 39 // Example: 40 // {"features":[list of TensorSpecs], "score":<a tensor spec>} 41 // {"context": "aFunction"} 42 // {"observation": 0} 43 // <bytes> 44 // {"outcome": 0} 45 // <bytes for the tensor corresponding to the "score" spec in the header> 46 // {"observation": 1} 47 // ... 48 // {"context": "anotherFunction"} 49 // {"observation": 0} 50 // ... 51 // 52 53 #ifndef LLVM_ANALYSIS_UTILS_TRAININGLOGGER_H 54 #define LLVM_ANALYSIS_UTILS_TRAININGLOGGER_H 55 56 #include "llvm/Config/llvm-config.h" 57 #include "llvm/Support/Compiler.h" 58 59 #include "llvm/ADT/StringMap.h" 60 #include "llvm/Analysis/TensorSpec.h" 61 #include "llvm/IR/LLVMContext.h" 62 #include "llvm/Support/JSON.h" 63 64 #include <memory> 65 #include <optional> 66 #include <vector> 67 68 namespace llvm { 69 70 /// Logging utility - given an ordered specification of features, and assuming 71 /// a scalar reward, allow logging feature values and rewards. 72 /// The assumption is that, for an event to be logged (i.e. a set of feature 73 /// values and a reward), the user calls the log* API for each feature exactly 74 /// once, providing the index matching the position in the feature spec list 75 /// provided at construction. The example assumes the first feature's element 76 /// type is float, the second is int64, and the reward is float: 77 /// 78 /// event 0: 79 /// logFloatValue(0, ...) 80 /// logInt64Value(1, ...) 81 /// ... 82 /// logFloatReward(...) 83 /// event 1: 84 /// logFloatValue(0, ...) 85 /// logInt64Value(1, ...) 86 /// ... 87 /// logFloatReward(...) 88 /// 89 /// At the end, call print to generate the log. 90 /// Alternatively, don't call logReward at the end of each event, just 91 /// log{Float|Int32|Int64}FinalReward at the end. 92 class Logger final { 93 std::unique_ptr<raw_ostream> OS; 94 const std::vector<TensorSpec> FeatureSpecs; 95 const TensorSpec RewardSpec; 96 const bool IncludeReward; 97 StringMap<size_t> ObservationIDs; 98 std::string CurrentContext; 99 100 void writeHeader(std::optional<TensorSpec> AdviceSpec); writeTensor(const TensorSpec & Spec,const char * RawData)101 void writeTensor(const TensorSpec &Spec, const char *RawData) { 102 OS->write(RawData, Spec.getTotalTensorBufferSize()); 103 } 104 LLVM_ABI void logRewardImpl(const char *RawData); 105 106 public: 107 /// Construct a Logger. If IncludeReward is false, then logReward or 108 /// logFinalReward shouldn't be called, and the reward feature won't be 109 /// printed out. 110 /// NOTE: the FeatureSpecs are expected to be in the same order (i.e. have 111 /// corresponding indices) with any MLModelRunner implementations 112 /// corresponding to the model being trained/logged. 113 LLVM_ABI Logger(std::unique_ptr<raw_ostream> OS, 114 const std::vector<TensorSpec> &FeatureSpecs, 115 const TensorSpec &RewardSpec, bool IncludeReward, 116 std::optional<TensorSpec> AdviceSpec = std::nullopt); 117 118 LLVM_ABI void switchContext(StringRef Name); 119 LLVM_ABI void startObservation(); 120 LLVM_ABI void endObservation(); flush()121 void flush() { OS->flush(); } 122 currentContext()123 const std::string ¤tContext() const { return CurrentContext; } 124 125 /// Check if there is at least an observation for `currentContext()`. hasObservationInProgress()126 bool hasObservationInProgress() const { 127 return hasAnyObservationForContext(CurrentContext); 128 } 129 130 /// Check if there is at least an observation for the context `Ctx`. hasAnyObservationForContext(StringRef Ctx)131 bool hasAnyObservationForContext(StringRef Ctx) const { 132 return ObservationIDs.contains(Ctx); 133 } 134 logReward(T Value)135 template <typename T> void logReward(T Value) { 136 logRewardImpl(reinterpret_cast<const char *>(&Value)); 137 } 138 logTensorValue(size_t FeatureID,const char * RawData)139 void logTensorValue(size_t FeatureID, const char *RawData) { 140 writeTensor(FeatureSpecs[FeatureID], RawData); 141 } 142 }; 143 144 } // namespace llvm 145 #endif // LLVM_ANALYSIS_UTILS_TRAININGLOGGER_H 146