xref: /freebsd/contrib/llvm-project/llvm/lib/Analysis/TrainingLogger.cpp (revision 5f757f3ff9144b609b3c433dfd370cc6bdc191ad)
1bdd1243dSDimitry Andric //===- TrainingLogger.cpp - mlgo feature/reward logging -------------------===//
2bdd1243dSDimitry Andric //
3bdd1243dSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4bdd1243dSDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5bdd1243dSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6bdd1243dSDimitry Andric //
7bdd1243dSDimitry Andric //===----------------------------------------------------------------------===//
8bdd1243dSDimitry Andric //
9bdd1243dSDimitry Andric // This file implements logging infrastructure for extracting features and
10bdd1243dSDimitry Andric // rewards for mlgo policy training.
11bdd1243dSDimitry Andric //
12bdd1243dSDimitry Andric //===----------------------------------------------------------------------===//
13bdd1243dSDimitry Andric #include "llvm/Analysis/TensorSpec.h"
14bdd1243dSDimitry Andric #include "llvm/Config/config.h"
15bdd1243dSDimitry Andric 
16bdd1243dSDimitry Andric #include "llvm/ADT/Twine.h"
17bdd1243dSDimitry Andric #include "llvm/Analysis/Utils/TrainingLogger.h"
18bdd1243dSDimitry Andric #include "llvm/Support/CommandLine.h"
19bdd1243dSDimitry Andric #include "llvm/Support/Debug.h"
20bdd1243dSDimitry Andric #include "llvm/Support/JSON.h"
21bdd1243dSDimitry Andric #include "llvm/Support/MemoryBuffer.h"
22bdd1243dSDimitry Andric #include "llvm/Support/Path.h"
23bdd1243dSDimitry Andric #include "llvm/Support/raw_ostream.h"
24bdd1243dSDimitry Andric 
25bdd1243dSDimitry Andric #include <cassert>
26bdd1243dSDimitry Andric #include <numeric>
27bdd1243dSDimitry Andric 
28bdd1243dSDimitry Andric using namespace llvm;
29bdd1243dSDimitry Andric 
30*06c3fb27SDimitry Andric void Logger::writeHeader(std::optional<TensorSpec> AdviceSpec) {
31bdd1243dSDimitry Andric   json::OStream JOS(*OS);
32bdd1243dSDimitry Andric   JOS.object([&]() {
33bdd1243dSDimitry Andric     JOS.attributeArray("features", [&]() {
34bdd1243dSDimitry Andric       for (const auto &TS : FeatureSpecs)
35bdd1243dSDimitry Andric         TS.toJSON(JOS);
36bdd1243dSDimitry Andric     });
37bdd1243dSDimitry Andric     if (IncludeReward) {
38bdd1243dSDimitry Andric       JOS.attributeBegin("score");
39bdd1243dSDimitry Andric       RewardSpec.toJSON(JOS);
40bdd1243dSDimitry Andric       JOS.attributeEnd();
41bdd1243dSDimitry Andric     }
42*06c3fb27SDimitry Andric     if (AdviceSpec.has_value()) {
43*06c3fb27SDimitry Andric       JOS.attributeBegin("advice");
44*06c3fb27SDimitry Andric       AdviceSpec->toJSON(JOS);
45*06c3fb27SDimitry Andric       JOS.attributeEnd();
46*06c3fb27SDimitry Andric     }
47bdd1243dSDimitry Andric   });
48bdd1243dSDimitry Andric   *OS << "\n";
49bdd1243dSDimitry Andric }
50bdd1243dSDimitry Andric 
51bdd1243dSDimitry Andric void Logger::switchContext(StringRef Name) {
52bdd1243dSDimitry Andric   CurrentContext = Name.str();
53bdd1243dSDimitry Andric   json::OStream JOS(*OS);
54bdd1243dSDimitry Andric   JOS.object([&]() { JOS.attribute("context", Name); });
55bdd1243dSDimitry Andric   *OS << "\n";
56bdd1243dSDimitry Andric }
57bdd1243dSDimitry Andric 
58bdd1243dSDimitry Andric void Logger::startObservation() {
59bdd1243dSDimitry Andric   auto I = ObservationIDs.insert({CurrentContext, 0});
60bdd1243dSDimitry Andric   size_t NewObservationID = I.second ? 0 : ++I.first->second;
61bdd1243dSDimitry Andric   json::OStream JOS(*OS);
62bdd1243dSDimitry Andric   JOS.object([&]() {
63bdd1243dSDimitry Andric     JOS.attribute("observation", static_cast<int64_t>(NewObservationID));
64bdd1243dSDimitry Andric   });
65bdd1243dSDimitry Andric   *OS << "\n";
66bdd1243dSDimitry Andric }
67bdd1243dSDimitry Andric 
68bdd1243dSDimitry Andric void Logger::endObservation() { *OS << "\n"; }
69bdd1243dSDimitry Andric 
70bdd1243dSDimitry Andric void Logger::logRewardImpl(const char *RawData) {
71bdd1243dSDimitry Andric   assert(IncludeReward);
72bdd1243dSDimitry Andric   json::OStream JOS(*OS);
73bdd1243dSDimitry Andric   JOS.object([&]() {
74bdd1243dSDimitry Andric     JOS.attribute("outcome", static_cast<int64_t>(
75bdd1243dSDimitry Andric                                  ObservationIDs.find(CurrentContext)->second));
76bdd1243dSDimitry Andric   });
77bdd1243dSDimitry Andric   *OS << "\n";
78bdd1243dSDimitry Andric   writeTensor(RewardSpec, RawData);
79bdd1243dSDimitry Andric   *OS << "\n";
80bdd1243dSDimitry Andric }
81bdd1243dSDimitry Andric 
82bdd1243dSDimitry Andric Logger::Logger(std::unique_ptr<raw_ostream> OS,
83bdd1243dSDimitry Andric                const std::vector<TensorSpec> &FeatureSpecs,
84*06c3fb27SDimitry Andric                const TensorSpec &RewardSpec, bool IncludeReward,
85*06c3fb27SDimitry Andric                std::optional<TensorSpec> AdviceSpec)
86bdd1243dSDimitry Andric     : OS(std::move(OS)), FeatureSpecs(FeatureSpecs), RewardSpec(RewardSpec),
87bdd1243dSDimitry Andric       IncludeReward(IncludeReward) {
88*06c3fb27SDimitry Andric   writeHeader(AdviceSpec);
89bdd1243dSDimitry Andric }
90