1 //===- LowerAllowCheckPass.cpp ----------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "llvm/Transforms/Instrumentation/LowerAllowCheckPass.h"
10
11 #include "llvm/ADT/SmallVector.h"
12 #include "llvm/ADT/Statistic.h"
13 #include "llvm/ADT/StringRef.h"
14 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
15 #include "llvm/Analysis/ProfileSummaryInfo.h"
16 #include "llvm/IR/Constants.h"
17 #include "llvm/IR/DiagnosticInfo.h"
18 #include "llvm/IR/InstIterator.h"
19 #include "llvm/IR/Instructions.h"
20 #include "llvm/IR/IntrinsicInst.h"
21 #include "llvm/IR/Intrinsics.h"
22 #include "llvm/IR/Metadata.h"
23 #include "llvm/IR/Module.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/RandomNumberGenerator.h"
26 #include <memory>
27 #include <random>
28
29 using namespace llvm;
30
31 #define DEBUG_TYPE "lower-allow-check"
32
33 static cl::opt<int>
34 HotPercentileCutoff("lower-allow-check-percentile-cutoff-hot",
35 cl::desc("Hot percentile cutoff."));
36
37 static cl::opt<float>
38 RandomRate("lower-allow-check-random-rate",
39 cl::desc("Probability value in the range [0.0, 1.0] of "
40 "unconditional pseudo-random checks."));
41
42 STATISTIC(NumChecksTotal, "Number of checks");
43 STATISTIC(NumChecksRemoved, "Number of removed checks");
44
45 struct RemarkInfo {
46 ore::NV Kind;
47 ore::NV F;
48 ore::NV BB;
RemarkInfoRemarkInfo49 explicit RemarkInfo(IntrinsicInst *II)
50 : Kind("Kind", II->getArgOperand(0)),
51 F("Function", II->getParent()->getParent()),
52 BB("Block", II->getParent()->getName()) {}
53 };
54
emitRemark(IntrinsicInst * II,OptimizationRemarkEmitter & ORE,bool Removed)55 static void emitRemark(IntrinsicInst *II, OptimizationRemarkEmitter &ORE,
56 bool Removed) {
57 if (Removed) {
58 ORE.emit([&]() {
59 RemarkInfo Info(II);
60 return OptimizationRemark(DEBUG_TYPE, "Removed", II)
61 << "Removed check: Kind=" << Info.Kind << " F=" << Info.F
62 << " BB=" << Info.BB;
63 });
64 } else {
65 ORE.emit([&]() {
66 RemarkInfo Info(II);
67 return OptimizationRemarkMissed(DEBUG_TYPE, "Allowed", II)
68 << "Allowed check: Kind=" << Info.Kind << " F=" << Info.F
69 << " BB=" << Info.BB;
70 });
71 }
72 }
73
removeUbsanTraps(Function & F,const BlockFrequencyInfo & BFI,const ProfileSummaryInfo * PSI,OptimizationRemarkEmitter & ORE,const LowerAllowCheckPass::Options & Opts)74 static bool removeUbsanTraps(Function &F, const BlockFrequencyInfo &BFI,
75 const ProfileSummaryInfo *PSI,
76 OptimizationRemarkEmitter &ORE,
77 const LowerAllowCheckPass::Options &Opts) {
78 SmallVector<std::pair<IntrinsicInst *, bool>, 16> ReplaceWithValue;
79 std::unique_ptr<RandomNumberGenerator> Rng;
80
81 auto GetRng = [&]() -> RandomNumberGenerator & {
82 if (!Rng)
83 Rng = F.getParent()->createRNG(F.getName());
84 return *Rng;
85 };
86
87 auto GetCutoff = [&](const IntrinsicInst *II) -> unsigned {
88 if (HotPercentileCutoff.getNumOccurrences())
89 return HotPercentileCutoff;
90 else if (II->getIntrinsicID() == Intrinsic::allow_ubsan_check) {
91 auto *Kind = cast<ConstantInt>(II->getArgOperand(0));
92 if (Kind->getZExtValue() < Opts.cutoffs.size())
93 return Opts.cutoffs[Kind->getZExtValue()];
94 } else if (II->getIntrinsicID() == Intrinsic::allow_runtime_check) {
95 return Opts.runtime_check;
96 }
97
98 return 0;
99 };
100
101 auto ShouldRemoveHot = [&](const BasicBlock &BB, unsigned int cutoff) {
102 return (cutoff == 1000000) ||
103 (PSI && PSI->isHotCountNthPercentile(
104 cutoff, BFI.getBlockProfileCount(&BB).value_or(0)));
105 };
106
107 auto ShouldRemoveRandom = [&]() {
108 return RandomRate.getNumOccurrences() &&
109 !std::bernoulli_distribution(RandomRate)(GetRng());
110 };
111
112 auto ShouldRemove = [&](const IntrinsicInst *II) {
113 unsigned int cutoff = GetCutoff(II);
114 return ShouldRemoveRandom() || ShouldRemoveHot(*(II->getParent()), cutoff);
115 };
116
117 for (Instruction &I : instructions(F)) {
118 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I);
119 if (!II)
120 continue;
121 auto ID = II->getIntrinsicID();
122 switch (ID) {
123 case Intrinsic::allow_ubsan_check:
124 case Intrinsic::allow_runtime_check: {
125 ++NumChecksTotal;
126
127 bool ToRemove = ShouldRemove(II);
128
129 ReplaceWithValue.push_back({
130 II,
131 ToRemove,
132 });
133 if (ToRemove)
134 ++NumChecksRemoved;
135 emitRemark(II, ORE, ToRemove);
136 break;
137 }
138 default:
139 break;
140 }
141 }
142
143 for (auto [I, V] : ReplaceWithValue) {
144 I->replaceAllUsesWith(ConstantInt::getBool(I->getType(), !V));
145 I->eraseFromParent();
146 }
147
148 return !ReplaceWithValue.empty();
149 }
150
run(Function & F,FunctionAnalysisManager & AM)151 PreservedAnalyses LowerAllowCheckPass::run(Function &F,
152 FunctionAnalysisManager &AM) {
153 if (F.isDeclaration())
154 return PreservedAnalyses::all();
155 auto &MAMProxy = AM.getResult<ModuleAnalysisManagerFunctionProxy>(F);
156 ProfileSummaryInfo *PSI =
157 MAMProxy.getCachedResult<ProfileSummaryAnalysis>(*F.getParent());
158 BlockFrequencyInfo &BFI = AM.getResult<BlockFrequencyAnalysis>(F);
159 OptimizationRemarkEmitter &ORE =
160 AM.getResult<OptimizationRemarkEmitterAnalysis>(F);
161
162 return removeUbsanTraps(F, BFI, PSI, ORE, Opts)
163 // We do not change the CFG, we only replace the intrinsics with
164 // true or false.
165 ? PreservedAnalyses::none().preserveSet<CFGAnalyses>()
166 : PreservedAnalyses::all();
167 }
168
IsRequested()169 bool LowerAllowCheckPass::IsRequested() {
170 return RandomRate.getNumOccurrences() ||
171 HotPercentileCutoff.getNumOccurrences();
172 }
173
printPipeline(raw_ostream & OS,function_ref<StringRef (StringRef)> MapClassName2PassName)174 void LowerAllowCheckPass::printPipeline(
175 raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) {
176 static_cast<PassInfoMixin<LowerAllowCheckPass> *>(this)->printPipeline(
177 OS, MapClassName2PassName);
178 OS << "<";
179
180 // Format is <cutoffs[0,1,2]=70000;cutoffs[5,6,8]=90000>
181 // but it's equally valid to specify
182 // cutoffs[0]=70000;cutoffs[1]=70000;cutoffs[2]=70000;cutoffs[5]=90000;...
183 // and that's what we do here. It is verbose but valid and easy to verify
184 // correctness.
185 // TODO: print shorter output by combining adjacent runs, etc.
186 int i = 0;
187 bool printed = false;
188 for (unsigned int cutoff : Opts.cutoffs) {
189 if (cutoff > 0) {
190 if (printed)
191 OS << ";";
192 OS << "cutoffs[" << i << "]=" << cutoff;
193 printed = true;
194 }
195
196 i++;
197 }
198 if (Opts.runtime_check) {
199 if (printed)
200 OS << ";";
201 OS << "runtime_check=" << Opts.runtime_check;
202 }
203
204 OS << '>';
205 }
206