1 //===- ProfDataUtils.cpp - Utility functions for MD_prof Metadata ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements utilities for working with Profiling Metadata. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "llvm/IR/ProfDataUtils.h" 14 #include "llvm/ADT/SmallVector.h" 15 #include "llvm/ADT/Twine.h" 16 #include "llvm/IR/Constants.h" 17 #include "llvm/IR/Function.h" 18 #include "llvm/IR/Instructions.h" 19 #include "llvm/IR/LLVMContext.h" 20 #include "llvm/IR/MDBuilder.h" 21 #include "llvm/IR/Metadata.h" 22 #include "llvm/IR/ProfDataUtils.h" 23 #include "llvm/Support/BranchProbability.h" 24 #include "llvm/Support/CommandLine.h" 25 26 using namespace llvm; 27 28 namespace { 29 30 // MD_prof nodes have the following layout 31 // 32 // In general: 33 // { String name, Array of i32 } 34 // 35 // In terms of Types: 36 // { MDString, [i32, i32, ...]} 37 // 38 // Concretely for Branch Weights 39 // { "branch_weights", [i32 1, i32 10000]} 40 // 41 // We maintain some constants here to ensure that we access the branch weights 42 // correctly, and can change the behavior in the future if the layout changes 43 44 // the minimum number of operands for MD_prof nodes with branch weights 45 constexpr unsigned MinBWOps = 3; 46 47 // the minimum number of operands for MD_prof nodes with value profiles 48 constexpr unsigned MinVPOps = 5; 49 50 // We may want to add support for other MD_prof types, so provide an abstraction 51 // for checking the metadata type. 52 bool isTargetMD(const MDNode *ProfData, const char *Name, unsigned MinOps) { 53 // TODO: This routine may be simplified if MD_prof used an enum instead of a 54 // string to differentiate the types of MD_prof nodes. 55 if (!ProfData || !Name || MinOps < 2) 56 return false; 57 58 unsigned NOps = ProfData->getNumOperands(); 59 if (NOps < MinOps) 60 return false; 61 62 auto *ProfDataName = dyn_cast<MDString>(ProfData->getOperand(0)); 63 if (!ProfDataName) 64 return false; 65 66 return ProfDataName->getString() == Name; 67 } 68 69 template <typename T, 70 typename = typename std::enable_if<std::is_arithmetic_v<T>>> 71 static void extractFromBranchWeightMD(const MDNode *ProfileData, 72 SmallVectorImpl<T> &Weights) { 73 assert(isBranchWeightMD(ProfileData) && "wrong metadata"); 74 75 unsigned NOps = ProfileData->getNumOperands(); 76 unsigned WeightsIdx = getBranchWeightOffset(ProfileData); 77 assert(WeightsIdx < NOps && "Weights Index must be less than NOps."); 78 Weights.resize(NOps - WeightsIdx); 79 80 for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) { 81 ConstantInt *Weight = 82 mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx)); 83 assert(Weight && "Malformed branch_weight in MD_prof node"); 84 assert(Weight->getValue().getActiveBits() <= (sizeof(T) * 8) && 85 "Too many bits for MD_prof branch_weight"); 86 Weights[Idx - WeightsIdx] = Weight->getZExtValue(); 87 } 88 } 89 90 } // namespace 91 92 namespace llvm { 93 94 bool hasProfMD(const Instruction &I) { 95 return I.hasMetadata(LLVMContext::MD_prof); 96 } 97 98 bool isBranchWeightMD(const MDNode *ProfileData) { 99 return isTargetMD(ProfileData, "branch_weights", MinBWOps); 100 } 101 102 bool isValueProfileMD(const MDNode *ProfileData) { 103 return isTargetMD(ProfileData, "VP", MinVPOps); 104 } 105 106 bool hasBranchWeightMD(const Instruction &I) { 107 auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); 108 return isBranchWeightMD(ProfileData); 109 } 110 111 bool hasCountTypeMD(const Instruction &I) { 112 auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); 113 // Value profiles record count-type information. 114 if (isValueProfileMD(ProfileData)) 115 return true; 116 // Conservatively assume non CallBase instruction only get taken/not-taken 117 // branch probability, so not interpret them as count. 118 return isa<CallBase>(I) && !isBranchWeightMD(ProfileData); 119 } 120 121 bool hasValidBranchWeightMD(const Instruction &I) { 122 return getValidBranchWeightMDNode(I); 123 } 124 125 bool hasBranchWeightOrigin(const Instruction &I) { 126 auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); 127 return hasBranchWeightOrigin(ProfileData); 128 } 129 130 bool hasBranchWeightOrigin(const MDNode *ProfileData) { 131 if (!isBranchWeightMD(ProfileData)) 132 return false; 133 auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(1)); 134 // NOTE: if we ever have more types of branch weight provenance, 135 // we need to check the string value is "expected". For now, we 136 // supply a more generic API, and avoid the spurious comparisons. 137 assert(ProfDataName == nullptr || ProfDataName->getString() == "expected"); 138 return ProfDataName != nullptr; 139 } 140 141 unsigned getBranchWeightOffset(const MDNode *ProfileData) { 142 return hasBranchWeightOrigin(ProfileData) ? 2 : 1; 143 } 144 145 unsigned getNumBranchWeights(const MDNode &ProfileData) { 146 return ProfileData.getNumOperands() - getBranchWeightOffset(&ProfileData); 147 } 148 149 MDNode *getBranchWeightMDNode(const Instruction &I) { 150 auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); 151 if (!isBranchWeightMD(ProfileData)) 152 return nullptr; 153 return ProfileData; 154 } 155 156 MDNode *getValidBranchWeightMDNode(const Instruction &I) { 157 auto *ProfileData = getBranchWeightMDNode(I); 158 if (ProfileData && getNumBranchWeights(*ProfileData) == I.getNumSuccessors()) 159 return ProfileData; 160 return nullptr; 161 } 162 163 void extractFromBranchWeightMD32(const MDNode *ProfileData, 164 SmallVectorImpl<uint32_t> &Weights) { 165 extractFromBranchWeightMD(ProfileData, Weights); 166 } 167 168 void extractFromBranchWeightMD64(const MDNode *ProfileData, 169 SmallVectorImpl<uint64_t> &Weights) { 170 extractFromBranchWeightMD(ProfileData, Weights); 171 } 172 173 bool extractBranchWeights(const MDNode *ProfileData, 174 SmallVectorImpl<uint32_t> &Weights) { 175 if (!isBranchWeightMD(ProfileData)) 176 return false; 177 extractFromBranchWeightMD(ProfileData, Weights); 178 return true; 179 } 180 181 bool extractBranchWeights(const Instruction &I, 182 SmallVectorImpl<uint32_t> &Weights) { 183 auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); 184 return extractBranchWeights(ProfileData, Weights); 185 } 186 187 bool extractBranchWeights(const Instruction &I, uint64_t &TrueVal, 188 uint64_t &FalseVal) { 189 assert((I.getOpcode() == Instruction::Br || 190 I.getOpcode() == Instruction::Select) && 191 "Looking for branch weights on something besides branch, select, or " 192 "switch"); 193 194 SmallVector<uint32_t, 2> Weights; 195 auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); 196 if (!extractBranchWeights(ProfileData, Weights)) 197 return false; 198 199 if (Weights.size() > 2) 200 return false; 201 202 TrueVal = Weights[0]; 203 FalseVal = Weights[1]; 204 return true; 205 } 206 207 bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalVal) { 208 TotalVal = 0; 209 if (!ProfileData) 210 return false; 211 212 auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(0)); 213 if (!ProfDataName) 214 return false; 215 216 if (ProfDataName->getString() == "branch_weights") { 217 unsigned Offset = getBranchWeightOffset(ProfileData); 218 for (unsigned Idx = Offset; Idx < ProfileData->getNumOperands(); ++Idx) { 219 auto *V = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx)); 220 assert(V && "Malformed branch_weight in MD_prof node"); 221 TotalVal += V->getValue().getZExtValue(); 222 } 223 return true; 224 } 225 226 if (ProfDataName->getString() == "VP" && ProfileData->getNumOperands() > 3) { 227 TotalVal = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(2)) 228 ->getValue() 229 .getZExtValue(); 230 return true; 231 } 232 return false; 233 } 234 235 bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalVal) { 236 return extractProfTotalWeight(I.getMetadata(LLVMContext::MD_prof), TotalVal); 237 } 238 239 void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights, 240 bool IsExpected) { 241 MDBuilder MDB(I.getContext()); 242 MDNode *BranchWeights = MDB.createBranchWeights(Weights, IsExpected); 243 I.setMetadata(LLVMContext::MD_prof, BranchWeights); 244 } 245 246 void scaleProfData(Instruction &I, uint64_t S, uint64_t T) { 247 assert(T != 0 && "Caller should guarantee"); 248 auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); 249 if (ProfileData == nullptr) 250 return; 251 252 auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(0)); 253 if (!ProfDataName || (ProfDataName->getString() != "branch_weights" && 254 ProfDataName->getString() != "VP")) 255 return; 256 257 if (!hasCountTypeMD(I)) 258 return; 259 260 LLVMContext &C = I.getContext(); 261 262 MDBuilder MDB(C); 263 SmallVector<Metadata *, 3> Vals; 264 Vals.push_back(ProfileData->getOperand(0)); 265 APInt APS(128, S), APT(128, T); 266 if (ProfDataName->getString() == "branch_weights" && 267 ProfileData->getNumOperands() > 0) { 268 // Using APInt::div may be expensive, but most cases should fit 64 bits. 269 APInt Val(128, 270 mdconst::dyn_extract<ConstantInt>( 271 ProfileData->getOperand(getBranchWeightOffset(ProfileData))) 272 ->getValue() 273 .getZExtValue()); 274 Val *= APS; 275 Vals.push_back(MDB.createConstant(ConstantInt::get( 276 Type::getInt32Ty(C), Val.udiv(APT).getLimitedValue(UINT32_MAX)))); 277 } else if (ProfDataName->getString() == "VP") 278 for (unsigned i = 1; i < ProfileData->getNumOperands(); i += 2) { 279 // The first value is the key of the value profile, which will not change. 280 Vals.push_back(ProfileData->getOperand(i)); 281 uint64_t Count = 282 mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(i + 1)) 283 ->getValue() 284 .getZExtValue(); 285 // Don't scale the magic number. 286 if (Count == NOMORE_ICP_MAGICNUM) { 287 Vals.push_back(ProfileData->getOperand(i + 1)); 288 continue; 289 } 290 // Using APInt::div may be expensive, but most cases should fit 64 bits. 291 APInt Val(128, Count); 292 Val *= APS; 293 Vals.push_back(MDB.createConstant(ConstantInt::get( 294 Type::getInt64Ty(C), Val.udiv(APT).getLimitedValue()))); 295 } 296 I.setMetadata(LLVMContext::MD_prof, MDNode::get(C, Vals)); 297 } 298 299 } // namespace llvm 300