1bdd1243dSDimitry Andric //===- ProfDataUtils.cpp - Utility functions for MD_prof Metadata ---------===//
2bdd1243dSDimitry Andric //
3bdd1243dSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4bdd1243dSDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5bdd1243dSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6bdd1243dSDimitry Andric //
7bdd1243dSDimitry Andric //===----------------------------------------------------------------------===//
8bdd1243dSDimitry Andric //
9bdd1243dSDimitry Andric // This file implements utilities for working with Profiling Metadata.
10bdd1243dSDimitry Andric //
11bdd1243dSDimitry Andric //===----------------------------------------------------------------------===//
12bdd1243dSDimitry Andric
13bdd1243dSDimitry Andric #include "llvm/IR/ProfDataUtils.h"
14bdd1243dSDimitry Andric #include "llvm/ADT/SmallVector.h"
15bdd1243dSDimitry Andric #include "llvm/ADT/Twine.h"
16bdd1243dSDimitry Andric #include "llvm/IR/Constants.h"
17bdd1243dSDimitry Andric #include "llvm/IR/Function.h"
18bdd1243dSDimitry Andric #include "llvm/IR/Instructions.h"
19bdd1243dSDimitry Andric #include "llvm/IR/LLVMContext.h"
205f757f3fSDimitry Andric #include "llvm/IR/MDBuilder.h"
21bdd1243dSDimitry Andric #include "llvm/IR/Metadata.h"
22*0fca6ea1SDimitry Andric #include "llvm/IR/ProfDataUtils.h"
23bdd1243dSDimitry Andric #include "llvm/Support/BranchProbability.h"
24bdd1243dSDimitry Andric #include "llvm/Support/CommandLine.h"
25bdd1243dSDimitry Andric
26bdd1243dSDimitry Andric using namespace llvm;
27bdd1243dSDimitry Andric
28bdd1243dSDimitry Andric namespace {
29bdd1243dSDimitry Andric
30bdd1243dSDimitry Andric // MD_prof nodes have the following layout
31bdd1243dSDimitry Andric //
32bdd1243dSDimitry Andric // In general:
33bdd1243dSDimitry Andric // { String name, Array of i32 }
34bdd1243dSDimitry Andric //
35bdd1243dSDimitry Andric // In terms of Types:
36bdd1243dSDimitry Andric // { MDString, [i32, i32, ...]}
37bdd1243dSDimitry Andric //
38bdd1243dSDimitry Andric // Concretely for Branch Weights
39bdd1243dSDimitry Andric // { "branch_weights", [i32 1, i32 10000]}
40bdd1243dSDimitry Andric //
41bdd1243dSDimitry Andric // We maintain some constants here to ensure that we access the branch weights
42bdd1243dSDimitry Andric // correctly, and can change the behavior in the future if the layout changes
43bdd1243dSDimitry Andric
44bdd1243dSDimitry Andric // the minimum number of operands for MD_prof nodes with branch weights
45bdd1243dSDimitry Andric constexpr unsigned MinBWOps = 3;
46bdd1243dSDimitry Andric
47*0fca6ea1SDimitry Andric // the minimum number of operands for MD_prof nodes with value profiles
48*0fca6ea1SDimitry Andric constexpr unsigned MinVPOps = 5;
49*0fca6ea1SDimitry Andric
50bdd1243dSDimitry Andric // We may want to add support for other MD_prof types, so provide an abstraction
51bdd1243dSDimitry Andric // for checking the metadata type.
isTargetMD(const MDNode * ProfData,const char * Name,unsigned MinOps)52bdd1243dSDimitry Andric bool isTargetMD(const MDNode *ProfData, const char *Name, unsigned MinOps) {
53bdd1243dSDimitry Andric // TODO: This routine may be simplified if MD_prof used an enum instead of a
54bdd1243dSDimitry Andric // string to differentiate the types of MD_prof nodes.
55bdd1243dSDimitry Andric if (!ProfData || !Name || MinOps < 2)
56bdd1243dSDimitry Andric return false;
57bdd1243dSDimitry Andric
58bdd1243dSDimitry Andric unsigned NOps = ProfData->getNumOperands();
59bdd1243dSDimitry Andric if (NOps < MinOps)
60bdd1243dSDimitry Andric return false;
61bdd1243dSDimitry Andric
62bdd1243dSDimitry Andric auto *ProfDataName = dyn_cast<MDString>(ProfData->getOperand(0));
63bdd1243dSDimitry Andric if (!ProfDataName)
64bdd1243dSDimitry Andric return false;
65bdd1243dSDimitry Andric
66*0fca6ea1SDimitry Andric return ProfDataName->getString() == Name;
67*0fca6ea1SDimitry Andric }
68*0fca6ea1SDimitry Andric
69*0fca6ea1SDimitry Andric template <typename T,
70*0fca6ea1SDimitry Andric typename = typename std::enable_if<std::is_arithmetic_v<T>>>
extractFromBranchWeightMD(const MDNode * ProfileData,SmallVectorImpl<T> & Weights)71*0fca6ea1SDimitry Andric static void extractFromBranchWeightMD(const MDNode *ProfileData,
72*0fca6ea1SDimitry Andric SmallVectorImpl<T> &Weights) {
73*0fca6ea1SDimitry Andric assert(isBranchWeightMD(ProfileData) && "wrong metadata");
74*0fca6ea1SDimitry Andric
75*0fca6ea1SDimitry Andric unsigned NOps = ProfileData->getNumOperands();
76*0fca6ea1SDimitry Andric unsigned WeightsIdx = getBranchWeightOffset(ProfileData);
77*0fca6ea1SDimitry Andric assert(WeightsIdx < NOps && "Weights Index must be less than NOps.");
78*0fca6ea1SDimitry Andric Weights.resize(NOps - WeightsIdx);
79*0fca6ea1SDimitry Andric
80*0fca6ea1SDimitry Andric for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) {
81*0fca6ea1SDimitry Andric ConstantInt *Weight =
82*0fca6ea1SDimitry Andric mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
83*0fca6ea1SDimitry Andric assert(Weight && "Malformed branch_weight in MD_prof node");
84*0fca6ea1SDimitry Andric assert(Weight->getValue().getActiveBits() <= (sizeof(T) * 8) &&
85*0fca6ea1SDimitry Andric "Too many bits for MD_prof branch_weight");
86*0fca6ea1SDimitry Andric Weights[Idx - WeightsIdx] = Weight->getZExtValue();
87*0fca6ea1SDimitry Andric }
88bdd1243dSDimitry Andric }
89bdd1243dSDimitry Andric
90bdd1243dSDimitry Andric } // namespace
91bdd1243dSDimitry Andric
92bdd1243dSDimitry Andric namespace llvm {
93bdd1243dSDimitry Andric
hasProfMD(const Instruction & I)94bdd1243dSDimitry Andric bool hasProfMD(const Instruction &I) {
95*0fca6ea1SDimitry Andric return I.hasMetadata(LLVMContext::MD_prof);
96bdd1243dSDimitry Andric }
97bdd1243dSDimitry Andric
isBranchWeightMD(const MDNode * ProfileData)98bdd1243dSDimitry Andric bool isBranchWeightMD(const MDNode *ProfileData) {
99bdd1243dSDimitry Andric return isTargetMD(ProfileData, "branch_weights", MinBWOps);
100bdd1243dSDimitry Andric }
101bdd1243dSDimitry Andric
isValueProfileMD(const MDNode * ProfileData)102*0fca6ea1SDimitry Andric bool isValueProfileMD(const MDNode *ProfileData) {
103*0fca6ea1SDimitry Andric return isTargetMD(ProfileData, "VP", MinVPOps);
104*0fca6ea1SDimitry Andric }
105*0fca6ea1SDimitry Andric
hasBranchWeightMD(const Instruction & I)106bdd1243dSDimitry Andric bool hasBranchWeightMD(const Instruction &I) {
107bdd1243dSDimitry Andric auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
108bdd1243dSDimitry Andric return isBranchWeightMD(ProfileData);
109bdd1243dSDimitry Andric }
110bdd1243dSDimitry Andric
hasCountTypeMD(const Instruction & I)111*0fca6ea1SDimitry Andric bool hasCountTypeMD(const Instruction &I) {
112*0fca6ea1SDimitry Andric auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
113*0fca6ea1SDimitry Andric // Value profiles record count-type information.
114*0fca6ea1SDimitry Andric if (isValueProfileMD(ProfileData))
115*0fca6ea1SDimitry Andric return true;
116*0fca6ea1SDimitry Andric // Conservatively assume non CallBase instruction only get taken/not-taken
117*0fca6ea1SDimitry Andric // branch probability, so not interpret them as count.
118*0fca6ea1SDimitry Andric return isa<CallBase>(I) && !isBranchWeightMD(ProfileData);
119*0fca6ea1SDimitry Andric }
120*0fca6ea1SDimitry Andric
hasValidBranchWeightMD(const Instruction & I)121bdd1243dSDimitry Andric bool hasValidBranchWeightMD(const Instruction &I) {
122bdd1243dSDimitry Andric return getValidBranchWeightMDNode(I);
123bdd1243dSDimitry Andric }
124bdd1243dSDimitry Andric
hasBranchWeightOrigin(const Instruction & I)125*0fca6ea1SDimitry Andric bool hasBranchWeightOrigin(const Instruction &I) {
126*0fca6ea1SDimitry Andric auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
127*0fca6ea1SDimitry Andric return hasBranchWeightOrigin(ProfileData);
128*0fca6ea1SDimitry Andric }
129*0fca6ea1SDimitry Andric
hasBranchWeightOrigin(const MDNode * ProfileData)130*0fca6ea1SDimitry Andric bool hasBranchWeightOrigin(const MDNode *ProfileData) {
131*0fca6ea1SDimitry Andric if (!isBranchWeightMD(ProfileData))
132*0fca6ea1SDimitry Andric return false;
133*0fca6ea1SDimitry Andric auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(1));
134*0fca6ea1SDimitry Andric // NOTE: if we ever have more types of branch weight provenance,
135*0fca6ea1SDimitry Andric // we need to check the string value is "expected". For now, we
136*0fca6ea1SDimitry Andric // supply a more generic API, and avoid the spurious comparisons.
137*0fca6ea1SDimitry Andric assert(ProfDataName == nullptr || ProfDataName->getString() == "expected");
138*0fca6ea1SDimitry Andric return ProfDataName != nullptr;
139*0fca6ea1SDimitry Andric }
140*0fca6ea1SDimitry Andric
getBranchWeightOffset(const MDNode * ProfileData)141*0fca6ea1SDimitry Andric unsigned getBranchWeightOffset(const MDNode *ProfileData) {
142*0fca6ea1SDimitry Andric return hasBranchWeightOrigin(ProfileData) ? 2 : 1;
143*0fca6ea1SDimitry Andric }
144*0fca6ea1SDimitry Andric
getNumBranchWeights(const MDNode & ProfileData)145*0fca6ea1SDimitry Andric unsigned getNumBranchWeights(const MDNode &ProfileData) {
146*0fca6ea1SDimitry Andric return ProfileData.getNumOperands() - getBranchWeightOffset(&ProfileData);
147*0fca6ea1SDimitry Andric }
148*0fca6ea1SDimitry Andric
getBranchWeightMDNode(const Instruction & I)149bdd1243dSDimitry Andric MDNode *getBranchWeightMDNode(const Instruction &I) {
150bdd1243dSDimitry Andric auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
151bdd1243dSDimitry Andric if (!isBranchWeightMD(ProfileData))
152bdd1243dSDimitry Andric return nullptr;
153bdd1243dSDimitry Andric return ProfileData;
154bdd1243dSDimitry Andric }
155bdd1243dSDimitry Andric
getValidBranchWeightMDNode(const Instruction & I)156bdd1243dSDimitry Andric MDNode *getValidBranchWeightMDNode(const Instruction &I) {
157bdd1243dSDimitry Andric auto *ProfileData = getBranchWeightMDNode(I);
158*0fca6ea1SDimitry Andric if (ProfileData && getNumBranchWeights(*ProfileData) == I.getNumSuccessors())
159bdd1243dSDimitry Andric return ProfileData;
160bdd1243dSDimitry Andric return nullptr;
161bdd1243dSDimitry Andric }
162bdd1243dSDimitry Andric
extractFromBranchWeightMD32(const MDNode * ProfileData,SmallVectorImpl<uint32_t> & Weights)163*0fca6ea1SDimitry Andric void extractFromBranchWeightMD32(const MDNode *ProfileData,
1645f757f3fSDimitry Andric SmallVectorImpl<uint32_t> &Weights) {
165*0fca6ea1SDimitry Andric extractFromBranchWeightMD(ProfileData, Weights);
1665f757f3fSDimitry Andric }
167*0fca6ea1SDimitry Andric
extractFromBranchWeightMD64(const MDNode * ProfileData,SmallVectorImpl<uint64_t> & Weights)168*0fca6ea1SDimitry Andric void extractFromBranchWeightMD64(const MDNode *ProfileData,
169*0fca6ea1SDimitry Andric SmallVectorImpl<uint64_t> &Weights) {
170*0fca6ea1SDimitry Andric extractFromBranchWeightMD(ProfileData, Weights);
1715f757f3fSDimitry Andric }
1725f757f3fSDimitry Andric
extractBranchWeights(const MDNode * ProfileData,SmallVectorImpl<uint32_t> & Weights)173bdd1243dSDimitry Andric bool extractBranchWeights(const MDNode *ProfileData,
174bdd1243dSDimitry Andric SmallVectorImpl<uint32_t> &Weights) {
175bdd1243dSDimitry Andric if (!isBranchWeightMD(ProfileData))
176bdd1243dSDimitry Andric return false;
1775f757f3fSDimitry Andric extractFromBranchWeightMD(ProfileData, Weights);
1785f757f3fSDimitry Andric return true;
179bdd1243dSDimitry Andric }
180bdd1243dSDimitry Andric
extractBranchWeights(const Instruction & I,SmallVectorImpl<uint32_t> & Weights)181bdd1243dSDimitry Andric bool extractBranchWeights(const Instruction &I,
182bdd1243dSDimitry Andric SmallVectorImpl<uint32_t> &Weights) {
183bdd1243dSDimitry Andric auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
184bdd1243dSDimitry Andric return extractBranchWeights(ProfileData, Weights);
185bdd1243dSDimitry Andric }
186bdd1243dSDimitry Andric
extractBranchWeights(const Instruction & I,uint64_t & TrueVal,uint64_t & FalseVal)187bdd1243dSDimitry Andric bool extractBranchWeights(const Instruction &I, uint64_t &TrueVal,
188bdd1243dSDimitry Andric uint64_t &FalseVal) {
189bdd1243dSDimitry Andric assert((I.getOpcode() == Instruction::Br ||
190bdd1243dSDimitry Andric I.getOpcode() == Instruction::Select) &&
191bdd1243dSDimitry Andric "Looking for branch weights on something besides branch, select, or "
192bdd1243dSDimitry Andric "switch");
193bdd1243dSDimitry Andric
194bdd1243dSDimitry Andric SmallVector<uint32_t, 2> Weights;
195bdd1243dSDimitry Andric auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
196bdd1243dSDimitry Andric if (!extractBranchWeights(ProfileData, Weights))
197bdd1243dSDimitry Andric return false;
198bdd1243dSDimitry Andric
199bdd1243dSDimitry Andric if (Weights.size() > 2)
200bdd1243dSDimitry Andric return false;
201bdd1243dSDimitry Andric
202bdd1243dSDimitry Andric TrueVal = Weights[0];
203bdd1243dSDimitry Andric FalseVal = Weights[1];
204bdd1243dSDimitry Andric return true;
205bdd1243dSDimitry Andric }
206bdd1243dSDimitry Andric
extractProfTotalWeight(const MDNode * ProfileData,uint64_t & TotalVal)207bdd1243dSDimitry Andric bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalVal) {
208bdd1243dSDimitry Andric TotalVal = 0;
209bdd1243dSDimitry Andric if (!ProfileData)
210bdd1243dSDimitry Andric return false;
211bdd1243dSDimitry Andric
212bdd1243dSDimitry Andric auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(0));
213bdd1243dSDimitry Andric if (!ProfDataName)
214bdd1243dSDimitry Andric return false;
215bdd1243dSDimitry Andric
216*0fca6ea1SDimitry Andric if (ProfDataName->getString() == "branch_weights") {
217*0fca6ea1SDimitry Andric unsigned Offset = getBranchWeightOffset(ProfileData);
218*0fca6ea1SDimitry Andric for (unsigned Idx = Offset; Idx < ProfileData->getNumOperands(); ++Idx) {
219bdd1243dSDimitry Andric auto *V = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
220bdd1243dSDimitry Andric assert(V && "Malformed branch_weight in MD_prof node");
221bdd1243dSDimitry Andric TotalVal += V->getValue().getZExtValue();
222bdd1243dSDimitry Andric }
223bdd1243dSDimitry Andric return true;
224bdd1243dSDimitry Andric }
225bdd1243dSDimitry Andric
226*0fca6ea1SDimitry Andric if (ProfDataName->getString() == "VP" && ProfileData->getNumOperands() > 3) {
227bdd1243dSDimitry Andric TotalVal = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(2))
228bdd1243dSDimitry Andric ->getValue()
229bdd1243dSDimitry Andric .getZExtValue();
230bdd1243dSDimitry Andric return true;
231bdd1243dSDimitry Andric }
232bdd1243dSDimitry Andric return false;
233bdd1243dSDimitry Andric }
234bdd1243dSDimitry Andric
extractProfTotalWeight(const Instruction & I,uint64_t & TotalVal)235bdd1243dSDimitry Andric bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalVal) {
236bdd1243dSDimitry Andric return extractProfTotalWeight(I.getMetadata(LLVMContext::MD_prof), TotalVal);
237bdd1243dSDimitry Andric }
238bdd1243dSDimitry Andric
setBranchWeights(Instruction & I,ArrayRef<uint32_t> Weights,bool IsExpected)239*0fca6ea1SDimitry Andric void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights,
240*0fca6ea1SDimitry Andric bool IsExpected) {
2415f757f3fSDimitry Andric MDBuilder MDB(I.getContext());
242*0fca6ea1SDimitry Andric MDNode *BranchWeights = MDB.createBranchWeights(Weights, IsExpected);
2435f757f3fSDimitry Andric I.setMetadata(LLVMContext::MD_prof, BranchWeights);
2445f757f3fSDimitry Andric }
2455f757f3fSDimitry Andric
scaleProfData(Instruction & I,uint64_t S,uint64_t T)246*0fca6ea1SDimitry Andric void scaleProfData(Instruction &I, uint64_t S, uint64_t T) {
247*0fca6ea1SDimitry Andric assert(T != 0 && "Caller should guarantee");
248*0fca6ea1SDimitry Andric auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
249*0fca6ea1SDimitry Andric if (ProfileData == nullptr)
250*0fca6ea1SDimitry Andric return;
251*0fca6ea1SDimitry Andric
252*0fca6ea1SDimitry Andric auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(0));
253*0fca6ea1SDimitry Andric if (!ProfDataName || (ProfDataName->getString() != "branch_weights" &&
254*0fca6ea1SDimitry Andric ProfDataName->getString() != "VP"))
255*0fca6ea1SDimitry Andric return;
256*0fca6ea1SDimitry Andric
257*0fca6ea1SDimitry Andric if (!hasCountTypeMD(I))
258*0fca6ea1SDimitry Andric return;
259*0fca6ea1SDimitry Andric
260*0fca6ea1SDimitry Andric LLVMContext &C = I.getContext();
261*0fca6ea1SDimitry Andric
262*0fca6ea1SDimitry Andric MDBuilder MDB(C);
263*0fca6ea1SDimitry Andric SmallVector<Metadata *, 3> Vals;
264*0fca6ea1SDimitry Andric Vals.push_back(ProfileData->getOperand(0));
265*0fca6ea1SDimitry Andric APInt APS(128, S), APT(128, T);
266*0fca6ea1SDimitry Andric if (ProfDataName->getString() == "branch_weights" &&
267*0fca6ea1SDimitry Andric ProfileData->getNumOperands() > 0) {
268*0fca6ea1SDimitry Andric // Using APInt::div may be expensive, but most cases should fit 64 bits.
269*0fca6ea1SDimitry Andric APInt Val(128,
270*0fca6ea1SDimitry Andric mdconst::dyn_extract<ConstantInt>(
271*0fca6ea1SDimitry Andric ProfileData->getOperand(getBranchWeightOffset(ProfileData)))
272*0fca6ea1SDimitry Andric ->getValue()
273*0fca6ea1SDimitry Andric .getZExtValue());
274*0fca6ea1SDimitry Andric Val *= APS;
275*0fca6ea1SDimitry Andric Vals.push_back(MDB.createConstant(ConstantInt::get(
276*0fca6ea1SDimitry Andric Type::getInt32Ty(C), Val.udiv(APT).getLimitedValue(UINT32_MAX))));
277*0fca6ea1SDimitry Andric } else if (ProfDataName->getString() == "VP")
278*0fca6ea1SDimitry Andric for (unsigned i = 1; i < ProfileData->getNumOperands(); i += 2) {
279*0fca6ea1SDimitry Andric // The first value is the key of the value profile, which will not change.
280*0fca6ea1SDimitry Andric Vals.push_back(ProfileData->getOperand(i));
281*0fca6ea1SDimitry Andric uint64_t Count =
282*0fca6ea1SDimitry Andric mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(i + 1))
283*0fca6ea1SDimitry Andric ->getValue()
284*0fca6ea1SDimitry Andric .getZExtValue();
285*0fca6ea1SDimitry Andric // Don't scale the magic number.
286*0fca6ea1SDimitry Andric if (Count == NOMORE_ICP_MAGICNUM) {
287*0fca6ea1SDimitry Andric Vals.push_back(ProfileData->getOperand(i + 1));
288*0fca6ea1SDimitry Andric continue;
289*0fca6ea1SDimitry Andric }
290*0fca6ea1SDimitry Andric // Using APInt::div may be expensive, but most cases should fit 64 bits.
291*0fca6ea1SDimitry Andric APInt Val(128, Count);
292*0fca6ea1SDimitry Andric Val *= APS;
293*0fca6ea1SDimitry Andric Vals.push_back(MDB.createConstant(ConstantInt::get(
294*0fca6ea1SDimitry Andric Type::getInt64Ty(C), Val.udiv(APT).getLimitedValue())));
295*0fca6ea1SDimitry Andric }
296*0fca6ea1SDimitry Andric I.setMetadata(LLVMContext::MD_prof, MDNode::get(C, Vals));
297*0fca6ea1SDimitry Andric }
298*0fca6ea1SDimitry Andric
299bdd1243dSDimitry Andric } // namespace llvm
300