xref: /freebsd/contrib/llvm-project/llvm/lib/Frontend/Offloading/OffloadWrapper.cpp (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1  //===- OffloadWrapper.cpp ---------------------------------------*- C++ -*-===//
2  //
3  // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4  // See https://llvm.org/LICENSE.txt for license information.
5  // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6  //
7  //===----------------------------------------------------------------------===//
8  
9  #include "llvm/Frontend/Offloading/OffloadWrapper.h"
10  #include "llvm/ADT/ArrayRef.h"
11  #include "llvm/BinaryFormat/Magic.h"
12  #include "llvm/Frontend/Offloading/Utility.h"
13  #include "llvm/IR/Constants.h"
14  #include "llvm/IR/GlobalVariable.h"
15  #include "llvm/IR/IRBuilder.h"
16  #include "llvm/IR/LLVMContext.h"
17  #include "llvm/IR/Module.h"
18  #include "llvm/Object/OffloadBinary.h"
19  #include "llvm/Support/Error.h"
20  #include "llvm/TargetParser/Triple.h"
21  #include "llvm/Transforms/Utils/ModuleUtils.h"
22  
23  using namespace llvm;
24  using namespace llvm::offloading;
25  
26  namespace {
27  /// Magic number that begins the section containing the CUDA fatbinary.
28  constexpr unsigned CudaFatMagic = 0x466243b1;
29  constexpr unsigned HIPFatMagic = 0x48495046;
30  
getSizeTTy(Module & M)31  IntegerType *getSizeTTy(Module &M) {
32    return M.getDataLayout().getIntPtrType(M.getContext());
33  }
34  
35  // struct __tgt_device_image {
36  //   void *ImageStart;
37  //   void *ImageEnd;
38  //   __tgt_offload_entry *EntriesBegin;
39  //   __tgt_offload_entry *EntriesEnd;
40  // };
getDeviceImageTy(Module & M)41  StructType *getDeviceImageTy(Module &M) {
42    LLVMContext &C = M.getContext();
43    StructType *ImageTy = StructType::getTypeByName(C, "__tgt_device_image");
44    if (!ImageTy)
45      ImageTy =
46          StructType::create("__tgt_device_image", PointerType::getUnqual(C),
47                             PointerType::getUnqual(C), PointerType::getUnqual(C),
48                             PointerType::getUnqual(C));
49    return ImageTy;
50  }
51  
getDeviceImagePtrTy(Module & M)52  PointerType *getDeviceImagePtrTy(Module &M) {
53    return PointerType::getUnqual(getDeviceImageTy(M));
54  }
55  
56  // struct __tgt_bin_desc {
57  //   int32_t NumDeviceImages;
58  //   __tgt_device_image *DeviceImages;
59  //   __tgt_offload_entry *HostEntriesBegin;
60  //   __tgt_offload_entry *HostEntriesEnd;
61  // };
getBinDescTy(Module & M)62  StructType *getBinDescTy(Module &M) {
63    LLVMContext &C = M.getContext();
64    StructType *DescTy = StructType::getTypeByName(C, "__tgt_bin_desc");
65    if (!DescTy)
66      DescTy = StructType::create(
67          "__tgt_bin_desc", Type::getInt32Ty(C), getDeviceImagePtrTy(M),
68          PointerType::getUnqual(C), PointerType::getUnqual(C));
69    return DescTy;
70  }
71  
getBinDescPtrTy(Module & M)72  PointerType *getBinDescPtrTy(Module &M) {
73    return PointerType::getUnqual(getBinDescTy(M));
74  }
75  
76  /// Creates binary descriptor for the given device images. Binary descriptor
77  /// is an object that is passed to the offloading runtime at program startup
78  /// and it describes all device images available in the executable or shared
79  /// library. It is defined as follows
80  ///
81  /// __attribute__((visibility("hidden")))
82  /// extern __tgt_offload_entry *__start_omp_offloading_entries;
83  /// __attribute__((visibility("hidden")))
84  /// extern __tgt_offload_entry *__stop_omp_offloading_entries;
85  ///
86  /// static const char Image0[] = { <Bufs.front() contents> };
87  ///  ...
88  /// static const char ImageN[] = { <Bufs.back() contents> };
89  ///
90  /// static const __tgt_device_image Images[] = {
91  ///   {
92  ///     Image0,                            /*ImageStart*/
93  ///     Image0 + sizeof(Image0),           /*ImageEnd*/
94  ///     __start_omp_offloading_entries,    /*EntriesBegin*/
95  ///     __stop_omp_offloading_entries      /*EntriesEnd*/
96  ///   },
97  ///   ...
98  ///   {
99  ///     ImageN,                            /*ImageStart*/
100  ///     ImageN + sizeof(ImageN),           /*ImageEnd*/
101  ///     __start_omp_offloading_entries,    /*EntriesBegin*/
102  ///     __stop_omp_offloading_entries      /*EntriesEnd*/
103  ///   }
104  /// };
105  ///
106  /// static const __tgt_bin_desc BinDesc = {
107  ///   sizeof(Images) / sizeof(Images[0]),  /*NumDeviceImages*/
108  ///   Images,                              /*DeviceImages*/
109  ///   __start_omp_offloading_entries,      /*HostEntriesBegin*/
110  ///   __stop_omp_offloading_entries        /*HostEntriesEnd*/
111  /// };
112  ///
113  /// Global variable that represents BinDesc is returned.
createBinDesc(Module & M,ArrayRef<ArrayRef<char>> Bufs,EntryArrayTy EntryArray,StringRef Suffix,bool Relocatable)114  GlobalVariable *createBinDesc(Module &M, ArrayRef<ArrayRef<char>> Bufs,
115                                EntryArrayTy EntryArray, StringRef Suffix,
116                                bool Relocatable) {
117    LLVMContext &C = M.getContext();
118    auto [EntriesB, EntriesE] = EntryArray;
119  
120    auto *Zero = ConstantInt::get(getSizeTTy(M), 0u);
121    Constant *ZeroZero[] = {Zero, Zero};
122  
123    // Create initializer for the images array.
124    SmallVector<Constant *, 4u> ImagesInits;
125    ImagesInits.reserve(Bufs.size());
126    for (ArrayRef<char> Buf : Bufs) {
127      // We embed the full offloading entry so the binary utilities can parse it.
128      auto *Data = ConstantDataArray::get(C, Buf);
129      auto *Image = new GlobalVariable(M, Data->getType(), /*isConstant=*/true,
130                                       GlobalVariable::InternalLinkage, Data,
131                                       ".omp_offloading.device_image" + Suffix);
132      Image->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
133      Image->setSection(Relocatable ? ".llvm.offloading.relocatable"
134                                    : ".llvm.offloading");
135      Image->setAlignment(Align(object::OffloadBinary::getAlignment()));
136  
137      StringRef Binary(Buf.data(), Buf.size());
138      assert(identify_magic(Binary) == file_magic::offload_binary &&
139             "Invalid binary format");
140  
141      // The device image struct contains the pointer to the beginning and end of
142      // the image stored inside of the offload binary. There should only be one
143      // of these for each buffer so we parse it out manually.
144      const auto *Header =
145          reinterpret_cast<const object::OffloadBinary::Header *>(
146              Binary.bytes_begin());
147      const auto *Entry = reinterpret_cast<const object::OffloadBinary::Entry *>(
148          Binary.bytes_begin() + Header->EntryOffset);
149  
150      auto *Begin = ConstantInt::get(getSizeTTy(M), Entry->ImageOffset);
151      auto *Size =
152          ConstantInt::get(getSizeTTy(M), Entry->ImageOffset + Entry->ImageSize);
153      Constant *ZeroBegin[] = {Zero, Begin};
154      Constant *ZeroSize[] = {Zero, Size};
155  
156      auto *ImageB =
157          ConstantExpr::getGetElementPtr(Image->getValueType(), Image, ZeroBegin);
158      auto *ImageE =
159          ConstantExpr::getGetElementPtr(Image->getValueType(), Image, ZeroSize);
160  
161      ImagesInits.push_back(ConstantStruct::get(getDeviceImageTy(M), ImageB,
162                                                ImageE, EntriesB, EntriesE));
163    }
164  
165    // Then create images array.
166    auto *ImagesData = ConstantArray::get(
167        ArrayType::get(getDeviceImageTy(M), ImagesInits.size()), ImagesInits);
168  
169    auto *Images =
170        new GlobalVariable(M, ImagesData->getType(), /*isConstant*/ true,
171                           GlobalValue::InternalLinkage, ImagesData,
172                           ".omp_offloading.device_images" + Suffix);
173    Images->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
174  
175    auto *ImagesB =
176        ConstantExpr::getGetElementPtr(Images->getValueType(), Images, ZeroZero);
177  
178    // And finally create the binary descriptor object.
179    auto *DescInit = ConstantStruct::get(
180        getBinDescTy(M),
181        ConstantInt::get(Type::getInt32Ty(C), ImagesInits.size()), ImagesB,
182        EntriesB, EntriesE);
183  
184    return new GlobalVariable(M, DescInit->getType(), /*isConstant*/ true,
185                              GlobalValue::InternalLinkage, DescInit,
186                              ".omp_offloading.descriptor" + Suffix);
187  }
188  
createUnregisterFunction(Module & M,GlobalVariable * BinDesc,StringRef Suffix)189  Function *createUnregisterFunction(Module &M, GlobalVariable *BinDesc,
190                                     StringRef Suffix) {
191    LLVMContext &C = M.getContext();
192    auto *FuncTy = FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false);
193    auto *Func =
194        Function::Create(FuncTy, GlobalValue::InternalLinkage,
195                         ".omp_offloading.descriptor_unreg" + Suffix, &M);
196    Func->setSection(".text.startup");
197  
198    // Get __tgt_unregister_lib function declaration.
199    auto *UnRegFuncTy = FunctionType::get(Type::getVoidTy(C), getBinDescPtrTy(M),
200                                          /*isVarArg*/ false);
201    FunctionCallee UnRegFuncC =
202        M.getOrInsertFunction("__tgt_unregister_lib", UnRegFuncTy);
203  
204    // Construct function body
205    IRBuilder<> Builder(BasicBlock::Create(C, "entry", Func));
206    Builder.CreateCall(UnRegFuncC, BinDesc);
207    Builder.CreateRetVoid();
208  
209    return Func;
210  }
211  
createRegisterFunction(Module & M,GlobalVariable * BinDesc,StringRef Suffix)212  void createRegisterFunction(Module &M, GlobalVariable *BinDesc,
213                              StringRef Suffix) {
214    LLVMContext &C = M.getContext();
215    auto *FuncTy = FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false);
216    auto *Func = Function::Create(FuncTy, GlobalValue::InternalLinkage,
217                                  ".omp_offloading.descriptor_reg" + Suffix, &M);
218    Func->setSection(".text.startup");
219  
220    // Get __tgt_register_lib function declaration.
221    auto *RegFuncTy = FunctionType::get(Type::getVoidTy(C), getBinDescPtrTy(M),
222                                        /*isVarArg*/ false);
223    FunctionCallee RegFuncC =
224        M.getOrInsertFunction("__tgt_register_lib", RegFuncTy);
225  
226    auto *AtExitTy = FunctionType::get(
227        Type::getInt32Ty(C), PointerType::getUnqual(C), /*isVarArg=*/false);
228    FunctionCallee AtExit = M.getOrInsertFunction("atexit", AtExitTy);
229  
230    Function *UnregFunc = createUnregisterFunction(M, BinDesc, Suffix);
231  
232    // Construct function body
233    IRBuilder<> Builder(BasicBlock::Create(C, "entry", Func));
234  
235    Builder.CreateCall(RegFuncC, BinDesc);
236  
237    // Register the destructors with 'atexit'. This is expected by the CUDA
238    // runtime and ensures that we clean up before dynamic objects are destroyed.
239    // This needs to be done after plugin initialization to ensure that it is
240    // called before the plugin runtime is destroyed.
241    Builder.CreateCall(AtExit, UnregFunc);
242    Builder.CreateRetVoid();
243  
244    // Add this function to constructors.
245    appendToGlobalCtors(M, Func, /*Priority=*/101);
246  }
247  
248  // struct fatbin_wrapper {
249  //  int32_t magic;
250  //  int32_t version;
251  //  void *image;
252  //  void *reserved;
253  //};
getFatbinWrapperTy(Module & M)254  StructType *getFatbinWrapperTy(Module &M) {
255    LLVMContext &C = M.getContext();
256    StructType *FatbinTy = StructType::getTypeByName(C, "fatbin_wrapper");
257    if (!FatbinTy)
258      FatbinTy = StructType::create(
259          "fatbin_wrapper", Type::getInt32Ty(C), Type::getInt32Ty(C),
260          PointerType::getUnqual(C), PointerType::getUnqual(C));
261    return FatbinTy;
262  }
263  
264  /// Embed the image \p Image into the module \p M so it can be found by the
265  /// runtime.
createFatbinDesc(Module & M,ArrayRef<char> Image,bool IsHIP,StringRef Suffix)266  GlobalVariable *createFatbinDesc(Module &M, ArrayRef<char> Image, bool IsHIP,
267                                   StringRef Suffix) {
268    LLVMContext &C = M.getContext();
269    llvm::Type *Int8PtrTy = PointerType::getUnqual(C);
270    llvm::Triple Triple = llvm::Triple(M.getTargetTriple());
271  
272    // Create the global string containing the fatbinary.
273    StringRef FatbinConstantSection =
274        IsHIP ? ".hip_fatbin"
275              : (Triple.isMacOSX() ? "__NV_CUDA,__nv_fatbin" : ".nv_fatbin");
276    auto *Data = ConstantDataArray::get(C, Image);
277    auto *Fatbin = new GlobalVariable(M, Data->getType(), /*isConstant*/ true,
278                                      GlobalVariable::InternalLinkage, Data,
279                                      ".fatbin_image" + Suffix);
280    Fatbin->setSection(FatbinConstantSection);
281  
282    // Create the fatbinary wrapper
283    StringRef FatbinWrapperSection = IsHIP               ? ".hipFatBinSegment"
284                                     : Triple.isMacOSX() ? "__NV_CUDA,__fatbin"
285                                                         : ".nvFatBinSegment";
286    Constant *FatbinWrapper[] = {
287        ConstantInt::get(Type::getInt32Ty(C), IsHIP ? HIPFatMagic : CudaFatMagic),
288        ConstantInt::get(Type::getInt32Ty(C), 1),
289        ConstantExpr::getPointerBitCastOrAddrSpaceCast(Fatbin, Int8PtrTy),
290        ConstantPointerNull::get(PointerType::getUnqual(C))};
291  
292    Constant *FatbinInitializer =
293        ConstantStruct::get(getFatbinWrapperTy(M), FatbinWrapper);
294  
295    auto *FatbinDesc =
296        new GlobalVariable(M, getFatbinWrapperTy(M),
297                           /*isConstant*/ true, GlobalValue::InternalLinkage,
298                           FatbinInitializer, ".fatbin_wrapper" + Suffix);
299    FatbinDesc->setSection(FatbinWrapperSection);
300    FatbinDesc->setAlignment(Align(8));
301  
302    return FatbinDesc;
303  }
304  
305  /// Create the register globals function. We will iterate all of the offloading
306  /// entries stored at the begin / end symbols and register them according to
307  /// their type. This creates the following function in IR:
308  ///
309  /// extern struct __tgt_offload_entry __start_cuda_offloading_entries;
310  /// extern struct __tgt_offload_entry __stop_cuda_offloading_entries;
311  ///
312  /// extern void __cudaRegisterFunction(void **, void *, void *, void *, int,
313  ///                                    void *, void *, void *, void *, int *);
314  /// extern void __cudaRegisterVar(void **, void *, void *, void *, int32_t,
315  ///                               int64_t, int32_t, int32_t);
316  ///
317  /// void __cudaRegisterTest(void **fatbinHandle) {
318  ///   for (struct __tgt_offload_entry *entry = &__start_cuda_offloading_entries;
319  ///        entry != &__stop_cuda_offloading_entries; ++entry) {
320  ///     if (!entry->size)
321  ///       __cudaRegisterFunction(fatbinHandle, entry->addr, entry->name,
322  ///                              entry->name, -1, 0, 0, 0, 0, 0);
323  ///     else
324  ///       __cudaRegisterVar(fatbinHandle, entry->addr, entry->name, entry->name,
325  ///                         0, entry->size, 0, 0);
326  ///   }
327  /// }
createRegisterGlobalsFunction(Module & M,bool IsHIP,EntryArrayTy EntryArray,StringRef Suffix,bool EmitSurfacesAndTextures)328  Function *createRegisterGlobalsFunction(Module &M, bool IsHIP,
329                                          EntryArrayTy EntryArray,
330                                          StringRef Suffix,
331                                          bool EmitSurfacesAndTextures) {
332    LLVMContext &C = M.getContext();
333    auto [EntriesB, EntriesE] = EntryArray;
334  
335    // Get the __cudaRegisterFunction function declaration.
336    PointerType *Int8PtrTy = PointerType::get(C, 0);
337    PointerType *Int8PtrPtrTy = PointerType::get(C, 0);
338    PointerType *Int32PtrTy = PointerType::get(C, 0);
339    auto *RegFuncTy = FunctionType::get(
340        Type::getInt32Ty(C),
341        {Int8PtrPtrTy, Int8PtrTy, Int8PtrTy, Int8PtrTy, Type::getInt32Ty(C),
342         Int8PtrTy, Int8PtrTy, Int8PtrTy, Int8PtrTy, Int32PtrTy},
343        /*isVarArg*/ false);
344    FunctionCallee RegFunc = M.getOrInsertFunction(
345        IsHIP ? "__hipRegisterFunction" : "__cudaRegisterFunction", RegFuncTy);
346  
347    // Get the __cudaRegisterVar function declaration.
348    auto *RegVarTy = FunctionType::get(
349        Type::getVoidTy(C),
350        {Int8PtrPtrTy, Int8PtrTy, Int8PtrTy, Int8PtrTy, Type::getInt32Ty(C),
351         getSizeTTy(M), Type::getInt32Ty(C), Type::getInt32Ty(C)},
352        /*isVarArg*/ false);
353    FunctionCallee RegVar = M.getOrInsertFunction(
354        IsHIP ? "__hipRegisterVar" : "__cudaRegisterVar", RegVarTy);
355  
356    // Get the __cudaRegisterSurface function declaration.
357    FunctionType *RegSurfaceTy =
358        FunctionType::get(Type::getVoidTy(C),
359                          {Int8PtrPtrTy, Int8PtrTy, Int8PtrTy, Int8PtrTy,
360                           Type::getInt32Ty(C), Type::getInt32Ty(C)},
361                          /*isVarArg=*/false);
362    FunctionCallee RegSurface = M.getOrInsertFunction(
363        IsHIP ? "__hipRegisterSurface" : "__cudaRegisterSurface", RegSurfaceTy);
364  
365    // Get the __cudaRegisterTexture function declaration.
366    FunctionType *RegTextureTy = FunctionType::get(
367        Type::getVoidTy(C),
368        {Int8PtrPtrTy, Int8PtrTy, Int8PtrTy, Int8PtrTy, Type::getInt32Ty(C),
369         Type::getInt32Ty(C), Type::getInt32Ty(C)},
370        /*isVarArg=*/false);
371    FunctionCallee RegTexture = M.getOrInsertFunction(
372        IsHIP ? "__hipRegisterTexture" : "__cudaRegisterTexture", RegTextureTy);
373  
374    auto *RegGlobalsTy = FunctionType::get(Type::getVoidTy(C), Int8PtrPtrTy,
375                                           /*isVarArg*/ false);
376    auto *RegGlobalsFn =
377        Function::Create(RegGlobalsTy, GlobalValue::InternalLinkage,
378                         IsHIP ? ".hip.globals_reg" : ".cuda.globals_reg", &M);
379    RegGlobalsFn->setSection(".text.startup");
380  
381    // Create the loop to register all the entries.
382    IRBuilder<> Builder(BasicBlock::Create(C, "entry", RegGlobalsFn));
383    auto *EntryBB = BasicBlock::Create(C, "while.entry", RegGlobalsFn);
384    auto *IfThenBB = BasicBlock::Create(C, "if.then", RegGlobalsFn);
385    auto *IfElseBB = BasicBlock::Create(C, "if.else", RegGlobalsFn);
386    auto *SwGlobalBB = BasicBlock::Create(C, "sw.global", RegGlobalsFn);
387    auto *SwManagedBB = BasicBlock::Create(C, "sw.managed", RegGlobalsFn);
388    auto *SwSurfaceBB = BasicBlock::Create(C, "sw.surface", RegGlobalsFn);
389    auto *SwTextureBB = BasicBlock::Create(C, "sw.texture", RegGlobalsFn);
390    auto *IfEndBB = BasicBlock::Create(C, "if.end", RegGlobalsFn);
391    auto *ExitBB = BasicBlock::Create(C, "while.end", RegGlobalsFn);
392  
393    auto *EntryCmp = Builder.CreateICmpNE(EntriesB, EntriesE);
394    Builder.CreateCondBr(EntryCmp, EntryBB, ExitBB);
395    Builder.SetInsertPoint(EntryBB);
396    auto *Entry = Builder.CreatePHI(PointerType::getUnqual(C), 2, "entry");
397    auto *AddrPtr =
398        Builder.CreateInBoundsGEP(offloading::getEntryTy(M), Entry,
399                                  {ConstantInt::get(getSizeTTy(M), 0),
400                                   ConstantInt::get(Type::getInt32Ty(C), 0)});
401    auto *Addr = Builder.CreateLoad(Int8PtrTy, AddrPtr, "addr");
402    auto *NamePtr =
403        Builder.CreateInBoundsGEP(offloading::getEntryTy(M), Entry,
404                                  {ConstantInt::get(getSizeTTy(M), 0),
405                                   ConstantInt::get(Type::getInt32Ty(C), 1)});
406    auto *Name = Builder.CreateLoad(Int8PtrTy, NamePtr, "name");
407    auto *SizePtr =
408        Builder.CreateInBoundsGEP(offloading::getEntryTy(M), Entry,
409                                  {ConstantInt::get(getSizeTTy(M), 0),
410                                   ConstantInt::get(Type::getInt32Ty(C), 2)});
411    auto *Size = Builder.CreateLoad(getSizeTTy(M), SizePtr, "size");
412    auto *FlagsPtr =
413        Builder.CreateInBoundsGEP(offloading::getEntryTy(M), Entry,
414                                  {ConstantInt::get(getSizeTTy(M), 0),
415                                   ConstantInt::get(Type::getInt32Ty(C), 3)});
416    auto *Flags = Builder.CreateLoad(Type::getInt32Ty(C), FlagsPtr, "flags");
417    auto *DataPtr =
418        Builder.CreateInBoundsGEP(offloading::getEntryTy(M), Entry,
419                                  {ConstantInt::get(getSizeTTy(M), 0),
420                                   ConstantInt::get(Type::getInt32Ty(C), 4)});
421    auto *Data = Builder.CreateLoad(Type::getInt32Ty(C), DataPtr, "textype");
422    auto *Kind = Builder.CreateAnd(
423        Flags, ConstantInt::get(Type::getInt32Ty(C), 0x7), "type");
424  
425    // Extract the flags stored in the bit-field and convert them to C booleans.
426    auto *ExternBit = Builder.CreateAnd(
427        Flags, ConstantInt::get(Type::getInt32Ty(C),
428                                llvm::offloading::OffloadGlobalExtern));
429    auto *Extern = Builder.CreateLShr(
430        ExternBit, ConstantInt::get(Type::getInt32Ty(C), 3), "extern");
431    auto *ConstantBit = Builder.CreateAnd(
432        Flags, ConstantInt::get(Type::getInt32Ty(C),
433                                llvm::offloading::OffloadGlobalConstant));
434    auto *Const = Builder.CreateLShr(
435        ConstantBit, ConstantInt::get(Type::getInt32Ty(C), 4), "constant");
436    auto *NormalizedBit = Builder.CreateAnd(
437        Flags, ConstantInt::get(Type::getInt32Ty(C),
438                                llvm::offloading::OffloadGlobalNormalized));
439    auto *Normalized = Builder.CreateLShr(
440        NormalizedBit, ConstantInt::get(Type::getInt32Ty(C), 5), "normalized");
441    auto *FnCond =
442        Builder.CreateICmpEQ(Size, ConstantInt::getNullValue(getSizeTTy(M)));
443    Builder.CreateCondBr(FnCond, IfThenBB, IfElseBB);
444  
445    // Create kernel registration code.
446    Builder.SetInsertPoint(IfThenBB);
447    Builder.CreateCall(RegFunc, {RegGlobalsFn->arg_begin(), Addr, Name, Name,
448                                 ConstantInt::get(Type::getInt32Ty(C), -1),
449                                 ConstantPointerNull::get(Int8PtrTy),
450                                 ConstantPointerNull::get(Int8PtrTy),
451                                 ConstantPointerNull::get(Int8PtrTy),
452                                 ConstantPointerNull::get(Int8PtrTy),
453                                 ConstantPointerNull::get(Int32PtrTy)});
454    Builder.CreateBr(IfEndBB);
455    Builder.SetInsertPoint(IfElseBB);
456  
457    auto *Switch = Builder.CreateSwitch(Kind, IfEndBB);
458    // Create global variable registration code.
459    Builder.SetInsertPoint(SwGlobalBB);
460    Builder.CreateCall(RegVar,
461                       {RegGlobalsFn->arg_begin(), Addr, Name, Name, Extern, Size,
462                        Const, ConstantInt::get(Type::getInt32Ty(C), 0)});
463    Builder.CreateBr(IfEndBB);
464    Switch->addCase(Builder.getInt32(llvm::offloading::OffloadGlobalEntry),
465                    SwGlobalBB);
466  
467    // Create managed variable registration code.
468    Builder.SetInsertPoint(SwManagedBB);
469    Builder.CreateBr(IfEndBB);
470    Switch->addCase(Builder.getInt32(llvm::offloading::OffloadGlobalManagedEntry),
471                    SwManagedBB);
472    // Create surface variable registration code.
473    Builder.SetInsertPoint(SwSurfaceBB);
474    if (EmitSurfacesAndTextures)
475      Builder.CreateCall(RegSurface, {RegGlobalsFn->arg_begin(), Addr, Name, Name,
476                                      Data, Extern});
477    Builder.CreateBr(IfEndBB);
478    Switch->addCase(Builder.getInt32(llvm::offloading::OffloadGlobalSurfaceEntry),
479                    SwSurfaceBB);
480  
481    // Create texture variable registration code.
482    Builder.SetInsertPoint(SwTextureBB);
483    if (EmitSurfacesAndTextures)
484      Builder.CreateCall(RegTexture, {RegGlobalsFn->arg_begin(), Addr, Name, Name,
485                                      Data, Normalized, Extern});
486    Builder.CreateBr(IfEndBB);
487    Switch->addCase(Builder.getInt32(llvm::offloading::OffloadGlobalTextureEntry),
488                    SwTextureBB);
489  
490    Builder.SetInsertPoint(IfEndBB);
491    auto *NewEntry = Builder.CreateInBoundsGEP(
492        offloading::getEntryTy(M), Entry, ConstantInt::get(getSizeTTy(M), 1));
493    auto *Cmp = Builder.CreateICmpEQ(
494        NewEntry,
495        ConstantExpr::getInBoundsGetElementPtr(
496            ArrayType::get(offloading::getEntryTy(M), 0), EntriesE,
497            ArrayRef<Constant *>({ConstantInt::get(getSizeTTy(M), 0),
498                                  ConstantInt::get(getSizeTTy(M), 0)})));
499    Entry->addIncoming(
500        ConstantExpr::getInBoundsGetElementPtr(
501            ArrayType::get(offloading::getEntryTy(M), 0), EntriesB,
502            ArrayRef<Constant *>({ConstantInt::get(getSizeTTy(M), 0),
503                                  ConstantInt::get(getSizeTTy(M), 0)})),
504        &RegGlobalsFn->getEntryBlock());
505    Entry->addIncoming(NewEntry, IfEndBB);
506    Builder.CreateCondBr(Cmp, ExitBB, EntryBB);
507    Builder.SetInsertPoint(ExitBB);
508    Builder.CreateRetVoid();
509  
510    return RegGlobalsFn;
511  }
512  
513  // Create the constructor and destructor to register the fatbinary with the CUDA
514  // runtime.
createRegisterFatbinFunction(Module & M,GlobalVariable * FatbinDesc,bool IsHIP,EntryArrayTy EntryArray,StringRef Suffix,bool EmitSurfacesAndTextures)515  void createRegisterFatbinFunction(Module &M, GlobalVariable *FatbinDesc,
516                                    bool IsHIP, EntryArrayTy EntryArray,
517                                    StringRef Suffix,
518                                    bool EmitSurfacesAndTextures) {
519    LLVMContext &C = M.getContext();
520    auto *CtorFuncTy = FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false);
521    auto *CtorFunc = Function::Create(
522        CtorFuncTy, GlobalValue::InternalLinkage,
523        (IsHIP ? ".hip.fatbin_reg" : ".cuda.fatbin_reg") + Suffix, &M);
524    CtorFunc->setSection(".text.startup");
525  
526    auto *DtorFuncTy = FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false);
527    auto *DtorFunc = Function::Create(
528        DtorFuncTy, GlobalValue::InternalLinkage,
529        (IsHIP ? ".hip.fatbin_unreg" : ".cuda.fatbin_unreg") + Suffix, &M);
530    DtorFunc->setSection(".text.startup");
531  
532    auto *PtrTy = PointerType::getUnqual(C);
533  
534    // Get the __cudaRegisterFatBinary function declaration.
535    auto *RegFatTy = FunctionType::get(PtrTy, PtrTy, /*isVarArg=*/false);
536    FunctionCallee RegFatbin = M.getOrInsertFunction(
537        IsHIP ? "__hipRegisterFatBinary" : "__cudaRegisterFatBinary", RegFatTy);
538    // Get the __cudaRegisterFatBinaryEnd function declaration.
539    auto *RegFatEndTy =
540        FunctionType::get(Type::getVoidTy(C), PtrTy, /*isVarArg=*/false);
541    FunctionCallee RegFatbinEnd =
542        M.getOrInsertFunction("__cudaRegisterFatBinaryEnd", RegFatEndTy);
543    // Get the __cudaUnregisterFatBinary function declaration.
544    auto *UnregFatTy =
545        FunctionType::get(Type::getVoidTy(C), PtrTy, /*isVarArg=*/false);
546    FunctionCallee UnregFatbin = M.getOrInsertFunction(
547        IsHIP ? "__hipUnregisterFatBinary" : "__cudaUnregisterFatBinary",
548        UnregFatTy);
549  
550    auto *AtExitTy =
551        FunctionType::get(Type::getInt32Ty(C), PtrTy, /*isVarArg=*/false);
552    FunctionCallee AtExit = M.getOrInsertFunction("atexit", AtExitTy);
553  
554    auto *BinaryHandleGlobal = new llvm::GlobalVariable(
555        M, PtrTy, false, llvm::GlobalValue::InternalLinkage,
556        llvm::ConstantPointerNull::get(PtrTy),
557        (IsHIP ? ".hip.binary_handle" : ".cuda.binary_handle") + Suffix);
558  
559    // Create the constructor to register this image with the runtime.
560    IRBuilder<> CtorBuilder(BasicBlock::Create(C, "entry", CtorFunc));
561    CallInst *Handle = CtorBuilder.CreateCall(
562        RegFatbin,
563        ConstantExpr::getPointerBitCastOrAddrSpaceCast(FatbinDesc, PtrTy));
564    CtorBuilder.CreateAlignedStore(
565        Handle, BinaryHandleGlobal,
566        Align(M.getDataLayout().getPointerTypeSize(PtrTy)));
567    CtorBuilder.CreateCall(createRegisterGlobalsFunction(M, IsHIP, EntryArray,
568                                                         Suffix,
569                                                         EmitSurfacesAndTextures),
570                           Handle);
571    if (!IsHIP)
572      CtorBuilder.CreateCall(RegFatbinEnd, Handle);
573    CtorBuilder.CreateCall(AtExit, DtorFunc);
574    CtorBuilder.CreateRetVoid();
575  
576    // Create the destructor to unregister the image with the runtime. We cannot
577    // use a standard global destructor after CUDA 9.2 so this must be called by
578    // `atexit()` intead.
579    IRBuilder<> DtorBuilder(BasicBlock::Create(C, "entry", DtorFunc));
580    LoadInst *BinaryHandle = DtorBuilder.CreateAlignedLoad(
581        PtrTy, BinaryHandleGlobal,
582        Align(M.getDataLayout().getPointerTypeSize(PtrTy)));
583    DtorBuilder.CreateCall(UnregFatbin, BinaryHandle);
584    DtorBuilder.CreateRetVoid();
585  
586    // Add this function to constructors.
587    appendToGlobalCtors(M, CtorFunc, /*Priority=*/101);
588  }
589  } // namespace
590  
wrapOpenMPBinaries(Module & M,ArrayRef<ArrayRef<char>> Images,EntryArrayTy EntryArray,llvm::StringRef Suffix,bool Relocatable)591  Error offloading::wrapOpenMPBinaries(Module &M, ArrayRef<ArrayRef<char>> Images,
592                                       EntryArrayTy EntryArray,
593                                       llvm::StringRef Suffix, bool Relocatable) {
594    GlobalVariable *Desc =
595        createBinDesc(M, Images, EntryArray, Suffix, Relocatable);
596    if (!Desc)
597      return createStringError(inconvertibleErrorCode(),
598                               "No binary descriptors created.");
599    createRegisterFunction(M, Desc, Suffix);
600    return Error::success();
601  }
602  
wrapCudaBinary(Module & M,ArrayRef<char> Image,EntryArrayTy EntryArray,llvm::StringRef Suffix,bool EmitSurfacesAndTextures)603  Error offloading::wrapCudaBinary(Module &M, ArrayRef<char> Image,
604                                   EntryArrayTy EntryArray,
605                                   llvm::StringRef Suffix,
606                                   bool EmitSurfacesAndTextures) {
607    GlobalVariable *Desc = createFatbinDesc(M, Image, /*IsHip=*/false, Suffix);
608    if (!Desc)
609      return createStringError(inconvertibleErrorCode(),
610                               "No fatbin section created.");
611  
612    createRegisterFatbinFunction(M, Desc, /*IsHip=*/false, EntryArray, Suffix,
613                                 EmitSurfacesAndTextures);
614    return Error::success();
615  }
616  
wrapHIPBinary(Module & M,ArrayRef<char> Image,EntryArrayTy EntryArray,llvm::StringRef Suffix,bool EmitSurfacesAndTextures)617  Error offloading::wrapHIPBinary(Module &M, ArrayRef<char> Image,
618                                  EntryArrayTy EntryArray, llvm::StringRef Suffix,
619                                  bool EmitSurfacesAndTextures) {
620    GlobalVariable *Desc = createFatbinDesc(M, Image, /*IsHip=*/true, Suffix);
621    if (!Desc)
622      return createStringError(inconvertibleErrorCode(),
623                               "No fatbin section created.");
624  
625    createRegisterFatbinFunction(M, Desc, /*IsHip=*/true, EntryArray, Suffix,
626                                 EmitSurfacesAndTextures);
627    return Error::success();
628  }
629