1 //===- LowerAllowCheckPass.cpp ----------------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/Transforms/Instrumentation/LowerAllowCheckPass.h" 10 11 #include "llvm/ADT/SmallVector.h" 12 #include "llvm/ADT/Statistic.h" 13 #include "llvm/ADT/StringRef.h" 14 #include "llvm/Analysis/OptimizationRemarkEmitter.h" 15 #include "llvm/Analysis/ProfileSummaryInfo.h" 16 #include "llvm/IR/Constants.h" 17 #include "llvm/IR/DiagnosticInfo.h" 18 #include "llvm/IR/InstIterator.h" 19 #include "llvm/IR/Instructions.h" 20 #include "llvm/IR/IntrinsicInst.h" 21 #include "llvm/IR/Intrinsics.h" 22 #include "llvm/IR/Metadata.h" 23 #include "llvm/IR/Module.h" 24 #include "llvm/Support/Debug.h" 25 #include "llvm/Support/RandomNumberGenerator.h" 26 #include <memory> 27 #include <random> 28 29 using namespace llvm; 30 31 #define DEBUG_TYPE "lower-allow-check" 32 33 static cl::opt<int> 34 HotPercentileCutoff("lower-allow-check-percentile-cutoff-hot", 35 cl::desc("Hot percentile cutoff.")); 36 37 static cl::opt<float> 38 RandomRate("lower-allow-check-random-rate", 39 cl::desc("Probability value in the range [0.0, 1.0] of " 40 "unconditional pseudo-random checks.")); 41 42 STATISTIC(NumChecksTotal, "Number of checks"); 43 STATISTIC(NumChecksRemoved, "Number of removed checks"); 44 45 struct RemarkInfo { 46 ore::NV Kind; 47 ore::NV F; 48 ore::NV BB; 49 explicit RemarkInfo(IntrinsicInst *II) 50 : Kind("Kind", II->getArgOperand(0)), 51 F("Function", II->getParent()->getParent()), 52 BB("Block", II->getParent()->getName()) {} 53 }; 54 55 static void emitRemark(IntrinsicInst *II, OptimizationRemarkEmitter &ORE, 56 bool Removed) { 57 if (Removed) { 58 ORE.emit([&]() { 59 RemarkInfo Info(II); 60 return OptimizationRemark(DEBUG_TYPE, "Removed", II) 61 << "Removed check: Kind=" << Info.Kind << " F=" << Info.F 62 << " BB=" << Info.BB; 63 }); 64 } else { 65 ORE.emit([&]() { 66 RemarkInfo Info(II); 67 return OptimizationRemarkMissed(DEBUG_TYPE, "Allowed", II) 68 << "Allowed check: Kind=" << Info.Kind << " F=" << Info.F 69 << " BB=" << Info.BB; 70 }); 71 } 72 } 73 74 static bool removeUbsanTraps(Function &F, const BlockFrequencyInfo &BFI, 75 const ProfileSummaryInfo *PSI, 76 OptimizationRemarkEmitter &ORE, 77 const LowerAllowCheckPass::Options &Opts) { 78 SmallVector<std::pair<IntrinsicInst *, bool>, 16> ReplaceWithValue; 79 std::unique_ptr<RandomNumberGenerator> Rng; 80 81 auto GetRng = [&]() -> RandomNumberGenerator & { 82 if (!Rng) 83 Rng = F.getParent()->createRNG(F.getName()); 84 return *Rng; 85 }; 86 87 auto GetCutoff = [&](const IntrinsicInst *II) -> unsigned { 88 if (HotPercentileCutoff.getNumOccurrences()) 89 return HotPercentileCutoff; 90 else if (II->getIntrinsicID() == Intrinsic::allow_ubsan_check) { 91 auto *Kind = cast<ConstantInt>(II->getArgOperand(0)); 92 if (Kind->getZExtValue() < Opts.cutoffs.size()) 93 return Opts.cutoffs[Kind->getZExtValue()]; 94 } else if (II->getIntrinsicID() == Intrinsic::allow_runtime_check) { 95 return Opts.runtime_check; 96 } 97 98 return 0; 99 }; 100 101 auto ShouldRemoveHot = [&](const BasicBlock &BB, unsigned int cutoff) { 102 return (cutoff == 1000000) || 103 (PSI && PSI->isHotCountNthPercentile( 104 cutoff, BFI.getBlockProfileCount(&BB).value_or(0))); 105 }; 106 107 auto ShouldRemoveRandom = [&]() { 108 return RandomRate.getNumOccurrences() && 109 !std::bernoulli_distribution(RandomRate)(GetRng()); 110 }; 111 112 auto ShouldRemove = [&](const IntrinsicInst *II) { 113 unsigned int cutoff = GetCutoff(II); 114 return ShouldRemoveRandom() || ShouldRemoveHot(*(II->getParent()), cutoff); 115 }; 116 117 for (Instruction &I : instructions(F)) { 118 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I); 119 if (!II) 120 continue; 121 auto ID = II->getIntrinsicID(); 122 switch (ID) { 123 case Intrinsic::allow_ubsan_check: 124 case Intrinsic::allow_runtime_check: { 125 ++NumChecksTotal; 126 127 bool ToRemove = ShouldRemove(II); 128 129 ReplaceWithValue.push_back({ 130 II, 131 ToRemove, 132 }); 133 if (ToRemove) 134 ++NumChecksRemoved; 135 emitRemark(II, ORE, ToRemove); 136 break; 137 } 138 default: 139 break; 140 } 141 } 142 143 for (auto [I, V] : ReplaceWithValue) { 144 I->replaceAllUsesWith(ConstantInt::getBool(I->getType(), !V)); 145 I->eraseFromParent(); 146 } 147 148 return !ReplaceWithValue.empty(); 149 } 150 151 PreservedAnalyses LowerAllowCheckPass::run(Function &F, 152 FunctionAnalysisManager &AM) { 153 if (F.isDeclaration()) 154 return PreservedAnalyses::all(); 155 auto &MAMProxy = AM.getResult<ModuleAnalysisManagerFunctionProxy>(F); 156 ProfileSummaryInfo *PSI = 157 MAMProxy.getCachedResult<ProfileSummaryAnalysis>(*F.getParent()); 158 BlockFrequencyInfo &BFI = AM.getResult<BlockFrequencyAnalysis>(F); 159 OptimizationRemarkEmitter &ORE = 160 AM.getResult<OptimizationRemarkEmitterAnalysis>(F); 161 162 return removeUbsanTraps(F, BFI, PSI, ORE, Opts) 163 // We do not change the CFG, we only replace the intrinsics with 164 // true or false. 165 ? PreservedAnalyses::none().preserveSet<CFGAnalyses>() 166 : PreservedAnalyses::all(); 167 } 168 169 bool LowerAllowCheckPass::IsRequested() { 170 return RandomRate.getNumOccurrences() || 171 HotPercentileCutoff.getNumOccurrences(); 172 } 173 174 void LowerAllowCheckPass::printPipeline( 175 raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { 176 static_cast<PassInfoMixin<LowerAllowCheckPass> *>(this)->printPipeline( 177 OS, MapClassName2PassName); 178 OS << "<"; 179 180 // Format is <cutoffs[0,1,2]=70000;cutoffs[5,6,8]=90000> 181 // but it's equally valid to specify 182 // cutoffs[0]=70000;cutoffs[1]=70000;cutoffs[2]=70000;cutoffs[5]=90000;... 183 // and that's what we do here. It is verbose but valid and easy to verify 184 // correctness. 185 // TODO: print shorter output by combining adjacent runs, etc. 186 int i = 0; 187 bool printed = false; 188 for (unsigned int cutoff : Opts.cutoffs) { 189 if (cutoff > 0) { 190 if (printed) 191 OS << ";"; 192 OS << "cutoffs[" << i << "]=" << cutoff; 193 printed = true; 194 } 195 196 i++; 197 } 198 if (Opts.runtime_check) { 199 if (printed) 200 OS << ";"; 201 OS << "runtime_check=" << Opts.runtime_check; 202 } 203 204 OS << '>'; 205 } 206