xref: /freebsd/contrib/llvm-project/llvm/lib/Target/BPF/BPFCheckAndAdjustIR.cpp (revision f126890ac5386406dadf7c4cfa9566cbb56537c5)
1 //===------------ BPFCheckAndAdjustIR.cpp - Check and Adjust IR -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Check IR and adjust IR for verifier friendly codes.
10 // The following are done for IR checking:
11 //   - no relocation globals in PHI node.
12 // The following are done for IR adjustment:
13 //   - remove __builtin_bpf_passthrough builtins. Target independent IR
14 //     optimizations are done and those builtins can be removed.
15 //
16 //===----------------------------------------------------------------------===//
17 
18 #include "BPF.h"
19 #include "BPFCORE.h"
20 #include "BPFTargetMachine.h"
21 #include "llvm/Analysis/LoopInfo.h"
22 #include "llvm/IR/DebugInfoMetadata.h"
23 #include "llvm/IR/GlobalVariable.h"
24 #include "llvm/IR/IRBuilder.h"
25 #include "llvm/IR/Instruction.h"
26 #include "llvm/IR/Instructions.h"
27 #include "llvm/IR/Module.h"
28 #include "llvm/IR/Type.h"
29 #include "llvm/IR/User.h"
30 #include "llvm/IR/Value.h"
31 #include "llvm/Pass.h"
32 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
33 
34 #define DEBUG_TYPE "bpf-check-and-opt-ir"
35 
36 using namespace llvm;
37 
38 namespace {
39 
40 class BPFCheckAndAdjustIR final : public ModulePass {
41   bool runOnModule(Module &F) override;
42 
43 public:
44   static char ID;
45   BPFCheckAndAdjustIR() : ModulePass(ID) {}
46   virtual void getAnalysisUsage(AnalysisUsage &AU) const override;
47 
48 private:
49   void checkIR(Module &M);
50   bool adjustIR(Module &M);
51   bool removePassThroughBuiltin(Module &M);
52   bool removeCompareBuiltin(Module &M);
53   bool sinkMinMax(Module &M);
54 };
55 } // End anonymous namespace
56 
57 char BPFCheckAndAdjustIR::ID = 0;
58 INITIALIZE_PASS(BPFCheckAndAdjustIR, DEBUG_TYPE, "BPF Check And Adjust IR",
59                 false, false)
60 
61 ModulePass *llvm::createBPFCheckAndAdjustIR() {
62   return new BPFCheckAndAdjustIR();
63 }
64 
65 void BPFCheckAndAdjustIR::checkIR(Module &M) {
66   // Ensure relocation global won't appear in PHI node
67   // This may happen if the compiler generated the following code:
68   //   B1:
69   //      g1 = @llvm.skb_buff:0:1...
70   //      ...
71   //      goto B_COMMON
72   //   B2:
73   //      g2 = @llvm.skb_buff:0:2...
74   //      ...
75   //      goto B_COMMON
76   //   B_COMMON:
77   //      g = PHI(g1, g2)
78   //      x = load g
79   //      ...
80   // If anything likes the above "g = PHI(g1, g2)", issue a fatal error.
81   for (Function &F : M)
82     for (auto &BB : F)
83       for (auto &I : BB) {
84         PHINode *PN = dyn_cast<PHINode>(&I);
85         if (!PN || PN->use_empty())
86           continue;
87         for (int i = 0, e = PN->getNumIncomingValues(); i < e; ++i) {
88           auto *GV = dyn_cast<GlobalVariable>(PN->getIncomingValue(i));
89           if (!GV)
90             continue;
91           if (GV->hasAttribute(BPFCoreSharedInfo::AmaAttr) ||
92               GV->hasAttribute(BPFCoreSharedInfo::TypeIdAttr))
93             report_fatal_error("relocation global in PHI node");
94         }
95       }
96 }
97 
98 bool BPFCheckAndAdjustIR::removePassThroughBuiltin(Module &M) {
99   // Remove __builtin_bpf_passthrough()'s which are used to prevent
100   // certain IR optimizations. Now major IR optimizations are done,
101   // remove them.
102   bool Changed = false;
103   CallInst *ToBeDeleted = nullptr;
104   for (Function &F : M)
105     for (auto &BB : F)
106       for (auto &I : BB) {
107         if (ToBeDeleted) {
108           ToBeDeleted->eraseFromParent();
109           ToBeDeleted = nullptr;
110         }
111 
112         auto *Call = dyn_cast<CallInst>(&I);
113         if (!Call)
114           continue;
115         auto *GV = dyn_cast<GlobalValue>(Call->getCalledOperand());
116         if (!GV)
117           continue;
118         if (!GV->getName().startswith("llvm.bpf.passthrough"))
119           continue;
120         Changed = true;
121         Value *Arg = Call->getArgOperand(1);
122         Call->replaceAllUsesWith(Arg);
123         ToBeDeleted = Call;
124       }
125   return Changed;
126 }
127 
128 bool BPFCheckAndAdjustIR::removeCompareBuiltin(Module &M) {
129   // Remove __builtin_bpf_compare()'s which are used to prevent
130   // certain IR optimizations. Now major IR optimizations are done,
131   // remove them.
132   bool Changed = false;
133   CallInst *ToBeDeleted = nullptr;
134   for (Function &F : M)
135     for (auto &BB : F)
136       for (auto &I : BB) {
137         if (ToBeDeleted) {
138           ToBeDeleted->eraseFromParent();
139           ToBeDeleted = nullptr;
140         }
141 
142         auto *Call = dyn_cast<CallInst>(&I);
143         if (!Call)
144           continue;
145         auto *GV = dyn_cast<GlobalValue>(Call->getCalledOperand());
146         if (!GV)
147           continue;
148         if (!GV->getName().startswith("llvm.bpf.compare"))
149           continue;
150 
151         Changed = true;
152         Value *Arg0 = Call->getArgOperand(0);
153         Value *Arg1 = Call->getArgOperand(1);
154         Value *Arg2 = Call->getArgOperand(2);
155 
156         auto OpVal = cast<ConstantInt>(Arg0)->getValue().getZExtValue();
157         CmpInst::Predicate Opcode = (CmpInst::Predicate)OpVal;
158 
159         auto *ICmp = new ICmpInst(Opcode, Arg1, Arg2);
160         ICmp->insertBefore(Call);
161 
162         Call->replaceAllUsesWith(ICmp);
163         ToBeDeleted = Call;
164       }
165   return Changed;
166 }
167 
168 struct MinMaxSinkInfo {
169   ICmpInst *ICmp;
170   Value *Other;
171   ICmpInst::Predicate Predicate;
172   CallInst *MinMax;
173   ZExtInst *ZExt;
174   SExtInst *SExt;
175 
176   MinMaxSinkInfo(ICmpInst *ICmp, Value *Other, ICmpInst::Predicate Predicate)
177       : ICmp(ICmp), Other(Other), Predicate(Predicate), MinMax(nullptr),
178         ZExt(nullptr), SExt(nullptr) {}
179 };
180 
181 static bool sinkMinMaxInBB(BasicBlock &BB,
182                            const std::function<bool(Instruction *)> &Filter) {
183   // Check if V is:
184   //   (fn %a %b) or (ext (fn %a %b))
185   // Where:
186   //   ext := sext | zext
187   //   fn  := smin | umin | smax | umax
188   auto IsMinMaxCall = [=](Value *V, MinMaxSinkInfo &Info) {
189     if (auto *ZExt = dyn_cast<ZExtInst>(V)) {
190       V = ZExt->getOperand(0);
191       Info.ZExt = ZExt;
192     } else if (auto *SExt = dyn_cast<SExtInst>(V)) {
193       V = SExt->getOperand(0);
194       Info.SExt = SExt;
195     }
196 
197     auto *Call = dyn_cast<CallInst>(V);
198     if (!Call)
199       return false;
200 
201     auto *Called = dyn_cast<Function>(Call->getCalledOperand());
202     if (!Called)
203       return false;
204 
205     switch (Called->getIntrinsicID()) {
206     case Intrinsic::smin:
207     case Intrinsic::umin:
208     case Intrinsic::smax:
209     case Intrinsic::umax:
210       break;
211     default:
212       return false;
213     }
214 
215     if (!Filter(Call))
216       return false;
217 
218     Info.MinMax = Call;
219 
220     return true;
221   };
222 
223   auto ZeroOrSignExtend = [](IRBuilder<> &Builder, Value *V,
224                              MinMaxSinkInfo &Info) {
225     if (Info.SExt) {
226       if (Info.SExt->getType() == V->getType())
227         return V;
228       return Builder.CreateSExt(V, Info.SExt->getType());
229     }
230     if (Info.ZExt) {
231       if (Info.ZExt->getType() == V->getType())
232         return V;
233       return Builder.CreateZExt(V, Info.ZExt->getType());
234     }
235     return V;
236   };
237 
238   bool Changed = false;
239   SmallVector<MinMaxSinkInfo, 2> SinkList;
240 
241   // Check BB for instructions like:
242   //   insn := (icmp %a (fn ...)) | (icmp (fn ...)  %a)
243   //
244   // Where:
245   //   fn := min | max | (sext (min ...)) | (sext (max ...))
246   //
247   // Put such instructions to SinkList.
248   for (Instruction &I : BB) {
249     ICmpInst *ICmp = dyn_cast<ICmpInst>(&I);
250     if (!ICmp)
251       continue;
252     if (!ICmp->isRelational())
253       continue;
254     MinMaxSinkInfo First(ICmp, ICmp->getOperand(1),
255                          ICmpInst::getSwappedPredicate(ICmp->getPredicate()));
256     MinMaxSinkInfo Second(ICmp, ICmp->getOperand(0), ICmp->getPredicate());
257     bool FirstMinMax = IsMinMaxCall(ICmp->getOperand(0), First);
258     bool SecondMinMax = IsMinMaxCall(ICmp->getOperand(1), Second);
259     if (!(FirstMinMax ^ SecondMinMax))
260       continue;
261     SinkList.push_back(FirstMinMax ? First : Second);
262   }
263 
264   // Iterate SinkList and replace each (icmp ...) with corresponding
265   // `x < a && x < b` or similar expression.
266   for (auto &Info : SinkList) {
267     ICmpInst *ICmp = Info.ICmp;
268     CallInst *MinMax = Info.MinMax;
269     Intrinsic::ID IID = MinMax->getCalledFunction()->getIntrinsicID();
270     ICmpInst::Predicate P = Info.Predicate;
271     if (ICmpInst::isSigned(P) && IID != Intrinsic::smin &&
272         IID != Intrinsic::smax)
273       continue;
274 
275     IRBuilder<> Builder(ICmp);
276     Value *X = Info.Other;
277     Value *A = ZeroOrSignExtend(Builder, MinMax->getArgOperand(0), Info);
278     Value *B = ZeroOrSignExtend(Builder, MinMax->getArgOperand(1), Info);
279     bool IsMin = IID == Intrinsic::smin || IID == Intrinsic::umin;
280     bool IsMax = IID == Intrinsic::smax || IID == Intrinsic::umax;
281     bool IsLess = ICmpInst::isLE(P) || ICmpInst::isLT(P);
282     bool IsGreater = ICmpInst::isGE(P) || ICmpInst::isGT(P);
283     assert(IsMin ^ IsMax);
284     assert(IsLess ^ IsGreater);
285 
286     Value *Replacement;
287     Value *LHS = Builder.CreateICmp(P, X, A);
288     Value *RHS = Builder.CreateICmp(P, X, B);
289     if ((IsLess && IsMin) || (IsGreater && IsMax))
290       // x < min(a, b) -> x < a && x < b
291       // x > max(a, b) -> x > a && x > b
292       Replacement = Builder.CreateLogicalAnd(LHS, RHS);
293     else
294       // x > min(a, b) -> x > a || x > b
295       // x < max(a, b) -> x < a || x < b
296       Replacement = Builder.CreateLogicalOr(LHS, RHS);
297 
298     ICmp->replaceAllUsesWith(Replacement);
299 
300     Instruction *ToRemove[] = {ICmp, Info.ZExt, Info.SExt, MinMax};
301     for (Instruction *I : ToRemove)
302       if (I && I->use_empty())
303         I->eraseFromParent();
304 
305     Changed = true;
306   }
307 
308   return Changed;
309 }
310 
311 // Do the following transformation:
312 //
313 //   x < min(a, b) -> x < a && x < b
314 //   x > min(a, b) -> x > a || x > b
315 //   x < max(a, b) -> x < a || x < b
316 //   x > max(a, b) -> x > a && x > b
317 //
318 // Such patterns are introduced by LICM.cpp:hoistMinMax()
319 // transformation and might lead to BPF verification failures for
320 // older kernels.
321 //
322 // To minimize "collateral" changes only do it for icmp + min/max
323 // calls when icmp is inside a loop and min/max is outside of that
324 // loop.
325 //
326 // Verification failure happens when:
327 // - RHS operand of some `icmp LHS, RHS` is replaced by some RHS1;
328 // - verifier can recognize RHS as a constant scalar in some context;
329 // - verifier can't recognize RHS1 as a constant scalar in the same
330 //   context;
331 //
332 // The "constant scalar" is not a compile time constant, but a register
333 // that holds a scalar value known to verifier at some point in time
334 // during abstract interpretation.
335 //
336 // See also:
337 //   https://lore.kernel.org/bpf/20230406164505.1046801-1-yhs@fb.com/
338 bool BPFCheckAndAdjustIR::sinkMinMax(Module &M) {
339   bool Changed = false;
340 
341   for (Function &F : M) {
342     if (F.isDeclaration())
343       continue;
344 
345     LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>(F).getLoopInfo();
346     for (Loop *L : LI)
347       for (BasicBlock *BB : L->blocks()) {
348         // Filter out instructions coming from the same loop
349         Loop *BBLoop = LI.getLoopFor(BB);
350         auto OtherLoopFilter = [&](Instruction *I) {
351           return LI.getLoopFor(I->getParent()) != BBLoop;
352         };
353         Changed |= sinkMinMaxInBB(*BB, OtherLoopFilter);
354       }
355   }
356 
357   return Changed;
358 }
359 
360 void BPFCheckAndAdjustIR::getAnalysisUsage(AnalysisUsage &AU) const {
361   AU.addRequired<LoopInfoWrapperPass>();
362 }
363 
364 bool BPFCheckAndAdjustIR::adjustIR(Module &M) {
365   bool Changed = removePassThroughBuiltin(M);
366   Changed = removeCompareBuiltin(M) || Changed;
367   Changed = sinkMinMax(M) || Changed;
368   return Changed;
369 }
370 
371 bool BPFCheckAndAdjustIR::runOnModule(Module &M) {
372   checkIR(M);
373   return adjustIR(M);
374 }
375