xref: /freebsd/contrib/llvm-project/llvm/lib/IR/ProfDataUtils.cpp (revision 770cf0a5f02dc8983a89c6568d741fbc25baa999)
1 //===- ProfDataUtils.cpp - Utility functions for MD_prof Metadata ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities for working with Profiling Metadata.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "llvm/IR/ProfDataUtils.h"
14 
15 #include "llvm/ADT/SmallVector.h"
16 #include "llvm/IR/Constants.h"
17 #include "llvm/IR/Function.h"
18 #include "llvm/IR/Instructions.h"
19 #include "llvm/IR/LLVMContext.h"
20 #include "llvm/IR/MDBuilder.h"
21 #include "llvm/IR/Metadata.h"
22 
23 using namespace llvm;
24 
25 namespace {
26 
27 // MD_prof nodes have the following layout
28 //
29 // In general:
30 // { String name,         Array of i32   }
31 //
32 // In terms of Types:
33 // { MDString,            [i32, i32, ...]}
34 //
35 // Concretely for Branch Weights
36 // { "branch_weights",    [i32 1, i32 10000]}
37 //
38 // We maintain some constants here to ensure that we access the branch weights
39 // correctly, and can change the behavior in the future if the layout changes
40 
41 // the minimum number of operands for MD_prof nodes with branch weights
42 constexpr unsigned MinBWOps = 3;
43 
44 // the minimum number of operands for MD_prof nodes with value profiles
45 constexpr unsigned MinVPOps = 5;
46 
47 // We may want to add support for other MD_prof types, so provide an abstraction
48 // for checking the metadata type.
49 bool isTargetMD(const MDNode *ProfData, const char *Name, unsigned MinOps) {
50   // TODO: This routine may be simplified if MD_prof used an enum instead of a
51   // string to differentiate the types of MD_prof nodes.
52   if (!ProfData || !Name || MinOps < 2)
53     return false;
54 
55   unsigned NOps = ProfData->getNumOperands();
56   if (NOps < MinOps)
57     return false;
58 
59   auto *ProfDataName = dyn_cast<MDString>(ProfData->getOperand(0));
60   if (!ProfDataName)
61     return false;
62 
63   return ProfDataName->getString() == Name;
64 }
65 
66 template <typename T,
67           typename = typename std::enable_if<std::is_arithmetic_v<T>>>
68 static void extractFromBranchWeightMD(const MDNode *ProfileData,
69                                       SmallVectorImpl<T> &Weights) {
70   assert(isBranchWeightMD(ProfileData) && "wrong metadata");
71 
72   unsigned NOps = ProfileData->getNumOperands();
73   unsigned WeightsIdx = getBranchWeightOffset(ProfileData);
74   assert(WeightsIdx < NOps && "Weights Index must be less than NOps.");
75   Weights.resize(NOps - WeightsIdx);
76 
77   for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) {
78     ConstantInt *Weight =
79         mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
80     assert(Weight && "Malformed branch_weight in MD_prof node");
81     assert(Weight->getValue().getActiveBits() <= (sizeof(T) * 8) &&
82            "Too many bits for MD_prof branch_weight");
83     Weights[Idx - WeightsIdx] = Weight->getZExtValue();
84   }
85 }
86 
87 } // namespace
88 
89 namespace llvm {
90 
91 const char *MDProfLabels::BranchWeights = "branch_weights";
92 const char *MDProfLabels::ExpectedBranchWeights = "expected";
93 const char *MDProfLabels::ValueProfile = "VP";
94 const char *MDProfLabels::FunctionEntryCount = "function_entry_count";
95 const char *MDProfLabels::SyntheticFunctionEntryCount =
96     "synthetic_function_entry_count";
97 const char *MDProfLabels::UnknownBranchWeightsMarker = "unknown";
98 
99 bool hasProfMD(const Instruction &I) {
100   return I.hasMetadata(LLVMContext::MD_prof);
101 }
102 
103 bool isBranchWeightMD(const MDNode *ProfileData) {
104   return isTargetMD(ProfileData, MDProfLabels::BranchWeights, MinBWOps);
105 }
106 
107 bool isValueProfileMD(const MDNode *ProfileData) {
108   return isTargetMD(ProfileData, MDProfLabels::ValueProfile, MinVPOps);
109 }
110 
111 bool hasBranchWeightMD(const Instruction &I) {
112   auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
113   return isBranchWeightMD(ProfileData);
114 }
115 
116 static bool hasCountTypeMD(const Instruction &I) {
117   auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
118   // Value profiles record count-type information.
119   if (isValueProfileMD(ProfileData))
120     return true;
121   // Conservatively assume non CallBase instruction only get taken/not-taken
122   // branch probability, so not interpret them as count.
123   return isa<CallBase>(I) && !isBranchWeightMD(ProfileData);
124 }
125 
126 bool hasValidBranchWeightMD(const Instruction &I) {
127   return getValidBranchWeightMDNode(I);
128 }
129 
130 bool hasBranchWeightOrigin(const Instruction &I) {
131   auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
132   return hasBranchWeightOrigin(ProfileData);
133 }
134 
135 bool hasBranchWeightOrigin(const MDNode *ProfileData) {
136   if (!isBranchWeightMD(ProfileData))
137     return false;
138   auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(1));
139   // NOTE: if we ever have more types of branch weight provenance,
140   // we need to check the string value is "expected". For now, we
141   // supply a more generic API, and avoid the spurious comparisons.
142   assert(ProfDataName == nullptr ||
143          ProfDataName->getString() == MDProfLabels::ExpectedBranchWeights);
144   return ProfDataName != nullptr;
145 }
146 
147 unsigned getBranchWeightOffset(const MDNode *ProfileData) {
148   return hasBranchWeightOrigin(ProfileData) ? 2 : 1;
149 }
150 
151 unsigned getNumBranchWeights(const MDNode &ProfileData) {
152   return ProfileData.getNumOperands() - getBranchWeightOffset(&ProfileData);
153 }
154 
155 MDNode *getBranchWeightMDNode(const Instruction &I) {
156   auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
157   if (!isBranchWeightMD(ProfileData))
158     return nullptr;
159   return ProfileData;
160 }
161 
162 MDNode *getValidBranchWeightMDNode(const Instruction &I) {
163   auto *ProfileData = getBranchWeightMDNode(I);
164   if (ProfileData && getNumBranchWeights(*ProfileData) == I.getNumSuccessors())
165     return ProfileData;
166   return nullptr;
167 }
168 
169 void extractFromBranchWeightMD32(const MDNode *ProfileData,
170                                  SmallVectorImpl<uint32_t> &Weights) {
171   extractFromBranchWeightMD(ProfileData, Weights);
172 }
173 
174 void extractFromBranchWeightMD64(const MDNode *ProfileData,
175                                  SmallVectorImpl<uint64_t> &Weights) {
176   extractFromBranchWeightMD(ProfileData, Weights);
177 }
178 
179 bool extractBranchWeights(const MDNode *ProfileData,
180                           SmallVectorImpl<uint32_t> &Weights) {
181   if (!isBranchWeightMD(ProfileData))
182     return false;
183   extractFromBranchWeightMD(ProfileData, Weights);
184   return true;
185 }
186 
187 bool extractBranchWeights(const Instruction &I,
188                           SmallVectorImpl<uint32_t> &Weights) {
189   auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
190   return extractBranchWeights(ProfileData, Weights);
191 }
192 
193 bool extractBranchWeights(const Instruction &I, uint64_t &TrueVal,
194                           uint64_t &FalseVal) {
195   assert((I.getOpcode() == Instruction::Br ||
196           I.getOpcode() == Instruction::Select) &&
197          "Looking for branch weights on something besides branch, select, or "
198          "switch");
199 
200   SmallVector<uint32_t, 2> Weights;
201   auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
202   if (!extractBranchWeights(ProfileData, Weights))
203     return false;
204 
205   if (Weights.size() > 2)
206     return false;
207 
208   TrueVal = Weights[0];
209   FalseVal = Weights[1];
210   return true;
211 }
212 
213 bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalVal) {
214   TotalVal = 0;
215   if (!ProfileData)
216     return false;
217 
218   auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(0));
219   if (!ProfDataName)
220     return false;
221 
222   if (ProfDataName->getString() == MDProfLabels::BranchWeights) {
223     unsigned Offset = getBranchWeightOffset(ProfileData);
224     for (unsigned Idx = Offset; Idx < ProfileData->getNumOperands(); ++Idx) {
225       auto *V = mdconst::extract<ConstantInt>(ProfileData->getOperand(Idx));
226       TotalVal += V->getValue().getZExtValue();
227     }
228     return true;
229   }
230 
231   if (ProfDataName->getString() == MDProfLabels::ValueProfile &&
232       ProfileData->getNumOperands() > 3) {
233     TotalVal = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(2))
234                    ->getValue()
235                    .getZExtValue();
236     return true;
237   }
238   return false;
239 }
240 
241 bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalVal) {
242   return extractProfTotalWeight(I.getMetadata(LLVMContext::MD_prof), TotalVal);
243 }
244 
245 void setExplicitlyUnknownBranchWeights(Instruction &I) {
246   MDBuilder MDB(I.getContext());
247   I.setMetadata(
248       LLVMContext::MD_prof,
249       MDNode::get(I.getContext(),
250                   MDB.createString(MDProfLabels::UnknownBranchWeightsMarker)));
251 }
252 
253 bool isExplicitlyUnknownBranchWeightsMetadata(const MDNode &MD) {
254   if (MD.getNumOperands() != 1)
255     return false;
256   return MD.getOperand(0).equalsStr(MDProfLabels::UnknownBranchWeightsMarker);
257 }
258 
259 bool hasExplicitlyUnknownBranchWeights(const Instruction &I) {
260   auto *MD = I.getMetadata(LLVMContext::MD_prof);
261   if (!MD)
262     return false;
263   return isExplicitlyUnknownBranchWeightsMetadata(*MD);
264 }
265 
266 void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights,
267                       bool IsExpected) {
268   MDBuilder MDB(I.getContext());
269   MDNode *BranchWeights = MDB.createBranchWeights(Weights, IsExpected);
270   I.setMetadata(LLVMContext::MD_prof, BranchWeights);
271 }
272 
273 void scaleProfData(Instruction &I, uint64_t S, uint64_t T) {
274   assert(T != 0 && "Caller should guarantee");
275   auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
276   if (ProfileData == nullptr)
277     return;
278 
279   auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(0));
280   if (!ProfDataName ||
281       (ProfDataName->getString() != MDProfLabels::BranchWeights &&
282        ProfDataName->getString() != MDProfLabels::ValueProfile))
283     return;
284 
285   if (!hasCountTypeMD(I))
286     return;
287 
288   LLVMContext &C = I.getContext();
289 
290   MDBuilder MDB(C);
291   SmallVector<Metadata *, 3> Vals;
292   Vals.push_back(ProfileData->getOperand(0));
293   APInt APS(128, S), APT(128, T);
294   if (ProfDataName->getString() == MDProfLabels::BranchWeights &&
295       ProfileData->getNumOperands() > 0) {
296     // Using APInt::div may be expensive, but most cases should fit 64 bits.
297     APInt Val(128,
298               mdconst::dyn_extract<ConstantInt>(
299                   ProfileData->getOperand(getBranchWeightOffset(ProfileData)))
300                   ->getValue()
301                   .getZExtValue());
302     Val *= APS;
303     Vals.push_back(MDB.createConstant(ConstantInt::get(
304         Type::getInt32Ty(C), Val.udiv(APT).getLimitedValue(UINT32_MAX))));
305   } else if (ProfDataName->getString() == MDProfLabels::ValueProfile)
306     for (unsigned Idx = 1; Idx < ProfileData->getNumOperands(); Idx += 2) {
307       // The first value is the key of the value profile, which will not change.
308       Vals.push_back(ProfileData->getOperand(Idx));
309       uint64_t Count =
310           mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx + 1))
311               ->getValue()
312               .getZExtValue();
313       // Don't scale the magic number.
314       if (Count == NOMORE_ICP_MAGICNUM) {
315         Vals.push_back(ProfileData->getOperand(Idx + 1));
316         continue;
317       }
318       // Using APInt::div may be expensive, but most cases should fit 64 bits.
319       APInt Val(128, Count);
320       Val *= APS;
321       Vals.push_back(MDB.createConstant(ConstantInt::get(
322           Type::getInt64Ty(C), Val.udiv(APT).getLimitedValue())));
323     }
324   I.setMetadata(LLVMContext::MD_prof, MDNode::get(C, Vals));
325 }
326 
327 } // namespace llvm
328