xref: /freebsd/contrib/llvm-project/llvm/lib/Analysis/TrainingLogger.cpp (revision 47e073941f4e7ca6e9bde3fa65abbfcfed6bfa2b)
1  //===- TrainingLogger.cpp - mlgo feature/reward logging -------------------===//
2  //
3  // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4  // See https://llvm.org/LICENSE.txt for license information.
5  // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6  //
7  //===----------------------------------------------------------------------===//
8  //
9  // This file implements logging infrastructure for extracting features and
10  // rewards for mlgo policy training.
11  //
12  //===----------------------------------------------------------------------===//
13  #include "llvm/Analysis/TensorSpec.h"
14  #include "llvm/Config/config.h"
15  
16  #include "llvm/ADT/Twine.h"
17  #include "llvm/Analysis/Utils/TrainingLogger.h"
18  #include "llvm/Support/CommandLine.h"
19  #include "llvm/Support/Debug.h"
20  #include "llvm/Support/JSON.h"
21  #include "llvm/Support/MemoryBuffer.h"
22  #include "llvm/Support/Path.h"
23  #include "llvm/Support/raw_ostream.h"
24  
25  #include <cassert>
26  #include <numeric>
27  
28  using namespace llvm;
29  
30  // FIXME(mtrofin): remove the flag altogether
31  static cl::opt<bool>
32      UseSimpleLogger("tfutils-use-simplelogger", cl::init(true), cl::Hidden,
33                      cl::desc("Output simple (non-protobuf) log."));
34  
35  void Logger::writeHeader(std::optional<TensorSpec> AdviceSpec) {
36    json::OStream JOS(*OS);
37    JOS.object([&]() {
38      JOS.attributeArray("features", [&]() {
39        for (const auto &TS : FeatureSpecs)
40          TS.toJSON(JOS);
41      });
42      if (IncludeReward) {
43        JOS.attributeBegin("score");
44        RewardSpec.toJSON(JOS);
45        JOS.attributeEnd();
46      }
47      if (AdviceSpec.has_value()) {
48        JOS.attributeBegin("advice");
49        AdviceSpec->toJSON(JOS);
50        JOS.attributeEnd();
51      }
52    });
53    *OS << "\n";
54  }
55  
56  void Logger::switchContext(StringRef Name) {
57    CurrentContext = Name.str();
58    json::OStream JOS(*OS);
59    JOS.object([&]() { JOS.attribute("context", Name); });
60    *OS << "\n";
61  }
62  
63  void Logger::startObservation() {
64    auto I = ObservationIDs.insert({CurrentContext, 0});
65    size_t NewObservationID = I.second ? 0 : ++I.first->second;
66    json::OStream JOS(*OS);
67    JOS.object([&]() {
68      JOS.attribute("observation", static_cast<int64_t>(NewObservationID));
69    });
70    *OS << "\n";
71  }
72  
73  void Logger::endObservation() { *OS << "\n"; }
74  
75  void Logger::logRewardImpl(const char *RawData) {
76    assert(IncludeReward);
77    json::OStream JOS(*OS);
78    JOS.object([&]() {
79      JOS.attribute("outcome", static_cast<int64_t>(
80                                   ObservationIDs.find(CurrentContext)->second));
81    });
82    *OS << "\n";
83    writeTensor(RewardSpec, RawData);
84    *OS << "\n";
85  }
86  
87  Logger::Logger(std::unique_ptr<raw_ostream> OS,
88                 const std::vector<TensorSpec> &FeatureSpecs,
89                 const TensorSpec &RewardSpec, bool IncludeReward,
90                 std::optional<TensorSpec> AdviceSpec)
91      : OS(std::move(OS)), FeatureSpecs(FeatureSpecs), RewardSpec(RewardSpec),
92        IncludeReward(IncludeReward) {
93    writeHeader(AdviceSpec);
94  }
95