xref: /freebsd/contrib/llvm-project/llvm/lib/ProfileData/DataAccessProf.cpp (revision 770cf0a5f02dc8983a89c6568d741fbc25baa999)
1 #include "llvm/ProfileData/DataAccessProf.h"
2 #include "llvm/ADT/STLExtras.h"
3 #include "llvm/ProfileData/InstrProf.h"
4 #include "llvm/Support/Compression.h"
5 #include "llvm/Support/Endian.h"
6 #include "llvm/Support/Errc.h"
7 #include "llvm/Support/Error.h"
8 #include "llvm/Support/StringSaver.h"
9 #include "llvm/Support/raw_ostream.h"
10 
11 namespace llvm {
12 namespace memprof {
13 
14 // If `Map` has an entry keyed by `Str`, returns the entry iterator. Otherwise,
15 // creates an owned copy of `Str`, adds a map entry for it and returns the
16 // iterator.
17 static std::pair<StringRef, uint64_t>
18 saveStringToMap(DataAccessProfData::StringToIndexMap &Map,
19                 llvm::UniqueStringSaver &Saver, StringRef Str) {
20   auto [Iter, Inserted] = Map.try_emplace(Saver.save(Str), Map.size());
21   return *Iter;
22 }
23 
24 // Returns the canonical name or error.
25 static Expected<StringRef> getCanonicalName(StringRef Name) {
26   if (Name.empty())
27     return make_error<StringError>("Empty symbol name",
28                                    llvm::errc::invalid_argument);
29   return InstrProfSymtab::getCanonicalName(Name);
30 }
31 
32 std::optional<DataAccessProfRecord>
33 DataAccessProfData::getProfileRecord(const SymbolHandleRef SymbolID) const {
34   auto Key = SymbolID;
35   if (std::holds_alternative<StringRef>(SymbolID)) {
36     auto NameOrErr = getCanonicalName(std::get<StringRef>(SymbolID));
37     // If name canonicalization fails, suppress the error inside.
38     if (!NameOrErr) {
39       assert(
40           std::get<StringRef>(SymbolID).empty() &&
41           "Name canonicalization only fails when stringified string is empty.");
42       return std::nullopt;
43     }
44     Key = *NameOrErr;
45   }
46 
47   auto It = Records.find(Key);
48   if (It != Records.end()) {
49     return DataAccessProfRecord(Key, It->second.AccessCount,
50                                 It->second.Locations);
51   }
52 
53   return std::nullopt;
54 }
55 
56 bool DataAccessProfData::isKnownColdSymbol(const SymbolHandleRef SymID) const {
57   if (std::holds_alternative<uint64_t>(SymID))
58     return KnownColdHashes.contains(std::get<uint64_t>(SymID));
59   return KnownColdSymbols.contains(std::get<StringRef>(SymID));
60 }
61 
62 Error DataAccessProfData::setDataAccessProfile(SymbolHandleRef Symbol,
63                                                uint64_t AccessCount) {
64   uint64_t RecordID = -1;
65   const bool IsStringLiteral = std::holds_alternative<uint64_t>(Symbol);
66   SymbolHandleRef Key;
67   if (IsStringLiteral) {
68     RecordID = std::get<uint64_t>(Symbol);
69     Key = RecordID;
70   } else {
71     auto CanonicalName = getCanonicalName(std::get<StringRef>(Symbol));
72     if (!CanonicalName)
73       return CanonicalName.takeError();
74     std::tie(Key, RecordID) =
75         saveStringToMap(StrToIndexMap, Saver, *CanonicalName);
76   }
77 
78   auto [Iter, Inserted] =
79       Records.try_emplace(Key, RecordID, AccessCount, IsStringLiteral);
80   if (!Inserted)
81     return make_error<StringError>("Duplicate symbol or string literal added. "
82                                    "User of DataAccessProfData should "
83                                    "aggregate count for the same symbol. ",
84                                    llvm::errc::invalid_argument);
85 
86   return Error::success();
87 }
88 
89 Error DataAccessProfData::setDataAccessProfile(
90     SymbolHandleRef SymbolID, uint64_t AccessCount,
91     ArrayRef<SourceLocation> Locations) {
92   if (Error E = setDataAccessProfile(SymbolID, AccessCount))
93     return E;
94 
95   auto &Record = Records.back().second;
96   for (const auto &Location : Locations)
97     Record.Locations.push_back(
98         {saveStringToMap(StrToIndexMap, Saver, Location.FileName).first,
99          Location.Line});
100 
101   return Error::success();
102 }
103 
104 Error DataAccessProfData::addKnownSymbolWithoutSamples(
105     SymbolHandleRef SymbolID) {
106   if (std::holds_alternative<uint64_t>(SymbolID)) {
107     KnownColdHashes.insert(std::get<uint64_t>(SymbolID));
108     return Error::success();
109   }
110   auto CanonicalName = getCanonicalName(std::get<StringRef>(SymbolID));
111   if (!CanonicalName)
112     return CanonicalName.takeError();
113   KnownColdSymbols.insert(
114       saveStringToMap(StrToIndexMap, Saver, *CanonicalName).first);
115   return Error::success();
116 }
117 
118 Error DataAccessProfData::deserialize(const unsigned char *&Ptr) {
119   uint64_t NumSampledSymbols =
120       support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
121   uint64_t NumColdKnownSymbols =
122       support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
123   if (Error E = deserializeSymbolsAndFilenames(Ptr, NumSampledSymbols,
124                                                NumColdKnownSymbols))
125     return E;
126 
127   uint64_t Num =
128       support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
129   for (uint64_t I = 0; I < Num; ++I)
130     KnownColdHashes.insert(
131         support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr));
132 
133   return deserializeRecords(Ptr);
134 }
135 
136 Error DataAccessProfData::serializeSymbolsAndFilenames(ProfOStream &OS) const {
137   OS.write(StrToIndexMap.size());
138   OS.write(KnownColdSymbols.size());
139 
140   std::vector<std::string> Strs;
141   Strs.reserve(StrToIndexMap.size() + KnownColdSymbols.size());
142   for (const auto &Str : StrToIndexMap)
143     Strs.push_back(Str.first.str());
144   for (const auto &Str : KnownColdSymbols)
145     Strs.push_back(Str.str());
146 
147   std::string CompressedStrings;
148   if (!Strs.empty())
149     if (Error E = collectGlobalObjectNameStrings(
150             Strs, compression::zlib::isAvailable(), CompressedStrings))
151       return E;
152   const uint64_t CompressedStringLen = CompressedStrings.length();
153   // Record the length of compressed string.
154   OS.write(CompressedStringLen);
155   // Write the chars in compressed strings.
156   for (char C : CompressedStrings)
157     OS.writeByte(static_cast<uint8_t>(C));
158   // Pad up to a multiple of 8.
159   // InstrProfReader could read bytes according to 'CompressedStringLen'.
160   const uint64_t PaddedLength = alignTo(CompressedStringLen, 8);
161   for (uint64_t K = CompressedStringLen; K < PaddedLength; K++)
162     OS.writeByte(0);
163   return Error::success();
164 }
165 
166 uint64_t
167 DataAccessProfData::getEncodedIndex(const SymbolHandleRef SymbolID) const {
168   if (std::holds_alternative<uint64_t>(SymbolID))
169     return std::get<uint64_t>(SymbolID);
170 
171   auto Iter = StrToIndexMap.find(std::get<StringRef>(SymbolID));
172   assert(Iter != StrToIndexMap.end() &&
173          "String literals not found in StrToIndexMap");
174   return Iter->second;
175 }
176 
177 Error DataAccessProfData::serialize(ProfOStream &OS) const {
178   if (Error E = serializeSymbolsAndFilenames(OS))
179     return E;
180   OS.write(KnownColdHashes.size());
181   for (const auto &Hash : KnownColdHashes)
182     OS.write(Hash);
183   OS.write((uint64_t)(Records.size()));
184   for (const auto &[Key, Rec] : Records) {
185     OS.write(getEncodedIndex(Rec.SymbolID));
186     OS.writeByte(Rec.IsStringLiteral);
187     OS.write(Rec.AccessCount);
188     OS.write(Rec.Locations.size());
189     for (const auto &Loc : Rec.Locations) {
190       OS.write(getEncodedIndex(Loc.FileName));
191       OS.write32(Loc.Line);
192     }
193   }
194   return Error::success();
195 }
196 
197 Error DataAccessProfData::deserializeSymbolsAndFilenames(
198     const unsigned char *&Ptr, const uint64_t NumSampledSymbols,
199     const uint64_t NumColdKnownSymbols) {
200   uint64_t Len =
201       support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
202 
203   // The first NumSampledSymbols strings are symbols with samples, and next
204   // NumColdKnownSymbols strings are known cold symbols.
205   uint64_t StringCnt = 0;
206   std::function<Error(StringRef)> addName = [&](StringRef Name) {
207     if (StringCnt < NumSampledSymbols)
208       saveStringToMap(StrToIndexMap, Saver, Name);
209     else
210       KnownColdSymbols.insert(Saver.save(Name));
211     ++StringCnt;
212     return Error::success();
213   };
214   if (Error E =
215           readAndDecodeStrings(StringRef((const char *)Ptr, Len), addName))
216     return E;
217 
218   Ptr += alignTo(Len, 8);
219   return Error::success();
220 }
221 
222 Error DataAccessProfData::deserializeRecords(const unsigned char *&Ptr) {
223   SmallVector<StringRef> Strings =
224       llvm::to_vector(llvm::make_first_range(getStrToIndexMapRef()));
225 
226   uint64_t NumRecords =
227       support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
228 
229   for (uint64_t I = 0; I < NumRecords; ++I) {
230     uint64_t ID =
231         support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
232 
233     bool IsStringLiteral =
234         support::endian::readNext<uint8_t, llvm::endianness::little>(Ptr);
235 
236     uint64_t AccessCount =
237         support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
238 
239     SymbolHandleRef SymbolID;
240     if (IsStringLiteral)
241       SymbolID = ID;
242     else
243       SymbolID = Strings[ID];
244     if (Error E = setDataAccessProfile(SymbolID, AccessCount))
245       return E;
246 
247     auto &Record = Records.back().second;
248 
249     uint64_t NumLocations =
250         support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
251 
252     Record.Locations.reserve(NumLocations);
253     for (uint64_t J = 0; J < NumLocations; ++J) {
254       uint64_t FileNameIndex =
255           support::endian::readNext<uint64_t, llvm::endianness::little>(Ptr);
256       uint32_t Line =
257           support::endian::readNext<uint32_t, llvm::endianness::little>(Ptr);
258       Record.Locations.push_back({Strings[FileNameIndex], Line});
259     }
260   }
261   return Error::success();
262 }
263 } // namespace memprof
264 } // namespace llvm
265