1 //===- ProfDataUtils.cpp - Utility functions for MD_prof Metadata ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements utilities for working with Profiling Metadata. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "llvm/IR/ProfDataUtils.h" 14 #include "llvm/ADT/SmallVector.h" 15 #include "llvm/ADT/Twine.h" 16 #include "llvm/IR/Constants.h" 17 #include "llvm/IR/Function.h" 18 #include "llvm/IR/Instructions.h" 19 #include "llvm/IR/LLVMContext.h" 20 #include "llvm/IR/MDBuilder.h" 21 #include "llvm/IR/Metadata.h" 22 #include "llvm/Support/BranchProbability.h" 23 #include "llvm/Support/CommandLine.h" 24 25 using namespace llvm; 26 27 namespace { 28 29 // MD_prof nodes have the following layout 30 // 31 // In general: 32 // { String name, Array of i32 } 33 // 34 // In terms of Types: 35 // { MDString, [i32, i32, ...]} 36 // 37 // Concretely for Branch Weights 38 // { "branch_weights", [i32 1, i32 10000]} 39 // 40 // We maintain some constants here to ensure that we access the branch weights 41 // correctly, and can change the behavior in the future if the layout changes 42 43 // The index at which the weights vector starts 44 constexpr unsigned WeightsIdx = 1; 45 46 // the minimum number of operands for MD_prof nodes with branch weights 47 constexpr unsigned MinBWOps = 3; 48 49 // We may want to add support for other MD_prof types, so provide an abstraction 50 // for checking the metadata type. 51 bool isTargetMD(const MDNode *ProfData, const char *Name, unsigned MinOps) { 52 // TODO: This routine may be simplified if MD_prof used an enum instead of a 53 // string to differentiate the types of MD_prof nodes. 54 if (!ProfData || !Name || MinOps < 2) 55 return false; 56 57 unsigned NOps = ProfData->getNumOperands(); 58 if (NOps < MinOps) 59 return false; 60 61 auto *ProfDataName = dyn_cast<MDString>(ProfData->getOperand(0)); 62 if (!ProfDataName) 63 return false; 64 65 return ProfDataName->getString().equals(Name); 66 } 67 68 } // namespace 69 70 namespace llvm { 71 72 bool hasProfMD(const Instruction &I) { 73 return nullptr != I.getMetadata(LLVMContext::MD_prof); 74 } 75 76 bool isBranchWeightMD(const MDNode *ProfileData) { 77 return isTargetMD(ProfileData, "branch_weights", MinBWOps); 78 } 79 80 bool hasBranchWeightMD(const Instruction &I) { 81 auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); 82 return isBranchWeightMD(ProfileData); 83 } 84 85 bool hasValidBranchWeightMD(const Instruction &I) { 86 return getValidBranchWeightMDNode(I); 87 } 88 89 MDNode *getBranchWeightMDNode(const Instruction &I) { 90 auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); 91 if (!isBranchWeightMD(ProfileData)) 92 return nullptr; 93 return ProfileData; 94 } 95 96 MDNode *getValidBranchWeightMDNode(const Instruction &I) { 97 auto *ProfileData = getBranchWeightMDNode(I); 98 if (ProfileData && ProfileData->getNumOperands() == 1 + I.getNumSuccessors()) 99 return ProfileData; 100 return nullptr; 101 } 102 103 void extractFromBranchWeightMD(const MDNode *ProfileData, 104 SmallVectorImpl<uint32_t> &Weights) { 105 assert(isBranchWeightMD(ProfileData) && "wrong metadata"); 106 107 unsigned NOps = ProfileData->getNumOperands(); 108 assert(WeightsIdx < NOps && "Weights Index must be less than NOps."); 109 Weights.resize(NOps - WeightsIdx); 110 111 for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) { 112 ConstantInt *Weight = 113 mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx)); 114 assert(Weight && "Malformed branch_weight in MD_prof node"); 115 assert(Weight->getValue().getActiveBits() <= 32 && 116 "Too many bits for uint32_t"); 117 Weights[Idx - WeightsIdx] = Weight->getZExtValue(); 118 } 119 } 120 121 bool extractBranchWeights(const MDNode *ProfileData, 122 SmallVectorImpl<uint32_t> &Weights) { 123 if (!isBranchWeightMD(ProfileData)) 124 return false; 125 extractFromBranchWeightMD(ProfileData, Weights); 126 return true; 127 } 128 129 bool extractBranchWeights(const Instruction &I, 130 SmallVectorImpl<uint32_t> &Weights) { 131 auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); 132 return extractBranchWeights(ProfileData, Weights); 133 } 134 135 bool extractBranchWeights(const Instruction &I, uint64_t &TrueVal, 136 uint64_t &FalseVal) { 137 assert((I.getOpcode() == Instruction::Br || 138 I.getOpcode() == Instruction::Select) && 139 "Looking for branch weights on something besides branch, select, or " 140 "switch"); 141 142 SmallVector<uint32_t, 2> Weights; 143 auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); 144 if (!extractBranchWeights(ProfileData, Weights)) 145 return false; 146 147 if (Weights.size() > 2) 148 return false; 149 150 TrueVal = Weights[0]; 151 FalseVal = Weights[1]; 152 return true; 153 } 154 155 bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalVal) { 156 TotalVal = 0; 157 if (!ProfileData) 158 return false; 159 160 auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(0)); 161 if (!ProfDataName) 162 return false; 163 164 if (ProfDataName->getString().equals("branch_weights")) { 165 for (unsigned Idx = 1; Idx < ProfileData->getNumOperands(); Idx++) { 166 auto *V = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx)); 167 assert(V && "Malformed branch_weight in MD_prof node"); 168 TotalVal += V->getValue().getZExtValue(); 169 } 170 return true; 171 } 172 173 if (ProfDataName->getString().equals("VP") && 174 ProfileData->getNumOperands() > 3) { 175 TotalVal = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(2)) 176 ->getValue() 177 .getZExtValue(); 178 return true; 179 } 180 return false; 181 } 182 183 bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalVal) { 184 return extractProfTotalWeight(I.getMetadata(LLVMContext::MD_prof), TotalVal); 185 } 186 187 void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights) { 188 MDBuilder MDB(I.getContext()); 189 MDNode *BranchWeights = MDB.createBranchWeights(Weights); 190 I.setMetadata(LLVMContext::MD_prof, BranchWeights); 191 } 192 193 } // namespace llvm 194