1 //===- ProfDataUtils.cpp - Utility functions for MD_prof Metadata ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements utilities for working with Profiling Metadata. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "llvm/IR/ProfDataUtils.h" 14 15 #include "llvm/ADT/SmallVector.h" 16 #include "llvm/IR/Constants.h" 17 #include "llvm/IR/Function.h" 18 #include "llvm/IR/Instructions.h" 19 #include "llvm/IR/LLVMContext.h" 20 #include "llvm/IR/MDBuilder.h" 21 #include "llvm/IR/Metadata.h" 22 23 using namespace llvm; 24 25 namespace { 26 27 // MD_prof nodes have the following layout 28 // 29 // In general: 30 // { String name, Array of i32 } 31 // 32 // In terms of Types: 33 // { MDString, [i32, i32, ...]} 34 // 35 // Concretely for Branch Weights 36 // { "branch_weights", [i32 1, i32 10000]} 37 // 38 // We maintain some constants here to ensure that we access the branch weights 39 // correctly, and can change the behavior in the future if the layout changes 40 41 // the minimum number of operands for MD_prof nodes with branch weights 42 constexpr unsigned MinBWOps = 3; 43 44 // the minimum number of operands for MD_prof nodes with value profiles 45 constexpr unsigned MinVPOps = 5; 46 47 // We may want to add support for other MD_prof types, so provide an abstraction 48 // for checking the metadata type. 49 bool isTargetMD(const MDNode *ProfData, const char *Name, unsigned MinOps) { 50 // TODO: This routine may be simplified if MD_prof used an enum instead of a 51 // string to differentiate the types of MD_prof nodes. 52 if (!ProfData || !Name || MinOps < 2) 53 return false; 54 55 unsigned NOps = ProfData->getNumOperands(); 56 if (NOps < MinOps) 57 return false; 58 59 auto *ProfDataName = dyn_cast<MDString>(ProfData->getOperand(0)); 60 if (!ProfDataName) 61 return false; 62 63 return ProfDataName->getString() == Name; 64 } 65 66 template <typename T, 67 typename = typename std::enable_if<std::is_arithmetic_v<T>>> 68 static void extractFromBranchWeightMD(const MDNode *ProfileData, 69 SmallVectorImpl<T> &Weights) { 70 assert(isBranchWeightMD(ProfileData) && "wrong metadata"); 71 72 unsigned NOps = ProfileData->getNumOperands(); 73 unsigned WeightsIdx = getBranchWeightOffset(ProfileData); 74 assert(WeightsIdx < NOps && "Weights Index must be less than NOps."); 75 Weights.resize(NOps - WeightsIdx); 76 77 for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) { 78 ConstantInt *Weight = 79 mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx)); 80 assert(Weight && "Malformed branch_weight in MD_prof node"); 81 assert(Weight->getValue().getActiveBits() <= (sizeof(T) * 8) && 82 "Too many bits for MD_prof branch_weight"); 83 Weights[Idx - WeightsIdx] = Weight->getZExtValue(); 84 } 85 } 86 87 } // namespace 88 89 namespace llvm { 90 91 const char *MDProfLabels::BranchWeights = "branch_weights"; 92 const char *MDProfLabels::ExpectedBranchWeights = "expected"; 93 const char *MDProfLabels::ValueProfile = "VP"; 94 const char *MDProfLabels::FunctionEntryCount = "function_entry_count"; 95 const char *MDProfLabels::SyntheticFunctionEntryCount = 96 "synthetic_function_entry_count"; 97 const char *MDProfLabels::UnknownBranchWeightsMarker = "unknown"; 98 99 bool hasProfMD(const Instruction &I) { 100 return I.hasMetadata(LLVMContext::MD_prof); 101 } 102 103 bool isBranchWeightMD(const MDNode *ProfileData) { 104 return isTargetMD(ProfileData, MDProfLabels::BranchWeights, MinBWOps); 105 } 106 107 bool isValueProfileMD(const MDNode *ProfileData) { 108 return isTargetMD(ProfileData, MDProfLabels::ValueProfile, MinVPOps); 109 } 110 111 bool hasBranchWeightMD(const Instruction &I) { 112 auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); 113 return isBranchWeightMD(ProfileData); 114 } 115 116 static bool hasCountTypeMD(const Instruction &I) { 117 auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); 118 // Value profiles record count-type information. 119 if (isValueProfileMD(ProfileData)) 120 return true; 121 // Conservatively assume non CallBase instruction only get taken/not-taken 122 // branch probability, so not interpret them as count. 123 return isa<CallBase>(I) && !isBranchWeightMD(ProfileData); 124 } 125 126 bool hasValidBranchWeightMD(const Instruction &I) { 127 return getValidBranchWeightMDNode(I); 128 } 129 130 bool hasBranchWeightOrigin(const Instruction &I) { 131 auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); 132 return hasBranchWeightOrigin(ProfileData); 133 } 134 135 bool hasBranchWeightOrigin(const MDNode *ProfileData) { 136 if (!isBranchWeightMD(ProfileData)) 137 return false; 138 auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(1)); 139 // NOTE: if we ever have more types of branch weight provenance, 140 // we need to check the string value is "expected". For now, we 141 // supply a more generic API, and avoid the spurious comparisons. 142 assert(ProfDataName == nullptr || 143 ProfDataName->getString() == MDProfLabels::ExpectedBranchWeights); 144 return ProfDataName != nullptr; 145 } 146 147 unsigned getBranchWeightOffset(const MDNode *ProfileData) { 148 return hasBranchWeightOrigin(ProfileData) ? 2 : 1; 149 } 150 151 unsigned getNumBranchWeights(const MDNode &ProfileData) { 152 return ProfileData.getNumOperands() - getBranchWeightOffset(&ProfileData); 153 } 154 155 MDNode *getBranchWeightMDNode(const Instruction &I) { 156 auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); 157 if (!isBranchWeightMD(ProfileData)) 158 return nullptr; 159 return ProfileData; 160 } 161 162 MDNode *getValidBranchWeightMDNode(const Instruction &I) { 163 auto *ProfileData = getBranchWeightMDNode(I); 164 if (ProfileData && getNumBranchWeights(*ProfileData) == I.getNumSuccessors()) 165 return ProfileData; 166 return nullptr; 167 } 168 169 void extractFromBranchWeightMD32(const MDNode *ProfileData, 170 SmallVectorImpl<uint32_t> &Weights) { 171 extractFromBranchWeightMD(ProfileData, Weights); 172 } 173 174 void extractFromBranchWeightMD64(const MDNode *ProfileData, 175 SmallVectorImpl<uint64_t> &Weights) { 176 extractFromBranchWeightMD(ProfileData, Weights); 177 } 178 179 bool extractBranchWeights(const MDNode *ProfileData, 180 SmallVectorImpl<uint32_t> &Weights) { 181 if (!isBranchWeightMD(ProfileData)) 182 return false; 183 extractFromBranchWeightMD(ProfileData, Weights); 184 return true; 185 } 186 187 bool extractBranchWeights(const Instruction &I, 188 SmallVectorImpl<uint32_t> &Weights) { 189 auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); 190 return extractBranchWeights(ProfileData, Weights); 191 } 192 193 bool extractBranchWeights(const Instruction &I, uint64_t &TrueVal, 194 uint64_t &FalseVal) { 195 assert((I.getOpcode() == Instruction::Br || 196 I.getOpcode() == Instruction::Select) && 197 "Looking for branch weights on something besides branch, select, or " 198 "switch"); 199 200 SmallVector<uint32_t, 2> Weights; 201 auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); 202 if (!extractBranchWeights(ProfileData, Weights)) 203 return false; 204 205 if (Weights.size() > 2) 206 return false; 207 208 TrueVal = Weights[0]; 209 FalseVal = Weights[1]; 210 return true; 211 } 212 213 bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalVal) { 214 TotalVal = 0; 215 if (!ProfileData) 216 return false; 217 218 auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(0)); 219 if (!ProfDataName) 220 return false; 221 222 if (ProfDataName->getString() == MDProfLabels::BranchWeights) { 223 unsigned Offset = getBranchWeightOffset(ProfileData); 224 for (unsigned Idx = Offset; Idx < ProfileData->getNumOperands(); ++Idx) { 225 auto *V = mdconst::extract<ConstantInt>(ProfileData->getOperand(Idx)); 226 TotalVal += V->getValue().getZExtValue(); 227 } 228 return true; 229 } 230 231 if (ProfDataName->getString() == MDProfLabels::ValueProfile && 232 ProfileData->getNumOperands() > 3) { 233 TotalVal = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(2)) 234 ->getValue() 235 .getZExtValue(); 236 return true; 237 } 238 return false; 239 } 240 241 bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalVal) { 242 return extractProfTotalWeight(I.getMetadata(LLVMContext::MD_prof), TotalVal); 243 } 244 245 void setExplicitlyUnknownBranchWeights(Instruction &I) { 246 MDBuilder MDB(I.getContext()); 247 I.setMetadata( 248 LLVMContext::MD_prof, 249 MDNode::get(I.getContext(), 250 MDB.createString(MDProfLabels::UnknownBranchWeightsMarker))); 251 } 252 253 bool isExplicitlyUnknownBranchWeightsMetadata(const MDNode &MD) { 254 if (MD.getNumOperands() != 1) 255 return false; 256 return MD.getOperand(0).equalsStr(MDProfLabels::UnknownBranchWeightsMarker); 257 } 258 259 bool hasExplicitlyUnknownBranchWeights(const Instruction &I) { 260 auto *MD = I.getMetadata(LLVMContext::MD_prof); 261 if (!MD) 262 return false; 263 return isExplicitlyUnknownBranchWeightsMetadata(*MD); 264 } 265 266 void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights, 267 bool IsExpected) { 268 MDBuilder MDB(I.getContext()); 269 MDNode *BranchWeights = MDB.createBranchWeights(Weights, IsExpected); 270 I.setMetadata(LLVMContext::MD_prof, BranchWeights); 271 } 272 273 void scaleProfData(Instruction &I, uint64_t S, uint64_t T) { 274 assert(T != 0 && "Caller should guarantee"); 275 auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); 276 if (ProfileData == nullptr) 277 return; 278 279 auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(0)); 280 if (!ProfDataName || 281 (ProfDataName->getString() != MDProfLabels::BranchWeights && 282 ProfDataName->getString() != MDProfLabels::ValueProfile)) 283 return; 284 285 if (!hasCountTypeMD(I)) 286 return; 287 288 LLVMContext &C = I.getContext(); 289 290 MDBuilder MDB(C); 291 SmallVector<Metadata *, 3> Vals; 292 Vals.push_back(ProfileData->getOperand(0)); 293 APInt APS(128, S), APT(128, T); 294 if (ProfDataName->getString() == MDProfLabels::BranchWeights && 295 ProfileData->getNumOperands() > 0) { 296 // Using APInt::div may be expensive, but most cases should fit 64 bits. 297 APInt Val(128, 298 mdconst::dyn_extract<ConstantInt>( 299 ProfileData->getOperand(getBranchWeightOffset(ProfileData))) 300 ->getValue() 301 .getZExtValue()); 302 Val *= APS; 303 Vals.push_back(MDB.createConstant(ConstantInt::get( 304 Type::getInt32Ty(C), Val.udiv(APT).getLimitedValue(UINT32_MAX)))); 305 } else if (ProfDataName->getString() == MDProfLabels::ValueProfile) 306 for (unsigned Idx = 1; Idx < ProfileData->getNumOperands(); Idx += 2) { 307 // The first value is the key of the value profile, which will not change. 308 Vals.push_back(ProfileData->getOperand(Idx)); 309 uint64_t Count = 310 mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx + 1)) 311 ->getValue() 312 .getZExtValue(); 313 // Don't scale the magic number. 314 if (Count == NOMORE_ICP_MAGICNUM) { 315 Vals.push_back(ProfileData->getOperand(Idx + 1)); 316 continue; 317 } 318 // Using APInt::div may be expensive, but most cases should fit 64 bits. 319 APInt Val(128, Count); 320 Val *= APS; 321 Vals.push_back(MDB.createConstant(ConstantInt::get( 322 Type::getInt64Ty(C), Val.udiv(APT).getLimitedValue()))); 323 } 324 I.setMetadata(LLVMContext::MD_prof, MDNode::get(C, Vals)); 325 } 326 327 } // namespace llvm 328