xref: /freebsd/contrib/llvm-project/llvm/lib/Target/BPF/BPFAbstractMemberAccess.cpp (revision ca53e5aedfebcc1b4091b68e01b2d5cae923f85e)
1 //===------ BPFAbstractMemberAccess.cpp - Abstracting Member Accesses -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass abstracted struct/union member accesses in order to support
10 // compile-once run-everywhere (CO-RE). The CO-RE intends to compile the program
11 // which can run on different kernels. In particular, if bpf program tries to
12 // access a particular kernel data structure member, the details of the
13 // intermediate member access will be remembered so bpf loader can do
14 // necessary adjustment right before program loading.
15 //
16 // For example,
17 //
18 //   struct s {
19 //     int a;
20 //     int b;
21 //   };
22 //   struct t {
23 //     struct s c;
24 //     int d;
25 //   };
26 //   struct t e;
27 //
28 // For the member access e.c.b, the compiler will generate code
29 //   &e + 4
30 //
31 // The compile-once run-everywhere instead generates the following code
32 //   r = 4
33 //   &e + r
34 // The "4" in "r = 4" can be changed based on a particular kernel version.
35 // For example, on a particular kernel version, if struct s is changed to
36 //
37 //   struct s {
38 //     int new_field;
39 //     int a;
40 //     int b;
41 //   }
42 //
43 // By repeating the member access on the host, the bpf loader can
44 // adjust "r = 4" as "r = 8".
45 //
46 // This feature relies on the following three intrinsic calls:
47 //   addr = preserve_array_access_index(base, dimension, index)
48 //   addr = preserve_union_access_index(base, di_index)
49 //          !llvm.preserve.access.index <union_ditype>
50 //   addr = preserve_struct_access_index(base, gep_index, di_index)
51 //          !llvm.preserve.access.index <struct_ditype>
52 //
53 // Bitfield member access needs special attention. User cannot take the
54 // address of a bitfield acceess. To facilitate kernel verifier
55 // for easy bitfield code optimization, a new clang intrinsic is introduced:
56 //   uint32_t __builtin_preserve_field_info(member_access, info_kind)
57 // In IR, a chain with two (or more) intrinsic calls will be generated:
58 //   ...
59 //   addr = preserve_struct_access_index(base, 1, 1) !struct s
60 //   uint32_t result = bpf_preserve_field_info(addr, info_kind)
61 //
62 // Suppose the info_kind is FIELD_SIGNEDNESS,
63 // The above two IR intrinsics will be replaced with
64 // a relocatable insn:
65 //   signness = /* signness of member_access */
66 // and signness can be changed by bpf loader based on the
67 // types on the host.
68 //
69 // User can also test whether a field exists or not with
70 //   uint32_t result = bpf_preserve_field_info(member_access, FIELD_EXISTENCE)
71 // The field will be always available (result = 1) during initial
72 // compilation, but bpf loader can patch with the correct value
73 // on the target host where the member_access may or may not be available
74 //
75 //===----------------------------------------------------------------------===//
76 
77 #include "BPF.h"
78 #include "BPFCORE.h"
79 #include "BPFTargetMachine.h"
80 #include "llvm/IR/DebugInfoMetadata.h"
81 #include "llvm/IR/GlobalVariable.h"
82 #include "llvm/IR/Instruction.h"
83 #include "llvm/IR/Instructions.h"
84 #include "llvm/IR/Module.h"
85 #include "llvm/IR/Type.h"
86 #include "llvm/IR/User.h"
87 #include "llvm/IR/Value.h"
88 #include "llvm/Pass.h"
89 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
90 #include <stack>
91 
92 #define DEBUG_TYPE "bpf-abstract-member-access"
93 
94 namespace llvm {
95 constexpr StringRef BPFCoreSharedInfo::AmaAttr;
96 } // namespace llvm
97 
98 using namespace llvm;
99 
100 namespace {
101 
102 class BPFAbstractMemberAccess final : public ModulePass {
103   StringRef getPassName() const override {
104     return "BPF Abstract Member Access";
105   }
106 
107   bool runOnModule(Module &M) override;
108 
109 public:
110   static char ID;
111   TargetMachine *TM;
112   // Add optional BPFTargetMachine parameter so that BPF backend can add the phase
113   // with target machine to find out the endianness. The default constructor (without
114   // parameters) is used by the pass manager for managing purposes.
115   BPFAbstractMemberAccess(BPFTargetMachine *TM = nullptr) : ModulePass(ID), TM(TM) {}
116 
117   struct CallInfo {
118     uint32_t Kind;
119     uint32_t AccessIndex;
120     Align RecordAlignment;
121     MDNode *Metadata;
122     Value *Base;
123   };
124   typedef std::stack<std::pair<CallInst *, CallInfo>> CallInfoStack;
125 
126 private:
127   enum : uint32_t {
128     BPFPreserveArrayAI = 1,
129     BPFPreserveUnionAI = 2,
130     BPFPreserveStructAI = 3,
131     BPFPreserveFieldInfoAI = 4,
132   };
133 
134   const DataLayout *DL = nullptr;
135 
136   std::map<std::string, GlobalVariable *> GEPGlobals;
137   // A map to link preserve_*_access_index instrinsic calls.
138   std::map<CallInst *, std::pair<CallInst *, CallInfo>> AIChain;
139   // A map to hold all the base preserve_*_access_index instrinsic calls.
140   // The base call is not an input of any other preserve_*
141   // intrinsics.
142   std::map<CallInst *, CallInfo> BaseAICalls;
143 
144   bool doTransformation(Module &M);
145 
146   void traceAICall(CallInst *Call, CallInfo &ParentInfo);
147   void traceBitCast(BitCastInst *BitCast, CallInst *Parent,
148                     CallInfo &ParentInfo);
149   void traceGEP(GetElementPtrInst *GEP, CallInst *Parent,
150                 CallInfo &ParentInfo);
151   void collectAICallChains(Module &M, Function &F);
152 
153   bool IsPreserveDIAccessIndexCall(const CallInst *Call, CallInfo &Cinfo);
154   bool IsValidAIChain(const MDNode *ParentMeta, uint32_t ParentAI,
155                       const MDNode *ChildMeta);
156   bool removePreserveAccessIndexIntrinsic(Module &M);
157   void replaceWithGEP(std::vector<CallInst *> &CallList,
158                       uint32_t NumOfZerosIndex, uint32_t DIIndex);
159   bool HasPreserveFieldInfoCall(CallInfoStack &CallStack);
160   void GetStorageBitRange(DIDerivedType *MemberTy, Align RecordAlignment,
161                           uint32_t &StartBitOffset, uint32_t &EndBitOffset);
162   uint32_t GetFieldInfo(uint32_t InfoKind, DICompositeType *CTy,
163                         uint32_t AccessIndex, uint32_t PatchImm,
164                         Align RecordAlignment);
165 
166   Value *computeBaseAndAccessKey(CallInst *Call, CallInfo &CInfo,
167                                  std::string &AccessKey, MDNode *&BaseMeta);
168   uint64_t getConstant(const Value *IndexValue);
169   bool transformGEPChain(Module &M, CallInst *Call, CallInfo &CInfo);
170 };
171 } // End anonymous namespace
172 
173 char BPFAbstractMemberAccess::ID = 0;
174 INITIALIZE_PASS(BPFAbstractMemberAccess, DEBUG_TYPE,
175                 "abstracting struct/union member accessees", false, false)
176 
177 ModulePass *llvm::createBPFAbstractMemberAccess(BPFTargetMachine *TM) {
178   return new BPFAbstractMemberAccess(TM);
179 }
180 
181 bool BPFAbstractMemberAccess::runOnModule(Module &M) {
182   LLVM_DEBUG(dbgs() << "********** Abstract Member Accesses **********\n");
183 
184   // Bail out if no debug info.
185   if (M.debug_compile_units().empty())
186     return false;
187 
188   DL = &M.getDataLayout();
189   return doTransformation(M);
190 }
191 
192 static bool SkipDIDerivedTag(unsigned Tag, bool skipTypedef) {
193   if (Tag != dwarf::DW_TAG_typedef && Tag != dwarf::DW_TAG_const_type &&
194       Tag != dwarf::DW_TAG_volatile_type &&
195       Tag != dwarf::DW_TAG_restrict_type &&
196       Tag != dwarf::DW_TAG_member)
197     return false;
198   if (Tag == dwarf::DW_TAG_typedef && !skipTypedef)
199     return false;
200   return true;
201 }
202 
203 static DIType * stripQualifiers(DIType *Ty, bool skipTypedef = true) {
204   while (auto *DTy = dyn_cast<DIDerivedType>(Ty)) {
205     if (!SkipDIDerivedTag(DTy->getTag(), skipTypedef))
206       break;
207     Ty = DTy->getBaseType();
208   }
209   return Ty;
210 }
211 
212 static const DIType * stripQualifiers(const DIType *Ty) {
213   while (auto *DTy = dyn_cast<DIDerivedType>(Ty)) {
214     if (!SkipDIDerivedTag(DTy->getTag(), true))
215       break;
216     Ty = DTy->getBaseType();
217   }
218   return Ty;
219 }
220 
221 static uint32_t calcArraySize(const DICompositeType *CTy, uint32_t StartDim) {
222   DINodeArray Elements = CTy->getElements();
223   uint32_t DimSize = 1;
224   for (uint32_t I = StartDim; I < Elements.size(); ++I) {
225     if (auto *Element = dyn_cast_or_null<DINode>(Elements[I]))
226       if (Element->getTag() == dwarf::DW_TAG_subrange_type) {
227         const DISubrange *SR = cast<DISubrange>(Element);
228         auto *CI = SR->getCount().dyn_cast<ConstantInt *>();
229         DimSize *= CI->getSExtValue();
230       }
231   }
232 
233   return DimSize;
234 }
235 
236 /// Check whether a call is a preserve_*_access_index intrinsic call or not.
237 bool BPFAbstractMemberAccess::IsPreserveDIAccessIndexCall(const CallInst *Call,
238                                                           CallInfo &CInfo) {
239   if (!Call)
240     return false;
241 
242   const auto *GV = dyn_cast<GlobalValue>(Call->getCalledOperand());
243   if (!GV)
244     return false;
245   if (GV->getName().startswith("llvm.preserve.array.access.index")) {
246     CInfo.Kind = BPFPreserveArrayAI;
247     CInfo.Metadata = Call->getMetadata(LLVMContext::MD_preserve_access_index);
248     if (!CInfo.Metadata)
249       report_fatal_error("Missing metadata for llvm.preserve.array.access.index intrinsic");
250     CInfo.AccessIndex = getConstant(Call->getArgOperand(2));
251     CInfo.Base = Call->getArgOperand(0);
252     CInfo.RecordAlignment =
253         DL->getABITypeAlign(CInfo.Base->getType()->getPointerElementType());
254     return true;
255   }
256   if (GV->getName().startswith("llvm.preserve.union.access.index")) {
257     CInfo.Kind = BPFPreserveUnionAI;
258     CInfo.Metadata = Call->getMetadata(LLVMContext::MD_preserve_access_index);
259     if (!CInfo.Metadata)
260       report_fatal_error("Missing metadata for llvm.preserve.union.access.index intrinsic");
261     CInfo.AccessIndex = getConstant(Call->getArgOperand(1));
262     CInfo.Base = Call->getArgOperand(0);
263     CInfo.RecordAlignment =
264         DL->getABITypeAlign(CInfo.Base->getType()->getPointerElementType());
265     return true;
266   }
267   if (GV->getName().startswith("llvm.preserve.struct.access.index")) {
268     CInfo.Kind = BPFPreserveStructAI;
269     CInfo.Metadata = Call->getMetadata(LLVMContext::MD_preserve_access_index);
270     if (!CInfo.Metadata)
271       report_fatal_error("Missing metadata for llvm.preserve.struct.access.index intrinsic");
272     CInfo.AccessIndex = getConstant(Call->getArgOperand(2));
273     CInfo.Base = Call->getArgOperand(0);
274     CInfo.RecordAlignment =
275         DL->getABITypeAlign(CInfo.Base->getType()->getPointerElementType());
276     return true;
277   }
278   if (GV->getName().startswith("llvm.bpf.preserve.field.info")) {
279     CInfo.Kind = BPFPreserveFieldInfoAI;
280     CInfo.Metadata = nullptr;
281     // Check validity of info_kind as clang did not check this.
282     uint64_t InfoKind = getConstant(Call->getArgOperand(1));
283     if (InfoKind >= BPFCoreSharedInfo::MAX_FIELD_RELOC_KIND)
284       report_fatal_error("Incorrect info_kind for llvm.bpf.preserve.field.info intrinsic");
285     CInfo.AccessIndex = InfoKind;
286     return true;
287   }
288 
289   return false;
290 }
291 
292 void BPFAbstractMemberAccess::replaceWithGEP(std::vector<CallInst *> &CallList,
293                                              uint32_t DimensionIndex,
294                                              uint32_t GEPIndex) {
295   for (auto Call : CallList) {
296     uint32_t Dimension = 1;
297     if (DimensionIndex > 0)
298       Dimension = getConstant(Call->getArgOperand(DimensionIndex));
299 
300     Constant *Zero =
301         ConstantInt::get(Type::getInt32Ty(Call->getParent()->getContext()), 0);
302     SmallVector<Value *, 4> IdxList;
303     for (unsigned I = 0; I < Dimension; ++I)
304       IdxList.push_back(Zero);
305     IdxList.push_back(Call->getArgOperand(GEPIndex));
306 
307     auto *GEP = GetElementPtrInst::CreateInBounds(Call->getArgOperand(0),
308                                                   IdxList, "", Call);
309     Call->replaceAllUsesWith(GEP);
310     Call->eraseFromParent();
311   }
312 }
313 
314 bool BPFAbstractMemberAccess::removePreserveAccessIndexIntrinsic(Module &M) {
315   std::vector<CallInst *> PreserveArrayIndexCalls;
316   std::vector<CallInst *> PreserveUnionIndexCalls;
317   std::vector<CallInst *> PreserveStructIndexCalls;
318   bool Found = false;
319 
320   for (Function &F : M)
321     for (auto &BB : F)
322       for (auto &I : BB) {
323         auto *Call = dyn_cast<CallInst>(&I);
324         CallInfo CInfo;
325         if (!IsPreserveDIAccessIndexCall(Call, CInfo))
326           continue;
327 
328         Found = true;
329         if (CInfo.Kind == BPFPreserveArrayAI)
330           PreserveArrayIndexCalls.push_back(Call);
331         else if (CInfo.Kind == BPFPreserveUnionAI)
332           PreserveUnionIndexCalls.push_back(Call);
333         else
334           PreserveStructIndexCalls.push_back(Call);
335       }
336 
337   // do the following transformation:
338   // . addr = preserve_array_access_index(base, dimension, index)
339   //   is transformed to
340   //     addr = GEP(base, dimenion's zero's, index)
341   // . addr = preserve_union_access_index(base, di_index)
342   //   is transformed to
343   //     addr = base, i.e., all usages of "addr" are replaced by "base".
344   // . addr = preserve_struct_access_index(base, gep_index, di_index)
345   //   is transformed to
346   //     addr = GEP(base, 0, gep_index)
347   replaceWithGEP(PreserveArrayIndexCalls, 1, 2);
348   replaceWithGEP(PreserveStructIndexCalls, 0, 1);
349   for (auto Call : PreserveUnionIndexCalls) {
350     Call->replaceAllUsesWith(Call->getArgOperand(0));
351     Call->eraseFromParent();
352   }
353 
354   return Found;
355 }
356 
357 /// Check whether the access index chain is valid. We check
358 /// here because there may be type casts between two
359 /// access indexes. We want to ensure memory access still valid.
360 bool BPFAbstractMemberAccess::IsValidAIChain(const MDNode *ParentType,
361                                              uint32_t ParentAI,
362                                              const MDNode *ChildType) {
363   if (!ChildType)
364     return true; // preserve_field_info, no type comparison needed.
365 
366   const DIType *PType = stripQualifiers(cast<DIType>(ParentType));
367   const DIType *CType = stripQualifiers(cast<DIType>(ChildType));
368 
369   // Child is a derived/pointer type, which is due to type casting.
370   // Pointer type cannot be in the middle of chain.
371   if (isa<DIDerivedType>(CType))
372     return false;
373 
374   // Parent is a pointer type.
375   if (const auto *PtrTy = dyn_cast<DIDerivedType>(PType)) {
376     if (PtrTy->getTag() != dwarf::DW_TAG_pointer_type)
377       return false;
378     return stripQualifiers(PtrTy->getBaseType()) == CType;
379   }
380 
381   // Otherwise, struct/union/array types
382   const auto *PTy = dyn_cast<DICompositeType>(PType);
383   const auto *CTy = dyn_cast<DICompositeType>(CType);
384   assert(PTy && CTy && "ParentType or ChildType is null or not composite");
385 
386   uint32_t PTyTag = PTy->getTag();
387   assert(PTyTag == dwarf::DW_TAG_array_type ||
388          PTyTag == dwarf::DW_TAG_structure_type ||
389          PTyTag == dwarf::DW_TAG_union_type);
390 
391   uint32_t CTyTag = CTy->getTag();
392   assert(CTyTag == dwarf::DW_TAG_array_type ||
393          CTyTag == dwarf::DW_TAG_structure_type ||
394          CTyTag == dwarf::DW_TAG_union_type);
395 
396   // Multi dimensional arrays, base element should be the same
397   if (PTyTag == dwarf::DW_TAG_array_type && PTyTag == CTyTag)
398     return PTy->getBaseType() == CTy->getBaseType();
399 
400   DIType *Ty;
401   if (PTyTag == dwarf::DW_TAG_array_type)
402     Ty = PTy->getBaseType();
403   else
404     Ty = dyn_cast<DIType>(PTy->getElements()[ParentAI]);
405 
406   return dyn_cast<DICompositeType>(stripQualifiers(Ty)) == CTy;
407 }
408 
409 void BPFAbstractMemberAccess::traceAICall(CallInst *Call,
410                                           CallInfo &ParentInfo) {
411   for (User *U : Call->users()) {
412     Instruction *Inst = dyn_cast<Instruction>(U);
413     if (!Inst)
414       continue;
415 
416     if (auto *BI = dyn_cast<BitCastInst>(Inst)) {
417       traceBitCast(BI, Call, ParentInfo);
418     } else if (auto *CI = dyn_cast<CallInst>(Inst)) {
419       CallInfo ChildInfo;
420 
421       if (IsPreserveDIAccessIndexCall(CI, ChildInfo) &&
422           IsValidAIChain(ParentInfo.Metadata, ParentInfo.AccessIndex,
423                          ChildInfo.Metadata)) {
424         AIChain[CI] = std::make_pair(Call, ParentInfo);
425         traceAICall(CI, ChildInfo);
426       } else {
427         BaseAICalls[Call] = ParentInfo;
428       }
429     } else if (auto *GI = dyn_cast<GetElementPtrInst>(Inst)) {
430       if (GI->hasAllZeroIndices())
431         traceGEP(GI, Call, ParentInfo);
432       else
433         BaseAICalls[Call] = ParentInfo;
434     } else {
435       BaseAICalls[Call] = ParentInfo;
436     }
437   }
438 }
439 
440 void BPFAbstractMemberAccess::traceBitCast(BitCastInst *BitCast,
441                                            CallInst *Parent,
442                                            CallInfo &ParentInfo) {
443   for (User *U : BitCast->users()) {
444     Instruction *Inst = dyn_cast<Instruction>(U);
445     if (!Inst)
446       continue;
447 
448     if (auto *BI = dyn_cast<BitCastInst>(Inst)) {
449       traceBitCast(BI, Parent, ParentInfo);
450     } else if (auto *CI = dyn_cast<CallInst>(Inst)) {
451       CallInfo ChildInfo;
452       if (IsPreserveDIAccessIndexCall(CI, ChildInfo) &&
453           IsValidAIChain(ParentInfo.Metadata, ParentInfo.AccessIndex,
454                          ChildInfo.Metadata)) {
455         AIChain[CI] = std::make_pair(Parent, ParentInfo);
456         traceAICall(CI, ChildInfo);
457       } else {
458         BaseAICalls[Parent] = ParentInfo;
459       }
460     } else if (auto *GI = dyn_cast<GetElementPtrInst>(Inst)) {
461       if (GI->hasAllZeroIndices())
462         traceGEP(GI, Parent, ParentInfo);
463       else
464         BaseAICalls[Parent] = ParentInfo;
465     } else {
466       BaseAICalls[Parent] = ParentInfo;
467     }
468   }
469 }
470 
471 void BPFAbstractMemberAccess::traceGEP(GetElementPtrInst *GEP, CallInst *Parent,
472                                        CallInfo &ParentInfo) {
473   for (User *U : GEP->users()) {
474     Instruction *Inst = dyn_cast<Instruction>(U);
475     if (!Inst)
476       continue;
477 
478     if (auto *BI = dyn_cast<BitCastInst>(Inst)) {
479       traceBitCast(BI, Parent, ParentInfo);
480     } else if (auto *CI = dyn_cast<CallInst>(Inst)) {
481       CallInfo ChildInfo;
482       if (IsPreserveDIAccessIndexCall(CI, ChildInfo) &&
483           IsValidAIChain(ParentInfo.Metadata, ParentInfo.AccessIndex,
484                          ChildInfo.Metadata)) {
485         AIChain[CI] = std::make_pair(Parent, ParentInfo);
486         traceAICall(CI, ChildInfo);
487       } else {
488         BaseAICalls[Parent] = ParentInfo;
489       }
490     } else if (auto *GI = dyn_cast<GetElementPtrInst>(Inst)) {
491       if (GI->hasAllZeroIndices())
492         traceGEP(GI, Parent, ParentInfo);
493       else
494         BaseAICalls[Parent] = ParentInfo;
495     } else {
496       BaseAICalls[Parent] = ParentInfo;
497     }
498   }
499 }
500 
501 void BPFAbstractMemberAccess::collectAICallChains(Module &M, Function &F) {
502   AIChain.clear();
503   BaseAICalls.clear();
504 
505   for (auto &BB : F)
506     for (auto &I : BB) {
507       CallInfo CInfo;
508       auto *Call = dyn_cast<CallInst>(&I);
509       if (!IsPreserveDIAccessIndexCall(Call, CInfo) ||
510           AIChain.find(Call) != AIChain.end())
511         continue;
512 
513       traceAICall(Call, CInfo);
514     }
515 }
516 
517 uint64_t BPFAbstractMemberAccess::getConstant(const Value *IndexValue) {
518   const ConstantInt *CV = dyn_cast<ConstantInt>(IndexValue);
519   assert(CV);
520   return CV->getValue().getZExtValue();
521 }
522 
523 /// Get the start and the end of storage offset for \p MemberTy.
524 void BPFAbstractMemberAccess::GetStorageBitRange(DIDerivedType *MemberTy,
525                                                  Align RecordAlignment,
526                                                  uint32_t &StartBitOffset,
527                                                  uint32_t &EndBitOffset) {
528   uint32_t MemberBitSize = MemberTy->getSizeInBits();
529   uint32_t MemberBitOffset = MemberTy->getOffsetInBits();
530   uint32_t AlignBits = RecordAlignment.value() * 8;
531   if (RecordAlignment > 8 || MemberBitSize > AlignBits)
532     report_fatal_error("Unsupported field expression for llvm.bpf.preserve.field.info, "
533                        "requiring too big alignment");
534 
535   StartBitOffset = MemberBitOffset & ~(AlignBits - 1);
536   if ((StartBitOffset + AlignBits) < (MemberBitOffset + MemberBitSize))
537     report_fatal_error("Unsupported field expression for llvm.bpf.preserve.field.info, "
538                        "cross alignment boundary");
539   EndBitOffset = StartBitOffset + AlignBits;
540 }
541 
542 uint32_t BPFAbstractMemberAccess::GetFieldInfo(uint32_t InfoKind,
543                                                DICompositeType *CTy,
544                                                uint32_t AccessIndex,
545                                                uint32_t PatchImm,
546                                                Align RecordAlignment) {
547   if (InfoKind == BPFCoreSharedInfo::FIELD_EXISTENCE)
548       return 1;
549 
550   uint32_t Tag = CTy->getTag();
551   if (InfoKind == BPFCoreSharedInfo::FIELD_BYTE_OFFSET) {
552     if (Tag == dwarf::DW_TAG_array_type) {
553       auto *EltTy = stripQualifiers(CTy->getBaseType());
554       PatchImm += AccessIndex * calcArraySize(CTy, 1) *
555                   (EltTy->getSizeInBits() >> 3);
556     } else if (Tag == dwarf::DW_TAG_structure_type) {
557       auto *MemberTy = cast<DIDerivedType>(CTy->getElements()[AccessIndex]);
558       if (!MemberTy->isBitField()) {
559         PatchImm += MemberTy->getOffsetInBits() >> 3;
560       } else {
561         unsigned SBitOffset, NextSBitOffset;
562         GetStorageBitRange(MemberTy, RecordAlignment, SBitOffset,
563                            NextSBitOffset);
564         PatchImm += SBitOffset >> 3;
565       }
566     }
567     return PatchImm;
568   }
569 
570   if (InfoKind == BPFCoreSharedInfo::FIELD_BYTE_SIZE) {
571     if (Tag == dwarf::DW_TAG_array_type) {
572       auto *EltTy = stripQualifiers(CTy->getBaseType());
573       return calcArraySize(CTy, 1) * (EltTy->getSizeInBits() >> 3);
574     } else {
575       auto *MemberTy = cast<DIDerivedType>(CTy->getElements()[AccessIndex]);
576       uint32_t SizeInBits = MemberTy->getSizeInBits();
577       if (!MemberTy->isBitField())
578         return SizeInBits >> 3;
579 
580       unsigned SBitOffset, NextSBitOffset;
581       GetStorageBitRange(MemberTy, RecordAlignment, SBitOffset, NextSBitOffset);
582       SizeInBits = NextSBitOffset - SBitOffset;
583       if (SizeInBits & (SizeInBits - 1))
584         report_fatal_error("Unsupported field expression for llvm.bpf.preserve.field.info");
585       return SizeInBits >> 3;
586     }
587   }
588 
589   if (InfoKind == BPFCoreSharedInfo::FIELD_SIGNEDNESS) {
590     const DIType *BaseTy;
591     if (Tag == dwarf::DW_TAG_array_type) {
592       // Signedness only checked when final array elements are accessed.
593       if (CTy->getElements().size() != 1)
594         report_fatal_error("Invalid array expression for llvm.bpf.preserve.field.info");
595       BaseTy = stripQualifiers(CTy->getBaseType());
596     } else {
597       auto *MemberTy = cast<DIDerivedType>(CTy->getElements()[AccessIndex]);
598       BaseTy = stripQualifiers(MemberTy->getBaseType());
599     }
600 
601     // Only basic types and enum types have signedness.
602     const auto *BTy = dyn_cast<DIBasicType>(BaseTy);
603     while (!BTy) {
604       const auto *CompTy = dyn_cast<DICompositeType>(BaseTy);
605       // Report an error if the field expression does not have signedness.
606       if (!CompTy || CompTy->getTag() != dwarf::DW_TAG_enumeration_type)
607         report_fatal_error("Invalid field expression for llvm.bpf.preserve.field.info");
608       BaseTy = stripQualifiers(CompTy->getBaseType());
609       BTy = dyn_cast<DIBasicType>(BaseTy);
610     }
611     uint32_t Encoding = BTy->getEncoding();
612     return (Encoding == dwarf::DW_ATE_signed || Encoding == dwarf::DW_ATE_signed_char);
613   }
614 
615   if (InfoKind == BPFCoreSharedInfo::FIELD_LSHIFT_U64) {
616     // The value is loaded into a value with FIELD_BYTE_SIZE size,
617     // and then zero or sign extended to U64.
618     // FIELD_LSHIFT_U64 and FIELD_RSHIFT_U64 are operations
619     // to extract the original value.
620     const Triple &Triple = TM->getTargetTriple();
621     DIDerivedType *MemberTy = nullptr;
622     bool IsBitField = false;
623     uint32_t SizeInBits;
624 
625     if (Tag == dwarf::DW_TAG_array_type) {
626       auto *EltTy = stripQualifiers(CTy->getBaseType());
627       SizeInBits = calcArraySize(CTy, 1) * EltTy->getSizeInBits();
628     } else {
629       MemberTy = cast<DIDerivedType>(CTy->getElements()[AccessIndex]);
630       SizeInBits = MemberTy->getSizeInBits();
631       IsBitField = MemberTy->isBitField();
632     }
633 
634     if (!IsBitField) {
635       if (SizeInBits > 64)
636         report_fatal_error("too big field size for llvm.bpf.preserve.field.info");
637       return 64 - SizeInBits;
638     }
639 
640     unsigned SBitOffset, NextSBitOffset;
641     GetStorageBitRange(MemberTy, RecordAlignment, SBitOffset, NextSBitOffset);
642     if (NextSBitOffset - SBitOffset > 64)
643       report_fatal_error("too big field size for llvm.bpf.preserve.field.info");
644 
645     unsigned OffsetInBits = MemberTy->getOffsetInBits();
646     if (Triple.getArch() == Triple::bpfel)
647       return SBitOffset + 64 - OffsetInBits - SizeInBits;
648     else
649       return OffsetInBits + 64 - NextSBitOffset;
650   }
651 
652   if (InfoKind == BPFCoreSharedInfo::FIELD_RSHIFT_U64) {
653     DIDerivedType *MemberTy = nullptr;
654     bool IsBitField = false;
655     uint32_t SizeInBits;
656     if (Tag == dwarf::DW_TAG_array_type) {
657       auto *EltTy = stripQualifiers(CTy->getBaseType());
658       SizeInBits = calcArraySize(CTy, 1) * EltTy->getSizeInBits();
659     } else {
660       MemberTy = cast<DIDerivedType>(CTy->getElements()[AccessIndex]);
661       SizeInBits = MemberTy->getSizeInBits();
662       IsBitField = MemberTy->isBitField();
663     }
664 
665     if (!IsBitField) {
666       if (SizeInBits > 64)
667         report_fatal_error("too big field size for llvm.bpf.preserve.field.info");
668       return 64 - SizeInBits;
669     }
670 
671     unsigned SBitOffset, NextSBitOffset;
672     GetStorageBitRange(MemberTy, RecordAlignment, SBitOffset, NextSBitOffset);
673     if (NextSBitOffset - SBitOffset > 64)
674       report_fatal_error("too big field size for llvm.bpf.preserve.field.info");
675 
676     return 64 - SizeInBits;
677   }
678 
679   llvm_unreachable("Unknown llvm.bpf.preserve.field.info info kind");
680 }
681 
682 bool BPFAbstractMemberAccess::HasPreserveFieldInfoCall(CallInfoStack &CallStack) {
683   // This is called in error return path, no need to maintain CallStack.
684   while (CallStack.size()) {
685     auto StackElem = CallStack.top();
686     if (StackElem.second.Kind == BPFPreserveFieldInfoAI)
687       return true;
688     CallStack.pop();
689   }
690   return false;
691 }
692 
693 /// Compute the base of the whole preserve_* intrinsics chains, i.e., the base
694 /// pointer of the first preserve_*_access_index call, and construct the access
695 /// string, which will be the name of a global variable.
696 Value *BPFAbstractMemberAccess::computeBaseAndAccessKey(CallInst *Call,
697                                                         CallInfo &CInfo,
698                                                         std::string &AccessKey,
699                                                         MDNode *&TypeMeta) {
700   Value *Base = nullptr;
701   std::string TypeName;
702   CallInfoStack CallStack;
703 
704   // Put the access chain into a stack with the top as the head of the chain.
705   while (Call) {
706     CallStack.push(std::make_pair(Call, CInfo));
707     CInfo = AIChain[Call].second;
708     Call = AIChain[Call].first;
709   }
710 
711   // The access offset from the base of the head of chain is also
712   // calculated here as all debuginfo types are available.
713 
714   // Get type name and calculate the first index.
715   // We only want to get type name from typedef, structure or union.
716   // If user wants a relocation like
717   //    int *p; ... __builtin_preserve_access_index(&p[4]) ...
718   // or
719   //    int a[10][20]; ... __builtin_preserve_access_index(&a[2][3]) ...
720   // we will skip them.
721   uint32_t FirstIndex = 0;
722   uint32_t PatchImm = 0; // AccessOffset or the requested field info
723   uint32_t InfoKind = BPFCoreSharedInfo::FIELD_BYTE_OFFSET;
724   while (CallStack.size()) {
725     auto StackElem = CallStack.top();
726     Call = StackElem.first;
727     CInfo = StackElem.second;
728 
729     if (!Base)
730       Base = CInfo.Base;
731 
732     DIType *PossibleTypeDef = stripQualifiers(cast<DIType>(CInfo.Metadata),
733                                               false);
734     DIType *Ty = stripQualifiers(PossibleTypeDef);
735     if (CInfo.Kind == BPFPreserveUnionAI ||
736         CInfo.Kind == BPFPreserveStructAI) {
737       // struct or union type. If the typedef is in the metadata, always
738       // use the typedef.
739       TypeName = std::string(PossibleTypeDef->getName());
740       TypeMeta = PossibleTypeDef;
741       PatchImm += FirstIndex * (Ty->getSizeInBits() >> 3);
742       break;
743     }
744 
745     assert(CInfo.Kind == BPFPreserveArrayAI);
746 
747     // Array entries will always be consumed for accumulative initial index.
748     CallStack.pop();
749 
750     // BPFPreserveArrayAI
751     uint64_t AccessIndex = CInfo.AccessIndex;
752 
753     DIType *BaseTy = nullptr;
754     bool CheckElemType = false;
755     if (const auto *CTy = dyn_cast<DICompositeType>(Ty)) {
756       // array type
757       assert(CTy->getTag() == dwarf::DW_TAG_array_type);
758 
759 
760       FirstIndex += AccessIndex * calcArraySize(CTy, 1);
761       BaseTy = stripQualifiers(CTy->getBaseType());
762       CheckElemType = CTy->getElements().size() == 1;
763     } else {
764       // pointer type
765       auto *DTy = cast<DIDerivedType>(Ty);
766       assert(DTy->getTag() == dwarf::DW_TAG_pointer_type);
767 
768       BaseTy = stripQualifiers(DTy->getBaseType());
769       CTy = dyn_cast<DICompositeType>(BaseTy);
770       if (!CTy) {
771         CheckElemType = true;
772       } else if (CTy->getTag() != dwarf::DW_TAG_array_type) {
773         FirstIndex += AccessIndex;
774         CheckElemType = true;
775       } else {
776         FirstIndex += AccessIndex * calcArraySize(CTy, 0);
777       }
778     }
779 
780     if (CheckElemType) {
781       auto *CTy = dyn_cast<DICompositeType>(BaseTy);
782       if (!CTy) {
783         if (HasPreserveFieldInfoCall(CallStack))
784           report_fatal_error("Invalid field access for llvm.preserve.field.info intrinsic");
785         return nullptr;
786       }
787 
788       unsigned CTag = CTy->getTag();
789       if (CTag == dwarf::DW_TAG_structure_type || CTag == dwarf::DW_TAG_union_type) {
790         TypeName = std::string(CTy->getName());
791       } else {
792         if (HasPreserveFieldInfoCall(CallStack))
793           report_fatal_error("Invalid field access for llvm.preserve.field.info intrinsic");
794         return nullptr;
795       }
796       TypeMeta = CTy;
797       PatchImm += FirstIndex * (CTy->getSizeInBits() >> 3);
798       break;
799     }
800   }
801   assert(TypeName.size());
802   AccessKey += std::to_string(FirstIndex);
803 
804   // Traverse the rest of access chain to complete offset calculation
805   // and access key construction.
806   while (CallStack.size()) {
807     auto StackElem = CallStack.top();
808     CInfo = StackElem.second;
809     CallStack.pop();
810 
811     if (CInfo.Kind == BPFPreserveFieldInfoAI) {
812       InfoKind = CInfo.AccessIndex;
813       break;
814     }
815 
816     // If the next Call (the top of the stack) is a BPFPreserveFieldInfoAI,
817     // the action will be extracting field info.
818     if (CallStack.size()) {
819       auto StackElem2 = CallStack.top();
820       CallInfo CInfo2 = StackElem2.second;
821       if (CInfo2.Kind == BPFPreserveFieldInfoAI) {
822         InfoKind = CInfo2.AccessIndex;
823         assert(CallStack.size() == 1);
824       }
825     }
826 
827     // Access Index
828     uint64_t AccessIndex = CInfo.AccessIndex;
829     AccessKey += ":" + std::to_string(AccessIndex);
830 
831     MDNode *MDN = CInfo.Metadata;
832     // At this stage, it cannot be pointer type.
833     auto *CTy = cast<DICompositeType>(stripQualifiers(cast<DIType>(MDN)));
834     PatchImm = GetFieldInfo(InfoKind, CTy, AccessIndex, PatchImm,
835                             CInfo.RecordAlignment);
836   }
837 
838   // Access key is the
839   //   "llvm." + type name + ":" + reloc type + ":" + patched imm + "$" +
840   //   access string,
841   // uniquely identifying one relocation.
842   // The prefix "llvm." indicates this is a temporary global, which should
843   // not be emitted to ELF file.
844   AccessKey = "llvm." + TypeName + ":" + std::to_string(InfoKind) + ":" +
845               std::to_string(PatchImm) + "$" + AccessKey;
846 
847   return Base;
848 }
849 
850 /// Call/Kind is the base preserve_*_access_index() call. Attempts to do
851 /// transformation to a chain of relocable GEPs.
852 bool BPFAbstractMemberAccess::transformGEPChain(Module &M, CallInst *Call,
853                                                 CallInfo &CInfo) {
854   std::string AccessKey;
855   MDNode *TypeMeta;
856   Value *Base =
857       computeBaseAndAccessKey(Call, CInfo, AccessKey, TypeMeta);
858   if (!Base)
859     return false;
860 
861   BasicBlock *BB = Call->getParent();
862   GlobalVariable *GV;
863 
864   if (GEPGlobals.find(AccessKey) == GEPGlobals.end()) {
865     IntegerType *VarType;
866     if (CInfo.Kind == BPFPreserveFieldInfoAI)
867       VarType = Type::getInt32Ty(BB->getContext()); // 32bit return value
868     else
869       VarType = Type::getInt64Ty(BB->getContext()); // 64bit ptr arith
870 
871     GV = new GlobalVariable(M, VarType, false, GlobalVariable::ExternalLinkage,
872                             NULL, AccessKey);
873     GV->addAttribute(BPFCoreSharedInfo::AmaAttr);
874     GV->setMetadata(LLVMContext::MD_preserve_access_index, TypeMeta);
875     GEPGlobals[AccessKey] = GV;
876   } else {
877     GV = GEPGlobals[AccessKey];
878   }
879 
880   if (CInfo.Kind == BPFPreserveFieldInfoAI) {
881     // Load the global variable which represents the returned field info.
882     auto *LDInst = new LoadInst(Type::getInt32Ty(BB->getContext()), GV, "",
883                                 Call);
884     Call->replaceAllUsesWith(LDInst);
885     Call->eraseFromParent();
886     return true;
887   }
888 
889   // For any original GEP Call and Base %2 like
890   //   %4 = bitcast %struct.net_device** %dev1 to i64*
891   // it is transformed to:
892   //   %6 = load sk_buff:50:$0:0:0:2:0
893   //   %7 = bitcast %struct.sk_buff* %2 to i8*
894   //   %8 = getelementptr i8, i8* %7, %6
895   //   %9 = bitcast i8* %8 to i64*
896   //   using %9 instead of %4
897   // The original Call inst is removed.
898 
899   // Load the global variable.
900   auto *LDInst = new LoadInst(Type::getInt64Ty(BB->getContext()), GV, "", Call);
901 
902   // Generate a BitCast
903   auto *BCInst = new BitCastInst(Base, Type::getInt8PtrTy(BB->getContext()));
904   BB->getInstList().insert(Call->getIterator(), BCInst);
905 
906   // Generate a GetElementPtr
907   auto *GEP = GetElementPtrInst::Create(Type::getInt8Ty(BB->getContext()),
908                                         BCInst, LDInst);
909   BB->getInstList().insert(Call->getIterator(), GEP);
910 
911   // Generate a BitCast
912   auto *BCInst2 = new BitCastInst(GEP, Call->getType());
913   BB->getInstList().insert(Call->getIterator(), BCInst2);
914 
915   Call->replaceAllUsesWith(BCInst2);
916   Call->eraseFromParent();
917 
918   return true;
919 }
920 
921 bool BPFAbstractMemberAccess::doTransformation(Module &M) {
922   bool Transformed = false;
923 
924   for (Function &F : M) {
925     // Collect PreserveDIAccessIndex Intrinsic call chains.
926     // The call chains will be used to generate the access
927     // patterns similar to GEP.
928     collectAICallChains(M, F);
929 
930     for (auto &C : BaseAICalls)
931       Transformed = transformGEPChain(M, C.first, C.second) || Transformed;
932   }
933 
934   return removePreserveAccessIndexIntrinsic(M) || Transformed;
935 }
936