xref: /freebsd/contrib/llvm-project/llvm/lib/Analysis/TrainingLogger.cpp (revision 1fd880742ace94e11fa60ee0b074f0b18e54c54f)
1  //===- TrainingLogger.cpp - mlgo feature/reward logging -------------------===//
2  //
3  // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4  // See https://llvm.org/LICENSE.txt for license information.
5  // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6  //
7  //===----------------------------------------------------------------------===//
8  //
9  // This file implements logging infrastructure for extracting features and
10  // rewards for mlgo policy training.
11  //
12  //===----------------------------------------------------------------------===//
13  #include "llvm/Analysis/TensorSpec.h"
14  #include "llvm/Config/config.h"
15  
16  #include "llvm/ADT/Twine.h"
17  #include "llvm/Analysis/Utils/TrainingLogger.h"
18  #include "llvm/Support/CommandLine.h"
19  #include "llvm/Support/Debug.h"
20  #include "llvm/Support/JSON.h"
21  #include "llvm/Support/MemoryBuffer.h"
22  #include "llvm/Support/Path.h"
23  #include "llvm/Support/raw_ostream.h"
24  
25  #include <cassert>
26  #include <numeric>
27  
28  using namespace llvm;
29  
30  void Logger::writeHeader(std::optional<TensorSpec> AdviceSpec) {
31    json::OStream JOS(*OS);
32    JOS.object([&]() {
33      JOS.attributeArray("features", [&]() {
34        for (const auto &TS : FeatureSpecs)
35          TS.toJSON(JOS);
36      });
37      if (IncludeReward) {
38        JOS.attributeBegin("score");
39        RewardSpec.toJSON(JOS);
40        JOS.attributeEnd();
41      }
42      if (AdviceSpec.has_value()) {
43        JOS.attributeBegin("advice");
44        AdviceSpec->toJSON(JOS);
45        JOS.attributeEnd();
46      }
47    });
48    *OS << "\n";
49  }
50  
51  void Logger::switchContext(StringRef Name) {
52    CurrentContext = Name.str();
53    json::OStream JOS(*OS);
54    JOS.object([&]() { JOS.attribute("context", Name); });
55    *OS << "\n";
56  }
57  
58  void Logger::startObservation() {
59    auto I = ObservationIDs.insert({CurrentContext, 0});
60    size_t NewObservationID = I.second ? 0 : ++I.first->second;
61    json::OStream JOS(*OS);
62    JOS.object([&]() {
63      JOS.attribute("observation", static_cast<int64_t>(NewObservationID));
64    });
65    *OS << "\n";
66  }
67  
68  void Logger::endObservation() { *OS << "\n"; }
69  
70  void Logger::logRewardImpl(const char *RawData) {
71    assert(IncludeReward);
72    json::OStream JOS(*OS);
73    JOS.object([&]() {
74      JOS.attribute("outcome", static_cast<int64_t>(
75                                   ObservationIDs.find(CurrentContext)->second));
76    });
77    *OS << "\n";
78    writeTensor(RewardSpec, RawData);
79    *OS << "\n";
80  }
81  
82  Logger::Logger(std::unique_ptr<raw_ostream> OS,
83                 const std::vector<TensorSpec> &FeatureSpecs,
84                 const TensorSpec &RewardSpec, bool IncludeReward,
85                 std::optional<TensorSpec> AdviceSpec)
86      : OS(std::move(OS)), FeatureSpecs(FeatureSpecs), RewardSpec(RewardSpec),
87        IncludeReward(IncludeReward) {
88    writeHeader(AdviceSpec);
89  }
90