1bdd1243dSDimitry Andric //===- TrainingLogger.cpp - mlgo feature/reward logging -------------------===// 2bdd1243dSDimitry Andric // 3bdd1243dSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4bdd1243dSDimitry Andric // See https://llvm.org/LICENSE.txt for license information. 5bdd1243dSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6bdd1243dSDimitry Andric // 7bdd1243dSDimitry Andric //===----------------------------------------------------------------------===// 8bdd1243dSDimitry Andric // 9bdd1243dSDimitry Andric // This file implements logging infrastructure for extracting features and 10bdd1243dSDimitry Andric // rewards for mlgo policy training. 11bdd1243dSDimitry Andric // 12bdd1243dSDimitry Andric //===----------------------------------------------------------------------===// 13bdd1243dSDimitry Andric #include "llvm/Analysis/TensorSpec.h" 14bdd1243dSDimitry Andric #include "llvm/Config/config.h" 15bdd1243dSDimitry Andric 16bdd1243dSDimitry Andric #include "llvm/ADT/Twine.h" 17bdd1243dSDimitry Andric #include "llvm/Analysis/Utils/TrainingLogger.h" 18bdd1243dSDimitry Andric #include "llvm/Support/CommandLine.h" 19bdd1243dSDimitry Andric #include "llvm/Support/Debug.h" 20bdd1243dSDimitry Andric #include "llvm/Support/JSON.h" 21bdd1243dSDimitry Andric #include "llvm/Support/MemoryBuffer.h" 22bdd1243dSDimitry Andric #include "llvm/Support/Path.h" 23bdd1243dSDimitry Andric #include "llvm/Support/raw_ostream.h" 24bdd1243dSDimitry Andric 25bdd1243dSDimitry Andric #include <cassert> 26bdd1243dSDimitry Andric #include <numeric> 27bdd1243dSDimitry Andric 28bdd1243dSDimitry Andric using namespace llvm; 29bdd1243dSDimitry Andric 30*06c3fb27SDimitry Andric void Logger::writeHeader(std::optional<TensorSpec> AdviceSpec) { 31bdd1243dSDimitry Andric json::OStream JOS(*OS); 32bdd1243dSDimitry Andric JOS.object([&]() { 33bdd1243dSDimitry Andric JOS.attributeArray("features", [&]() { 34bdd1243dSDimitry Andric for (const auto &TS : FeatureSpecs) 35bdd1243dSDimitry Andric TS.toJSON(JOS); 36bdd1243dSDimitry Andric }); 37bdd1243dSDimitry Andric if (IncludeReward) { 38bdd1243dSDimitry Andric JOS.attributeBegin("score"); 39bdd1243dSDimitry Andric RewardSpec.toJSON(JOS); 40bdd1243dSDimitry Andric JOS.attributeEnd(); 41bdd1243dSDimitry Andric } 42*06c3fb27SDimitry Andric if (AdviceSpec.has_value()) { 43*06c3fb27SDimitry Andric JOS.attributeBegin("advice"); 44*06c3fb27SDimitry Andric AdviceSpec->toJSON(JOS); 45*06c3fb27SDimitry Andric JOS.attributeEnd(); 46*06c3fb27SDimitry Andric } 47bdd1243dSDimitry Andric }); 48bdd1243dSDimitry Andric *OS << "\n"; 49bdd1243dSDimitry Andric } 50bdd1243dSDimitry Andric 51bdd1243dSDimitry Andric void Logger::switchContext(StringRef Name) { 52bdd1243dSDimitry Andric CurrentContext = Name.str(); 53bdd1243dSDimitry Andric json::OStream JOS(*OS); 54bdd1243dSDimitry Andric JOS.object([&]() { JOS.attribute("context", Name); }); 55bdd1243dSDimitry Andric *OS << "\n"; 56bdd1243dSDimitry Andric } 57bdd1243dSDimitry Andric 58bdd1243dSDimitry Andric void Logger::startObservation() { 59bdd1243dSDimitry Andric auto I = ObservationIDs.insert({CurrentContext, 0}); 60bdd1243dSDimitry Andric size_t NewObservationID = I.second ? 0 : ++I.first->second; 61bdd1243dSDimitry Andric json::OStream JOS(*OS); 62bdd1243dSDimitry Andric JOS.object([&]() { 63bdd1243dSDimitry Andric JOS.attribute("observation", static_cast<int64_t>(NewObservationID)); 64bdd1243dSDimitry Andric }); 65bdd1243dSDimitry Andric *OS << "\n"; 66bdd1243dSDimitry Andric } 67bdd1243dSDimitry Andric 68bdd1243dSDimitry Andric void Logger::endObservation() { *OS << "\n"; } 69bdd1243dSDimitry Andric 70bdd1243dSDimitry Andric void Logger::logRewardImpl(const char *RawData) { 71bdd1243dSDimitry Andric assert(IncludeReward); 72bdd1243dSDimitry Andric json::OStream JOS(*OS); 73bdd1243dSDimitry Andric JOS.object([&]() { 74bdd1243dSDimitry Andric JOS.attribute("outcome", static_cast<int64_t>( 75bdd1243dSDimitry Andric ObservationIDs.find(CurrentContext)->second)); 76bdd1243dSDimitry Andric }); 77bdd1243dSDimitry Andric *OS << "\n"; 78bdd1243dSDimitry Andric writeTensor(RewardSpec, RawData); 79bdd1243dSDimitry Andric *OS << "\n"; 80bdd1243dSDimitry Andric } 81bdd1243dSDimitry Andric 82bdd1243dSDimitry Andric Logger::Logger(std::unique_ptr<raw_ostream> OS, 83bdd1243dSDimitry Andric const std::vector<TensorSpec> &FeatureSpecs, 84*06c3fb27SDimitry Andric const TensorSpec &RewardSpec, bool IncludeReward, 85*06c3fb27SDimitry Andric std::optional<TensorSpec> AdviceSpec) 86bdd1243dSDimitry Andric : OS(std::move(OS)), FeatureSpecs(FeatureSpecs), RewardSpec(RewardSpec), 87bdd1243dSDimitry Andric IncludeReward(IncludeReward) { 88*06c3fb27SDimitry Andric writeHeader(AdviceSpec); 89bdd1243dSDimitry Andric } 90