//===- TrainingLogger.cpp - mlgo feature/reward logging -------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements logging infrastructure for extracting features and // rewards for mlgo policy training. // //===----------------------------------------------------------------------===// #include "llvm/Analysis/TensorSpec.h" #include "llvm/Config/config.h" #include "llvm/ADT/Twine.h" #include "llvm/Analysis/Utils/TrainingLogger.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/JSON.h" #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/Path.h" #include "llvm/Support/raw_ostream.h" #include #include using namespace llvm; // FIXME(mtrofin): remove the flag altogether static cl::opt UseSimpleLogger("tfutils-use-simplelogger", cl::init(true), cl::Hidden, cl::desc("Output simple (non-protobuf) log.")); void Logger::writeHeader() { json::OStream JOS(*OS); JOS.object([&]() { JOS.attributeArray("features", [&]() { for (const auto &TS : FeatureSpecs) TS.toJSON(JOS); }); if (IncludeReward) { JOS.attributeBegin("score"); RewardSpec.toJSON(JOS); JOS.attributeEnd(); } }); *OS << "\n"; } void Logger::switchContext(StringRef Name) { CurrentContext = Name.str(); json::OStream JOS(*OS); JOS.object([&]() { JOS.attribute("context", Name); }); *OS << "\n"; } void Logger::startObservation() { auto I = ObservationIDs.insert({CurrentContext, 0}); size_t NewObservationID = I.second ? 0 : ++I.first->second; json::OStream JOS(*OS); JOS.object([&]() { JOS.attribute("observation", static_cast(NewObservationID)); }); *OS << "\n"; } void Logger::endObservation() { *OS << "\n"; } void Logger::logRewardImpl(const char *RawData) { assert(IncludeReward); json::OStream JOS(*OS); JOS.object([&]() { JOS.attribute("outcome", static_cast( ObservationIDs.find(CurrentContext)->second)); }); *OS << "\n"; writeTensor(RewardSpec, RawData); *OS << "\n"; } Logger::Logger(std::unique_ptr OS, const std::vector &FeatureSpecs, const TensorSpec &RewardSpec, bool IncludeReward) : OS(std::move(OS)), FeatureSpecs(FeatureSpecs), RewardSpec(RewardSpec), IncludeReward(IncludeReward) { writeHeader(); }