xref: /freebsd/contrib/llvm-project/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (revision 7fdf597e96a02165cfe22ff357b857d5fa15ed8a)
1 //===- OpenMPIRBuilder.cpp - Builder for LLVM-IR for OpenMP directives ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 ///
10 /// This file implements the OpenMPIRBuilder class, which is used as a
11 /// convenient way to create LLVM instructions for OpenMP directives.
12 ///
13 //===----------------------------------------------------------------------===//
14 
15 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
16 #include "llvm/ADT/SmallSet.h"
17 #include "llvm/ADT/StringExtras.h"
18 #include "llvm/ADT/StringRef.h"
19 #include "llvm/Analysis/AssumptionCache.h"
20 #include "llvm/Analysis/CodeMetrics.h"
21 #include "llvm/Analysis/LoopInfo.h"
22 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
23 #include "llvm/Analysis/ScalarEvolution.h"
24 #include "llvm/Analysis/TargetLibraryInfo.h"
25 #include "llvm/Bitcode/BitcodeReader.h"
26 #include "llvm/Frontend/Offloading/Utility.h"
27 #include "llvm/Frontend/OpenMP/OMPGridValues.h"
28 #include "llvm/IR/Attributes.h"
29 #include "llvm/IR/BasicBlock.h"
30 #include "llvm/IR/CFG.h"
31 #include "llvm/IR/CallingConv.h"
32 #include "llvm/IR/Constant.h"
33 #include "llvm/IR/Constants.h"
34 #include "llvm/IR/DebugInfoMetadata.h"
35 #include "llvm/IR/DerivedTypes.h"
36 #include "llvm/IR/Function.h"
37 #include "llvm/IR/GlobalVariable.h"
38 #include "llvm/IR/IRBuilder.h"
39 #include "llvm/IR/LLVMContext.h"
40 #include "llvm/IR/MDBuilder.h"
41 #include "llvm/IR/Metadata.h"
42 #include "llvm/IR/PassManager.h"
43 #include "llvm/IR/PassInstrumentation.h"
44 #include "llvm/IR/ReplaceConstant.h"
45 #include "llvm/IR/Value.h"
46 #include "llvm/MC/TargetRegistry.h"
47 #include "llvm/Support/CommandLine.h"
48 #include "llvm/Support/ErrorHandling.h"
49 #include "llvm/Support/FileSystem.h"
50 #include "llvm/Target/TargetMachine.h"
51 #include "llvm/Target/TargetOptions.h"
52 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
53 #include "llvm/Transforms/Utils/Cloning.h"
54 #include "llvm/Transforms/Utils/CodeExtractor.h"
55 #include "llvm/Transforms/Utils/LoopPeel.h"
56 #include "llvm/Transforms/Utils/UnrollLoop.h"
57 
58 #include <cstdint>
59 #include <optional>
60 #include <stack>
61 
62 #define DEBUG_TYPE "openmp-ir-builder"
63 
64 using namespace llvm;
65 using namespace omp;
66 
67 static cl::opt<bool>
68     OptimisticAttributes("openmp-ir-builder-optimistic-attributes", cl::Hidden,
69                          cl::desc("Use optimistic attributes describing "
70                                   "'as-if' properties of runtime calls."),
71                          cl::init(false));
72 
73 static cl::opt<double> UnrollThresholdFactor(
74     "openmp-ir-builder-unroll-threshold-factor", cl::Hidden,
75     cl::desc("Factor for the unroll threshold to account for code "
76              "simplifications still taking place"),
77     cl::init(1.5));
78 
79 #ifndef NDEBUG
80 /// Return whether IP1 and IP2 are ambiguous, i.e. that inserting instructions
81 /// at position IP1 may change the meaning of IP2 or vice-versa. This is because
82 /// an InsertPoint stores the instruction before something is inserted. For
83 /// instance, if both point to the same instruction, two IRBuilders alternating
84 /// creating instruction will cause the instructions to be interleaved.
85 static bool isConflictIP(IRBuilder<>::InsertPoint IP1,
86                          IRBuilder<>::InsertPoint IP2) {
87   if (!IP1.isSet() || !IP2.isSet())
88     return false;
89   return IP1.getBlock() == IP2.getBlock() && IP1.getPoint() == IP2.getPoint();
90 }
91 
92 static bool isValidWorkshareLoopScheduleType(OMPScheduleType SchedType) {
93   // Valid ordered/unordered and base algorithm combinations.
94   switch (SchedType & ~OMPScheduleType::MonotonicityMask) {
95   case OMPScheduleType::UnorderedStaticChunked:
96   case OMPScheduleType::UnorderedStatic:
97   case OMPScheduleType::UnorderedDynamicChunked:
98   case OMPScheduleType::UnorderedGuidedChunked:
99   case OMPScheduleType::UnorderedRuntime:
100   case OMPScheduleType::UnorderedAuto:
101   case OMPScheduleType::UnorderedTrapezoidal:
102   case OMPScheduleType::UnorderedGreedy:
103   case OMPScheduleType::UnorderedBalanced:
104   case OMPScheduleType::UnorderedGuidedIterativeChunked:
105   case OMPScheduleType::UnorderedGuidedAnalyticalChunked:
106   case OMPScheduleType::UnorderedSteal:
107   case OMPScheduleType::UnorderedStaticBalancedChunked:
108   case OMPScheduleType::UnorderedGuidedSimd:
109   case OMPScheduleType::UnorderedRuntimeSimd:
110   case OMPScheduleType::OrderedStaticChunked:
111   case OMPScheduleType::OrderedStatic:
112   case OMPScheduleType::OrderedDynamicChunked:
113   case OMPScheduleType::OrderedGuidedChunked:
114   case OMPScheduleType::OrderedRuntime:
115   case OMPScheduleType::OrderedAuto:
116   case OMPScheduleType::OrderdTrapezoidal:
117   case OMPScheduleType::NomergeUnorderedStaticChunked:
118   case OMPScheduleType::NomergeUnorderedStatic:
119   case OMPScheduleType::NomergeUnorderedDynamicChunked:
120   case OMPScheduleType::NomergeUnorderedGuidedChunked:
121   case OMPScheduleType::NomergeUnorderedRuntime:
122   case OMPScheduleType::NomergeUnorderedAuto:
123   case OMPScheduleType::NomergeUnorderedTrapezoidal:
124   case OMPScheduleType::NomergeUnorderedGreedy:
125   case OMPScheduleType::NomergeUnorderedBalanced:
126   case OMPScheduleType::NomergeUnorderedGuidedIterativeChunked:
127   case OMPScheduleType::NomergeUnorderedGuidedAnalyticalChunked:
128   case OMPScheduleType::NomergeUnorderedSteal:
129   case OMPScheduleType::NomergeOrderedStaticChunked:
130   case OMPScheduleType::NomergeOrderedStatic:
131   case OMPScheduleType::NomergeOrderedDynamicChunked:
132   case OMPScheduleType::NomergeOrderedGuidedChunked:
133   case OMPScheduleType::NomergeOrderedRuntime:
134   case OMPScheduleType::NomergeOrderedAuto:
135   case OMPScheduleType::NomergeOrderedTrapezoidal:
136     break;
137   default:
138     return false;
139   }
140 
141   // Must not set both monotonicity modifiers at the same time.
142   OMPScheduleType MonotonicityFlags =
143       SchedType & OMPScheduleType::MonotonicityMask;
144   if (MonotonicityFlags == OMPScheduleType::MonotonicityMask)
145     return false;
146 
147   return true;
148 }
149 #endif
150 
151 static const omp::GV &getGridValue(const Triple &T, Function *Kernel) {
152   if (T.isAMDGPU()) {
153     StringRef Features =
154         Kernel->getFnAttribute("target-features").getValueAsString();
155     if (Features.count("+wavefrontsize64"))
156       return omp::getAMDGPUGridValues<64>();
157     return omp::getAMDGPUGridValues<32>();
158   }
159   if (T.isNVPTX())
160     return omp::NVPTXGridValues;
161   llvm_unreachable("No grid value available for this architecture!");
162 }
163 
164 /// Determine which scheduling algorithm to use, determined from schedule clause
165 /// arguments.
166 static OMPScheduleType
167 getOpenMPBaseScheduleType(llvm::omp::ScheduleKind ClauseKind, bool HasChunks,
168                           bool HasSimdModifier) {
169   // Currently, the default schedule it static.
170   switch (ClauseKind) {
171   case OMP_SCHEDULE_Default:
172   case OMP_SCHEDULE_Static:
173     return HasChunks ? OMPScheduleType::BaseStaticChunked
174                      : OMPScheduleType::BaseStatic;
175   case OMP_SCHEDULE_Dynamic:
176     return OMPScheduleType::BaseDynamicChunked;
177   case OMP_SCHEDULE_Guided:
178     return HasSimdModifier ? OMPScheduleType::BaseGuidedSimd
179                            : OMPScheduleType::BaseGuidedChunked;
180   case OMP_SCHEDULE_Auto:
181     return llvm::omp::OMPScheduleType::BaseAuto;
182   case OMP_SCHEDULE_Runtime:
183     return HasSimdModifier ? OMPScheduleType::BaseRuntimeSimd
184                            : OMPScheduleType::BaseRuntime;
185   }
186   llvm_unreachable("unhandled schedule clause argument");
187 }
188 
189 /// Adds ordering modifier flags to schedule type.
190 static OMPScheduleType
191 getOpenMPOrderingScheduleType(OMPScheduleType BaseScheduleType,
192                               bool HasOrderedClause) {
193   assert((BaseScheduleType & OMPScheduleType::ModifierMask) ==
194              OMPScheduleType::None &&
195          "Must not have ordering nor monotonicity flags already set");
196 
197   OMPScheduleType OrderingModifier = HasOrderedClause
198                                          ? OMPScheduleType::ModifierOrdered
199                                          : OMPScheduleType::ModifierUnordered;
200   OMPScheduleType OrderingScheduleType = BaseScheduleType | OrderingModifier;
201 
202   // Unsupported combinations
203   if (OrderingScheduleType ==
204       (OMPScheduleType::BaseGuidedSimd | OMPScheduleType::ModifierOrdered))
205     return OMPScheduleType::OrderedGuidedChunked;
206   else if (OrderingScheduleType == (OMPScheduleType::BaseRuntimeSimd |
207                                     OMPScheduleType::ModifierOrdered))
208     return OMPScheduleType::OrderedRuntime;
209 
210   return OrderingScheduleType;
211 }
212 
213 /// Adds monotonicity modifier flags to schedule type.
214 static OMPScheduleType
215 getOpenMPMonotonicityScheduleType(OMPScheduleType ScheduleType,
216                                   bool HasSimdModifier, bool HasMonotonic,
217                                   bool HasNonmonotonic, bool HasOrderedClause) {
218   assert((ScheduleType & OMPScheduleType::MonotonicityMask) ==
219              OMPScheduleType::None &&
220          "Must not have monotonicity flags already set");
221   assert((!HasMonotonic || !HasNonmonotonic) &&
222          "Monotonic and Nonmonotonic are contradicting each other");
223 
224   if (HasMonotonic) {
225     return ScheduleType | OMPScheduleType::ModifierMonotonic;
226   } else if (HasNonmonotonic) {
227     return ScheduleType | OMPScheduleType::ModifierNonmonotonic;
228   } else {
229     // OpenMP 5.1, 2.11.4 Worksharing-Loop Construct, Description.
230     // If the static schedule kind is specified or if the ordered clause is
231     // specified, and if the nonmonotonic modifier is not specified, the
232     // effect is as if the monotonic modifier is specified. Otherwise, unless
233     // the monotonic modifier is specified, the effect is as if the
234     // nonmonotonic modifier is specified.
235     OMPScheduleType BaseScheduleType =
236         ScheduleType & ~OMPScheduleType::ModifierMask;
237     if ((BaseScheduleType == OMPScheduleType::BaseStatic) ||
238         (BaseScheduleType == OMPScheduleType::BaseStaticChunked) ||
239         HasOrderedClause) {
240       // The monotonic is used by default in openmp runtime library, so no need
241       // to set it.
242       return ScheduleType;
243     } else {
244       return ScheduleType | OMPScheduleType::ModifierNonmonotonic;
245     }
246   }
247 }
248 
249 /// Determine the schedule type using schedule and ordering clause arguments.
250 static OMPScheduleType
251 computeOpenMPScheduleType(ScheduleKind ClauseKind, bool HasChunks,
252                           bool HasSimdModifier, bool HasMonotonicModifier,
253                           bool HasNonmonotonicModifier, bool HasOrderedClause) {
254   OMPScheduleType BaseSchedule =
255       getOpenMPBaseScheduleType(ClauseKind, HasChunks, HasSimdModifier);
256   OMPScheduleType OrderedSchedule =
257       getOpenMPOrderingScheduleType(BaseSchedule, HasOrderedClause);
258   OMPScheduleType Result = getOpenMPMonotonicityScheduleType(
259       OrderedSchedule, HasSimdModifier, HasMonotonicModifier,
260       HasNonmonotonicModifier, HasOrderedClause);
261 
262   assert(isValidWorkshareLoopScheduleType(Result));
263   return Result;
264 }
265 
266 /// Make \p Source branch to \p Target.
267 ///
268 /// Handles two situations:
269 /// * \p Source already has an unconditional branch.
270 /// * \p Source is a degenerate block (no terminator because the BB is
271 ///             the current head of the IR construction).
272 static void redirectTo(BasicBlock *Source, BasicBlock *Target, DebugLoc DL) {
273   if (Instruction *Term = Source->getTerminator()) {
274     auto *Br = cast<BranchInst>(Term);
275     assert(!Br->isConditional() &&
276            "BB's terminator must be an unconditional branch (or degenerate)");
277     BasicBlock *Succ = Br->getSuccessor(0);
278     Succ->removePredecessor(Source, /*KeepOneInputPHIs=*/true);
279     Br->setSuccessor(0, Target);
280     return;
281   }
282 
283   auto *NewBr = BranchInst::Create(Target, Source);
284   NewBr->setDebugLoc(DL);
285 }
286 
287 void llvm::spliceBB(IRBuilderBase::InsertPoint IP, BasicBlock *New,
288                     bool CreateBranch) {
289   assert(New->getFirstInsertionPt() == New->begin() &&
290          "Target BB must not have PHI nodes");
291 
292   // Move instructions to new block.
293   BasicBlock *Old = IP.getBlock();
294   New->splice(New->begin(), Old, IP.getPoint(), Old->end());
295 
296   if (CreateBranch)
297     BranchInst::Create(New, Old);
298 }
299 
300 void llvm::spliceBB(IRBuilder<> &Builder, BasicBlock *New, bool CreateBranch) {
301   DebugLoc DebugLoc = Builder.getCurrentDebugLocation();
302   BasicBlock *Old = Builder.GetInsertBlock();
303 
304   spliceBB(Builder.saveIP(), New, CreateBranch);
305   if (CreateBranch)
306     Builder.SetInsertPoint(Old->getTerminator());
307   else
308     Builder.SetInsertPoint(Old);
309 
310   // SetInsertPoint also updates the Builder's debug location, but we want to
311   // keep the one the Builder was configured to use.
312   Builder.SetCurrentDebugLocation(DebugLoc);
313 }
314 
315 BasicBlock *llvm::splitBB(IRBuilderBase::InsertPoint IP, bool CreateBranch,
316                           llvm::Twine Name) {
317   BasicBlock *Old = IP.getBlock();
318   BasicBlock *New = BasicBlock::Create(
319       Old->getContext(), Name.isTriviallyEmpty() ? Old->getName() : Name,
320       Old->getParent(), Old->getNextNode());
321   spliceBB(IP, New, CreateBranch);
322   New->replaceSuccessorsPhiUsesWith(Old, New);
323   return New;
324 }
325 
326 BasicBlock *llvm::splitBB(IRBuilderBase &Builder, bool CreateBranch,
327                           llvm::Twine Name) {
328   DebugLoc DebugLoc = Builder.getCurrentDebugLocation();
329   BasicBlock *New = splitBB(Builder.saveIP(), CreateBranch, Name);
330   if (CreateBranch)
331     Builder.SetInsertPoint(Builder.GetInsertBlock()->getTerminator());
332   else
333     Builder.SetInsertPoint(Builder.GetInsertBlock());
334   // SetInsertPoint also updates the Builder's debug location, but we want to
335   // keep the one the Builder was configured to use.
336   Builder.SetCurrentDebugLocation(DebugLoc);
337   return New;
338 }
339 
340 BasicBlock *llvm::splitBB(IRBuilder<> &Builder, bool CreateBranch,
341                           llvm::Twine Name) {
342   DebugLoc DebugLoc = Builder.getCurrentDebugLocation();
343   BasicBlock *New = splitBB(Builder.saveIP(), CreateBranch, Name);
344   if (CreateBranch)
345     Builder.SetInsertPoint(Builder.GetInsertBlock()->getTerminator());
346   else
347     Builder.SetInsertPoint(Builder.GetInsertBlock());
348   // SetInsertPoint also updates the Builder's debug location, but we want to
349   // keep the one the Builder was configured to use.
350   Builder.SetCurrentDebugLocation(DebugLoc);
351   return New;
352 }
353 
354 BasicBlock *llvm::splitBBWithSuffix(IRBuilderBase &Builder, bool CreateBranch,
355                                     llvm::Twine Suffix) {
356   BasicBlock *Old = Builder.GetInsertBlock();
357   return splitBB(Builder, CreateBranch, Old->getName() + Suffix);
358 }
359 
360 // This function creates a fake integer value and a fake use for the integer
361 // value. It returns the fake value created. This is useful in modeling the
362 // extra arguments to the outlined functions.
363 Value *createFakeIntVal(IRBuilderBase &Builder,
364                         OpenMPIRBuilder::InsertPointTy OuterAllocaIP,
365                         llvm::SmallVectorImpl<Instruction *> &ToBeDeleted,
366                         OpenMPIRBuilder::InsertPointTy InnerAllocaIP,
367                         const Twine &Name = "", bool AsPtr = true) {
368   Builder.restoreIP(OuterAllocaIP);
369   Instruction *FakeVal;
370   AllocaInst *FakeValAddr =
371       Builder.CreateAlloca(Builder.getInt32Ty(), nullptr, Name + ".addr");
372   ToBeDeleted.push_back(FakeValAddr);
373 
374   if (AsPtr) {
375     FakeVal = FakeValAddr;
376   } else {
377     FakeVal =
378         Builder.CreateLoad(Builder.getInt32Ty(), FakeValAddr, Name + ".val");
379     ToBeDeleted.push_back(FakeVal);
380   }
381 
382   // Generate a fake use of this value
383   Builder.restoreIP(InnerAllocaIP);
384   Instruction *UseFakeVal;
385   if (AsPtr) {
386     UseFakeVal =
387         Builder.CreateLoad(Builder.getInt32Ty(), FakeVal, Name + ".use");
388   } else {
389     UseFakeVal =
390         cast<BinaryOperator>(Builder.CreateAdd(FakeVal, Builder.getInt32(10)));
391   }
392   ToBeDeleted.push_back(UseFakeVal);
393   return FakeVal;
394 }
395 
396 //===----------------------------------------------------------------------===//
397 // OpenMPIRBuilderConfig
398 //===----------------------------------------------------------------------===//
399 
400 namespace {
401 LLVM_ENABLE_BITMASK_ENUMS_IN_NAMESPACE();
402 /// Values for bit flags for marking which requires clauses have been used.
403 enum OpenMPOffloadingRequiresDirFlags {
404   /// flag undefined.
405   OMP_REQ_UNDEFINED = 0x000,
406   /// no requires directive present.
407   OMP_REQ_NONE = 0x001,
408   /// reverse_offload clause.
409   OMP_REQ_REVERSE_OFFLOAD = 0x002,
410   /// unified_address clause.
411   OMP_REQ_UNIFIED_ADDRESS = 0x004,
412   /// unified_shared_memory clause.
413   OMP_REQ_UNIFIED_SHARED_MEMORY = 0x008,
414   /// dynamic_allocators clause.
415   OMP_REQ_DYNAMIC_ALLOCATORS = 0x010,
416   LLVM_MARK_AS_BITMASK_ENUM(/*LargestValue=*/OMP_REQ_DYNAMIC_ALLOCATORS)
417 };
418 
419 } // anonymous namespace
420 
421 OpenMPIRBuilderConfig::OpenMPIRBuilderConfig()
422     : RequiresFlags(OMP_REQ_UNDEFINED) {}
423 
424 OpenMPIRBuilderConfig::OpenMPIRBuilderConfig(
425     bool IsTargetDevice, bool IsGPU, bool OpenMPOffloadMandatory,
426     bool HasRequiresReverseOffload, bool HasRequiresUnifiedAddress,
427     bool HasRequiresUnifiedSharedMemory, bool HasRequiresDynamicAllocators)
428     : IsTargetDevice(IsTargetDevice), IsGPU(IsGPU),
429       OpenMPOffloadMandatory(OpenMPOffloadMandatory),
430       RequiresFlags(OMP_REQ_UNDEFINED) {
431   if (HasRequiresReverseOffload)
432     RequiresFlags |= OMP_REQ_REVERSE_OFFLOAD;
433   if (HasRequiresUnifiedAddress)
434     RequiresFlags |= OMP_REQ_UNIFIED_ADDRESS;
435   if (HasRequiresUnifiedSharedMemory)
436     RequiresFlags |= OMP_REQ_UNIFIED_SHARED_MEMORY;
437   if (HasRequiresDynamicAllocators)
438     RequiresFlags |= OMP_REQ_DYNAMIC_ALLOCATORS;
439 }
440 
441 bool OpenMPIRBuilderConfig::hasRequiresReverseOffload() const {
442   return RequiresFlags & OMP_REQ_REVERSE_OFFLOAD;
443 }
444 
445 bool OpenMPIRBuilderConfig::hasRequiresUnifiedAddress() const {
446   return RequiresFlags & OMP_REQ_UNIFIED_ADDRESS;
447 }
448 
449 bool OpenMPIRBuilderConfig::hasRequiresUnifiedSharedMemory() const {
450   return RequiresFlags & OMP_REQ_UNIFIED_SHARED_MEMORY;
451 }
452 
453 bool OpenMPIRBuilderConfig::hasRequiresDynamicAllocators() const {
454   return RequiresFlags & OMP_REQ_DYNAMIC_ALLOCATORS;
455 }
456 
457 int64_t OpenMPIRBuilderConfig::getRequiresFlags() const {
458   return hasRequiresFlags() ? RequiresFlags
459                             : static_cast<int64_t>(OMP_REQ_NONE);
460 }
461 
462 void OpenMPIRBuilderConfig::setHasRequiresReverseOffload(bool Value) {
463   if (Value)
464     RequiresFlags |= OMP_REQ_REVERSE_OFFLOAD;
465   else
466     RequiresFlags &= ~OMP_REQ_REVERSE_OFFLOAD;
467 }
468 
469 void OpenMPIRBuilderConfig::setHasRequiresUnifiedAddress(bool Value) {
470   if (Value)
471     RequiresFlags |= OMP_REQ_UNIFIED_ADDRESS;
472   else
473     RequiresFlags &= ~OMP_REQ_UNIFIED_ADDRESS;
474 }
475 
476 void OpenMPIRBuilderConfig::setHasRequiresUnifiedSharedMemory(bool Value) {
477   if (Value)
478     RequiresFlags |= OMP_REQ_UNIFIED_SHARED_MEMORY;
479   else
480     RequiresFlags &= ~OMP_REQ_UNIFIED_SHARED_MEMORY;
481 }
482 
483 void OpenMPIRBuilderConfig::setHasRequiresDynamicAllocators(bool Value) {
484   if (Value)
485     RequiresFlags |= OMP_REQ_DYNAMIC_ALLOCATORS;
486   else
487     RequiresFlags &= ~OMP_REQ_DYNAMIC_ALLOCATORS;
488 }
489 
490 //===----------------------------------------------------------------------===//
491 // OpenMPIRBuilder
492 //===----------------------------------------------------------------------===//
493 
494 void OpenMPIRBuilder::getKernelArgsVector(TargetKernelArgs &KernelArgs,
495                                           IRBuilderBase &Builder,
496                                           SmallVector<Value *> &ArgsVector) {
497   Value *Version = Builder.getInt32(OMP_KERNEL_ARG_VERSION);
498   Value *PointerNum = Builder.getInt32(KernelArgs.NumTargetItems);
499   auto Int32Ty = Type::getInt32Ty(Builder.getContext());
500   Value *ZeroArray = Constant::getNullValue(ArrayType::get(Int32Ty, 3));
501   Value *Flags = Builder.getInt64(KernelArgs.HasNoWait);
502 
503   Value *NumTeams3D =
504       Builder.CreateInsertValue(ZeroArray, KernelArgs.NumTeams, {0});
505   Value *NumThreads3D =
506       Builder.CreateInsertValue(ZeroArray, KernelArgs.NumThreads, {0});
507 
508   ArgsVector = {Version,
509                 PointerNum,
510                 KernelArgs.RTArgs.BasePointersArray,
511                 KernelArgs.RTArgs.PointersArray,
512                 KernelArgs.RTArgs.SizesArray,
513                 KernelArgs.RTArgs.MapTypesArray,
514                 KernelArgs.RTArgs.MapNamesArray,
515                 KernelArgs.RTArgs.MappersArray,
516                 KernelArgs.NumIterations,
517                 Flags,
518                 NumTeams3D,
519                 NumThreads3D,
520                 KernelArgs.DynCGGroupMem};
521 }
522 
523 void OpenMPIRBuilder::addAttributes(omp::RuntimeFunction FnID, Function &Fn) {
524   LLVMContext &Ctx = Fn.getContext();
525 
526   // Get the function's current attributes.
527   auto Attrs = Fn.getAttributes();
528   auto FnAttrs = Attrs.getFnAttrs();
529   auto RetAttrs = Attrs.getRetAttrs();
530   SmallVector<AttributeSet, 4> ArgAttrs;
531   for (size_t ArgNo = 0; ArgNo < Fn.arg_size(); ++ArgNo)
532     ArgAttrs.emplace_back(Attrs.getParamAttrs(ArgNo));
533 
534   // Add AS to FnAS while taking special care with integer extensions.
535   auto addAttrSet = [&](AttributeSet &FnAS, const AttributeSet &AS,
536                         bool Param = true) -> void {
537     bool HasSignExt = AS.hasAttribute(Attribute::SExt);
538     bool HasZeroExt = AS.hasAttribute(Attribute::ZExt);
539     if (HasSignExt || HasZeroExt) {
540       assert(AS.getNumAttributes() == 1 &&
541              "Currently not handling extension attr combined with others.");
542       if (Param) {
543         if (auto AK = TargetLibraryInfo::getExtAttrForI32Param(T, HasSignExt))
544           FnAS = FnAS.addAttribute(Ctx, AK);
545       } else if (auto AK =
546                      TargetLibraryInfo::getExtAttrForI32Return(T, HasSignExt))
547         FnAS = FnAS.addAttribute(Ctx, AK);
548     } else {
549       FnAS = FnAS.addAttributes(Ctx, AS);
550     }
551   };
552 
553 #define OMP_ATTRS_SET(VarName, AttrSet) AttributeSet VarName = AttrSet;
554 #include "llvm/Frontend/OpenMP/OMPKinds.def"
555 
556   // Add attributes to the function declaration.
557   switch (FnID) {
558 #define OMP_RTL_ATTRS(Enum, FnAttrSet, RetAttrSet, ArgAttrSets)                \
559   case Enum:                                                                   \
560     FnAttrs = FnAttrs.addAttributes(Ctx, FnAttrSet);                           \
561     addAttrSet(RetAttrs, RetAttrSet, /*Param*/ false);                         \
562     for (size_t ArgNo = 0; ArgNo < ArgAttrSets.size(); ++ArgNo)                \
563       addAttrSet(ArgAttrs[ArgNo], ArgAttrSets[ArgNo]);                         \
564     Fn.setAttributes(AttributeList::get(Ctx, FnAttrs, RetAttrs, ArgAttrs));    \
565     break;
566 #include "llvm/Frontend/OpenMP/OMPKinds.def"
567   default:
568     // Attributes are optional.
569     break;
570   }
571 }
572 
573 FunctionCallee
574 OpenMPIRBuilder::getOrCreateRuntimeFunction(Module &M, RuntimeFunction FnID) {
575   FunctionType *FnTy = nullptr;
576   Function *Fn = nullptr;
577 
578   // Try to find the declation in the module first.
579   switch (FnID) {
580 #define OMP_RTL(Enum, Str, IsVarArg, ReturnType, ...)                          \
581   case Enum:                                                                   \
582     FnTy = FunctionType::get(ReturnType, ArrayRef<Type *>{__VA_ARGS__},        \
583                              IsVarArg);                                        \
584     Fn = M.getFunction(Str);                                                   \
585     break;
586 #include "llvm/Frontend/OpenMP/OMPKinds.def"
587   }
588 
589   if (!Fn) {
590     // Create a new declaration if we need one.
591     switch (FnID) {
592 #define OMP_RTL(Enum, Str, ...)                                                \
593   case Enum:                                                                   \
594     Fn = Function::Create(FnTy, GlobalValue::ExternalLinkage, Str, M);         \
595     break;
596 #include "llvm/Frontend/OpenMP/OMPKinds.def"
597     }
598 
599     // Add information if the runtime function takes a callback function
600     if (FnID == OMPRTL___kmpc_fork_call || FnID == OMPRTL___kmpc_fork_teams) {
601       if (!Fn->hasMetadata(LLVMContext::MD_callback)) {
602         LLVMContext &Ctx = Fn->getContext();
603         MDBuilder MDB(Ctx);
604         // Annotate the callback behavior of the runtime function:
605         //  - The callback callee is argument number 2 (microtask).
606         //  - The first two arguments of the callback callee are unknown (-1).
607         //  - All variadic arguments to the runtime function are passed to the
608         //    callback callee.
609         Fn->addMetadata(
610             LLVMContext::MD_callback,
611             *MDNode::get(Ctx, {MDB.createCallbackEncoding(
612                                   2, {-1, -1}, /* VarArgsArePassed */ true)}));
613       }
614     }
615 
616     LLVM_DEBUG(dbgs() << "Created OpenMP runtime function " << Fn->getName()
617                       << " with type " << *Fn->getFunctionType() << "\n");
618     addAttributes(FnID, *Fn);
619 
620   } else {
621     LLVM_DEBUG(dbgs() << "Found OpenMP runtime function " << Fn->getName()
622                       << " with type " << *Fn->getFunctionType() << "\n");
623   }
624 
625   assert(Fn && "Failed to create OpenMP runtime function");
626 
627   return {FnTy, Fn};
628 }
629 
630 Function *OpenMPIRBuilder::getOrCreateRuntimeFunctionPtr(RuntimeFunction FnID) {
631   FunctionCallee RTLFn = getOrCreateRuntimeFunction(M, FnID);
632   auto *Fn = dyn_cast<llvm::Function>(RTLFn.getCallee());
633   assert(Fn && "Failed to create OpenMP runtime function pointer");
634   return Fn;
635 }
636 
637 void OpenMPIRBuilder::initialize() { initializeTypes(M); }
638 
639 static void raiseUserConstantDataAllocasToEntryBlock(IRBuilderBase &Builder,
640                                                      Function *Function) {
641   BasicBlock &EntryBlock = Function->getEntryBlock();
642   Instruction *MoveLocInst = EntryBlock.getFirstNonPHI();
643 
644   // Loop over blocks looking for constant allocas, skipping the entry block
645   // as any allocas there are already in the desired location.
646   for (auto Block = std::next(Function->begin(), 1); Block != Function->end();
647        Block++) {
648     for (auto Inst = Block->getReverseIterator()->begin();
649          Inst != Block->getReverseIterator()->end();) {
650       if (auto *AllocaInst = dyn_cast_if_present<llvm::AllocaInst>(Inst)) {
651         Inst++;
652         if (!isa<ConstantData>(AllocaInst->getArraySize()))
653           continue;
654         AllocaInst->moveBeforePreserving(MoveLocInst);
655       } else {
656         Inst++;
657       }
658     }
659   }
660 }
661 
662 void OpenMPIRBuilder::finalize(Function *Fn) {
663   SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
664   SmallVector<BasicBlock *, 32> Blocks;
665   SmallVector<OutlineInfo, 16> DeferredOutlines;
666   for (OutlineInfo &OI : OutlineInfos) {
667     // Skip functions that have not finalized yet; may happen with nested
668     // function generation.
669     if (Fn && OI.getFunction() != Fn) {
670       DeferredOutlines.push_back(OI);
671       continue;
672     }
673 
674     ParallelRegionBlockSet.clear();
675     Blocks.clear();
676     OI.collectBlocks(ParallelRegionBlockSet, Blocks);
677 
678     Function *OuterFn = OI.getFunction();
679     CodeExtractorAnalysisCache CEAC(*OuterFn);
680     // If we generate code for the target device, we need to allocate
681     // struct for aggregate params in the device default alloca address space.
682     // OpenMP runtime requires that the params of the extracted functions are
683     // passed as zero address space pointers. This flag ensures that
684     // CodeExtractor generates correct code for extracted functions
685     // which are used by OpenMP runtime.
686     bool ArgsInZeroAddressSpace = Config.isTargetDevice();
687     CodeExtractor Extractor(Blocks, /* DominatorTree */ nullptr,
688                             /* AggregateArgs */ true,
689                             /* BlockFrequencyInfo */ nullptr,
690                             /* BranchProbabilityInfo */ nullptr,
691                             /* AssumptionCache */ nullptr,
692                             /* AllowVarArgs */ true,
693                             /* AllowAlloca */ true,
694                             /* AllocaBlock*/ OI.OuterAllocaBB,
695                             /* Suffix */ ".omp_par", ArgsInZeroAddressSpace);
696 
697     LLVM_DEBUG(dbgs() << "Before     outlining: " << *OuterFn << "\n");
698     LLVM_DEBUG(dbgs() << "Entry " << OI.EntryBB->getName()
699                       << " Exit: " << OI.ExitBB->getName() << "\n");
700     assert(Extractor.isEligible() &&
701            "Expected OpenMP outlining to be possible!");
702 
703     for (auto *V : OI.ExcludeArgsFromAggregate)
704       Extractor.excludeArgFromAggregate(V);
705 
706     Function *OutlinedFn = Extractor.extractCodeRegion(CEAC);
707 
708     // Forward target-cpu, target-features attributes to the outlined function.
709     auto TargetCpuAttr = OuterFn->getFnAttribute("target-cpu");
710     if (TargetCpuAttr.isStringAttribute())
711       OutlinedFn->addFnAttr(TargetCpuAttr);
712 
713     auto TargetFeaturesAttr = OuterFn->getFnAttribute("target-features");
714     if (TargetFeaturesAttr.isStringAttribute())
715       OutlinedFn->addFnAttr(TargetFeaturesAttr);
716 
717     LLVM_DEBUG(dbgs() << "After      outlining: " << *OuterFn << "\n");
718     LLVM_DEBUG(dbgs() << "   Outlined function: " << *OutlinedFn << "\n");
719     assert(OutlinedFn->getReturnType()->isVoidTy() &&
720            "OpenMP outlined functions should not return a value!");
721 
722     // For compability with the clang CG we move the outlined function after the
723     // one with the parallel region.
724     OutlinedFn->removeFromParent();
725     M.getFunctionList().insertAfter(OuterFn->getIterator(), OutlinedFn);
726 
727     // Remove the artificial entry introduced by the extractor right away, we
728     // made our own entry block after all.
729     {
730       BasicBlock &ArtificialEntry = OutlinedFn->getEntryBlock();
731       assert(ArtificialEntry.getUniqueSuccessor() == OI.EntryBB);
732       assert(OI.EntryBB->getUniquePredecessor() == &ArtificialEntry);
733       // Move instructions from the to-be-deleted ArtificialEntry to the entry
734       // basic block of the parallel region. CodeExtractor generates
735       // instructions to unwrap the aggregate argument and may sink
736       // allocas/bitcasts for values that are solely used in the outlined region
737       // and do not escape.
738       assert(!ArtificialEntry.empty() &&
739              "Expected instructions to add in the outlined region entry");
740       for (BasicBlock::reverse_iterator It = ArtificialEntry.rbegin(),
741                                         End = ArtificialEntry.rend();
742            It != End;) {
743         Instruction &I = *It;
744         It++;
745 
746         if (I.isTerminator())
747           continue;
748 
749         I.moveBeforePreserving(*OI.EntryBB, OI.EntryBB->getFirstInsertionPt());
750       }
751 
752       OI.EntryBB->moveBefore(&ArtificialEntry);
753       ArtificialEntry.eraseFromParent();
754     }
755     assert(&OutlinedFn->getEntryBlock() == OI.EntryBB);
756     assert(OutlinedFn && OutlinedFn->getNumUses() == 1);
757 
758     // Run a user callback, e.g. to add attributes.
759     if (OI.PostOutlineCB)
760       OI.PostOutlineCB(*OutlinedFn);
761   }
762 
763   // Remove work items that have been completed.
764   OutlineInfos = std::move(DeferredOutlines);
765 
766   // The createTarget functions embeds user written code into
767   // the target region which may inject allocas which need to
768   // be moved to the entry block of our target or risk malformed
769   // optimisations by later passes, this is only relevant for
770   // the device pass which appears to be a little more delicate
771   // when it comes to optimisations (however, we do not block on
772   // that here, it's up to the inserter to the list to do so).
773   // This notbaly has to occur after the OutlinedInfo candidates
774   // have been extracted so we have an end product that will not
775   // be implicitly adversely affected by any raises unless
776   // intentionally appended to the list.
777   // NOTE: This only does so for ConstantData, it could be extended
778   // to ConstantExpr's with further effort, however, they should
779   // largely be folded when they get here. Extending it to runtime
780   // defined/read+writeable allocation sizes would be non-trivial
781   // (need to factor in movement of any stores to variables the
782   // allocation size depends on, as well as the usual loads,
783   // otherwise it'll yield the wrong result after movement) and
784   // likely be more suitable as an LLVM optimisation pass.
785   for (Function *F : ConstantAllocaRaiseCandidates)
786     raiseUserConstantDataAllocasToEntryBlock(Builder, F);
787 
788   EmitMetadataErrorReportFunctionTy &&ErrorReportFn =
789       [](EmitMetadataErrorKind Kind,
790          const TargetRegionEntryInfo &EntryInfo) -> void {
791     errs() << "Error of kind: " << Kind
792            << " when emitting offload entries and metadata during "
793               "OMPIRBuilder finalization \n";
794   };
795 
796   if (!OffloadInfoManager.empty())
797     createOffloadEntriesAndInfoMetadata(ErrorReportFn);
798 
799   if (Config.EmitLLVMUsedMetaInfo.value_or(false)) {
800     std::vector<WeakTrackingVH> LLVMCompilerUsed = {
801         M.getGlobalVariable("__openmp_nvptx_data_transfer_temporary_storage")};
802     emitUsed("llvm.compiler.used", LLVMCompilerUsed);
803   }
804 }
805 
806 OpenMPIRBuilder::~OpenMPIRBuilder() {
807   assert(OutlineInfos.empty() && "There must be no outstanding outlinings");
808 }
809 
810 GlobalValue *OpenMPIRBuilder::createGlobalFlag(unsigned Value, StringRef Name) {
811   IntegerType *I32Ty = Type::getInt32Ty(M.getContext());
812   auto *GV =
813       new GlobalVariable(M, I32Ty,
814                          /* isConstant = */ true, GlobalValue::WeakODRLinkage,
815                          ConstantInt::get(I32Ty, Value), Name);
816   GV->setVisibility(GlobalValue::HiddenVisibility);
817 
818   return GV;
819 }
820 
821 Constant *OpenMPIRBuilder::getOrCreateIdent(Constant *SrcLocStr,
822                                             uint32_t SrcLocStrSize,
823                                             IdentFlag LocFlags,
824                                             unsigned Reserve2Flags) {
825   // Enable "C-mode".
826   LocFlags |= OMP_IDENT_FLAG_KMPC;
827 
828   Constant *&Ident =
829       IdentMap[{SrcLocStr, uint64_t(LocFlags) << 31 | Reserve2Flags}];
830   if (!Ident) {
831     Constant *I32Null = ConstantInt::getNullValue(Int32);
832     Constant *IdentData[] = {I32Null,
833                              ConstantInt::get(Int32, uint32_t(LocFlags)),
834                              ConstantInt::get(Int32, Reserve2Flags),
835                              ConstantInt::get(Int32, SrcLocStrSize), SrcLocStr};
836     Constant *Initializer =
837         ConstantStruct::get(OpenMPIRBuilder::Ident, IdentData);
838 
839     // Look for existing encoding of the location + flags, not needed but
840     // minimizes the difference to the existing solution while we transition.
841     for (GlobalVariable &GV : M.globals())
842       if (GV.getValueType() == OpenMPIRBuilder::Ident && GV.hasInitializer())
843         if (GV.getInitializer() == Initializer)
844           Ident = &GV;
845 
846     if (!Ident) {
847       auto *GV = new GlobalVariable(
848           M, OpenMPIRBuilder::Ident,
849           /* isConstant = */ true, GlobalValue::PrivateLinkage, Initializer, "",
850           nullptr, GlobalValue::NotThreadLocal,
851           M.getDataLayout().getDefaultGlobalsAddressSpace());
852       GV->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
853       GV->setAlignment(Align(8));
854       Ident = GV;
855     }
856   }
857 
858   return ConstantExpr::getPointerBitCastOrAddrSpaceCast(Ident, IdentPtr);
859 }
860 
861 Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(StringRef LocStr,
862                                                 uint32_t &SrcLocStrSize) {
863   SrcLocStrSize = LocStr.size();
864   Constant *&SrcLocStr = SrcLocStrMap[LocStr];
865   if (!SrcLocStr) {
866     Constant *Initializer =
867         ConstantDataArray::getString(M.getContext(), LocStr);
868 
869     // Look for existing encoding of the location, not needed but minimizes the
870     // difference to the existing solution while we transition.
871     for (GlobalVariable &GV : M.globals())
872       if (GV.isConstant() && GV.hasInitializer() &&
873           GV.getInitializer() == Initializer)
874         return SrcLocStr = ConstantExpr::getPointerCast(&GV, Int8Ptr);
875 
876     SrcLocStr = Builder.CreateGlobalStringPtr(LocStr, /* Name */ "",
877                                               /* AddressSpace */ 0, &M);
878   }
879   return SrcLocStr;
880 }
881 
882 Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(StringRef FunctionName,
883                                                 StringRef FileName,
884                                                 unsigned Line, unsigned Column,
885                                                 uint32_t &SrcLocStrSize) {
886   SmallString<128> Buffer;
887   Buffer.push_back(';');
888   Buffer.append(FileName);
889   Buffer.push_back(';');
890   Buffer.append(FunctionName);
891   Buffer.push_back(';');
892   Buffer.append(std::to_string(Line));
893   Buffer.push_back(';');
894   Buffer.append(std::to_string(Column));
895   Buffer.push_back(';');
896   Buffer.push_back(';');
897   return getOrCreateSrcLocStr(Buffer.str(), SrcLocStrSize);
898 }
899 
900 Constant *
901 OpenMPIRBuilder::getOrCreateDefaultSrcLocStr(uint32_t &SrcLocStrSize) {
902   StringRef UnknownLoc = ";unknown;unknown;0;0;;";
903   return getOrCreateSrcLocStr(UnknownLoc, SrcLocStrSize);
904 }
905 
906 Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(DebugLoc DL,
907                                                 uint32_t &SrcLocStrSize,
908                                                 Function *F) {
909   DILocation *DIL = DL.get();
910   if (!DIL)
911     return getOrCreateDefaultSrcLocStr(SrcLocStrSize);
912   StringRef FileName = M.getName();
913   if (DIFile *DIF = DIL->getFile())
914     if (std::optional<StringRef> Source = DIF->getSource())
915       FileName = *Source;
916   StringRef Function = DIL->getScope()->getSubprogram()->getName();
917   if (Function.empty() && F)
918     Function = F->getName();
919   return getOrCreateSrcLocStr(Function, FileName, DIL->getLine(),
920                               DIL->getColumn(), SrcLocStrSize);
921 }
922 
923 Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(const LocationDescription &Loc,
924                                                 uint32_t &SrcLocStrSize) {
925   return getOrCreateSrcLocStr(Loc.DL, SrcLocStrSize,
926                               Loc.IP.getBlock()->getParent());
927 }
928 
929 Value *OpenMPIRBuilder::getOrCreateThreadID(Value *Ident) {
930   return Builder.CreateCall(
931       getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_global_thread_num), Ident,
932       "omp_global_thread_num");
933 }
934 
935 OpenMPIRBuilder::InsertPointTy
936 OpenMPIRBuilder::createBarrier(const LocationDescription &Loc, Directive Kind,
937                                bool ForceSimpleCall, bool CheckCancelFlag) {
938   if (!updateToLocation(Loc))
939     return Loc.IP;
940 
941   // Build call __kmpc_cancel_barrier(loc, thread_id) or
942   //            __kmpc_barrier(loc, thread_id);
943 
944   IdentFlag BarrierLocFlags;
945   switch (Kind) {
946   case OMPD_for:
947     BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL_FOR;
948     break;
949   case OMPD_sections:
950     BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL_SECTIONS;
951     break;
952   case OMPD_single:
953     BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL_SINGLE;
954     break;
955   case OMPD_barrier:
956     BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_EXPL;
957     break;
958   default:
959     BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL;
960     break;
961   }
962 
963   uint32_t SrcLocStrSize;
964   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
965   Value *Args[] = {
966       getOrCreateIdent(SrcLocStr, SrcLocStrSize, BarrierLocFlags),
967       getOrCreateThreadID(getOrCreateIdent(SrcLocStr, SrcLocStrSize))};
968 
969   // If we are in a cancellable parallel region, barriers are cancellation
970   // points.
971   // TODO: Check why we would force simple calls or to ignore the cancel flag.
972   bool UseCancelBarrier =
973       !ForceSimpleCall && isLastFinalizationInfoCancellable(OMPD_parallel);
974 
975   Value *Result =
976       Builder.CreateCall(getOrCreateRuntimeFunctionPtr(
977                              UseCancelBarrier ? OMPRTL___kmpc_cancel_barrier
978                                               : OMPRTL___kmpc_barrier),
979                          Args);
980 
981   if (UseCancelBarrier && CheckCancelFlag)
982     emitCancelationCheckImpl(Result, OMPD_parallel);
983 
984   return Builder.saveIP();
985 }
986 
987 OpenMPIRBuilder::InsertPointTy
988 OpenMPIRBuilder::createCancel(const LocationDescription &Loc,
989                               Value *IfCondition,
990                               omp::Directive CanceledDirective) {
991   if (!updateToLocation(Loc))
992     return Loc.IP;
993 
994   // LLVM utilities like blocks with terminators.
995   auto *UI = Builder.CreateUnreachable();
996 
997   Instruction *ThenTI = UI, *ElseTI = nullptr;
998   if (IfCondition)
999     SplitBlockAndInsertIfThenElse(IfCondition, UI, &ThenTI, &ElseTI);
1000   Builder.SetInsertPoint(ThenTI);
1001 
1002   Value *CancelKind = nullptr;
1003   switch (CanceledDirective) {
1004 #define OMP_CANCEL_KIND(Enum, Str, DirectiveEnum, Value)                       \
1005   case DirectiveEnum:                                                          \
1006     CancelKind = Builder.getInt32(Value);                                      \
1007     break;
1008 #include "llvm/Frontend/OpenMP/OMPKinds.def"
1009   default:
1010     llvm_unreachable("Unknown cancel kind!");
1011   }
1012 
1013   uint32_t SrcLocStrSize;
1014   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1015   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1016   Value *Args[] = {Ident, getOrCreateThreadID(Ident), CancelKind};
1017   Value *Result = Builder.CreateCall(
1018       getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_cancel), Args);
1019   auto ExitCB = [this, CanceledDirective, Loc](InsertPointTy IP) {
1020     if (CanceledDirective == OMPD_parallel) {
1021       IRBuilder<>::InsertPointGuard IPG(Builder);
1022       Builder.restoreIP(IP);
1023       createBarrier(LocationDescription(Builder.saveIP(), Loc.DL),
1024                     omp::Directive::OMPD_unknown, /* ForceSimpleCall */ false,
1025                     /* CheckCancelFlag */ false);
1026     }
1027   };
1028 
1029   // The actual cancel logic is shared with others, e.g., cancel_barriers.
1030   emitCancelationCheckImpl(Result, CanceledDirective, ExitCB);
1031 
1032   // Update the insertion point and remove the terminator we introduced.
1033   Builder.SetInsertPoint(UI->getParent());
1034   UI->eraseFromParent();
1035 
1036   return Builder.saveIP();
1037 }
1038 
1039 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitTargetKernel(
1040     const LocationDescription &Loc, InsertPointTy AllocaIP, Value *&Return,
1041     Value *Ident, Value *DeviceID, Value *NumTeams, Value *NumThreads,
1042     Value *HostPtr, ArrayRef<Value *> KernelArgs) {
1043   if (!updateToLocation(Loc))
1044     return Loc.IP;
1045 
1046   Builder.restoreIP(AllocaIP);
1047   auto *KernelArgsPtr =
1048       Builder.CreateAlloca(OpenMPIRBuilder::KernelArgs, nullptr, "kernel_args");
1049   Builder.restoreIP(Loc.IP);
1050 
1051   for (unsigned I = 0, Size = KernelArgs.size(); I != Size; ++I) {
1052     llvm::Value *Arg =
1053         Builder.CreateStructGEP(OpenMPIRBuilder::KernelArgs, KernelArgsPtr, I);
1054     Builder.CreateAlignedStore(
1055         KernelArgs[I], Arg,
1056         M.getDataLayout().getPrefTypeAlign(KernelArgs[I]->getType()));
1057   }
1058 
1059   SmallVector<Value *> OffloadingArgs{Ident,      DeviceID, NumTeams,
1060                                       NumThreads, HostPtr,  KernelArgsPtr};
1061 
1062   Return = Builder.CreateCall(
1063       getOrCreateRuntimeFunction(M, OMPRTL___tgt_target_kernel),
1064       OffloadingArgs);
1065 
1066   return Builder.saveIP();
1067 }
1068 
1069 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitKernelLaunch(
1070     const LocationDescription &Loc, Function *OutlinedFn, Value *OutlinedFnID,
1071     EmitFallbackCallbackTy emitTargetCallFallbackCB, TargetKernelArgs &Args,
1072     Value *DeviceID, Value *RTLoc, InsertPointTy AllocaIP) {
1073 
1074   if (!updateToLocation(Loc))
1075     return Loc.IP;
1076 
1077   Builder.restoreIP(Loc.IP);
1078   // On top of the arrays that were filled up, the target offloading call
1079   // takes as arguments the device id as well as the host pointer. The host
1080   // pointer is used by the runtime library to identify the current target
1081   // region, so it only has to be unique and not necessarily point to
1082   // anything. It could be the pointer to the outlined function that
1083   // implements the target region, but we aren't using that so that the
1084   // compiler doesn't need to keep that, and could therefore inline the host
1085   // function if proven worthwhile during optimization.
1086 
1087   // From this point on, we need to have an ID of the target region defined.
1088   assert(OutlinedFnID && "Invalid outlined function ID!");
1089   (void)OutlinedFnID;
1090 
1091   // Return value of the runtime offloading call.
1092   Value *Return = nullptr;
1093 
1094   // Arguments for the target kernel.
1095   SmallVector<Value *> ArgsVector;
1096   getKernelArgsVector(Args, Builder, ArgsVector);
1097 
1098   // The target region is an outlined function launched by the runtime
1099   // via calls to __tgt_target_kernel().
1100   //
1101   // Note that on the host and CPU targets, the runtime implementation of
1102   // these calls simply call the outlined function without forking threads.
1103   // The outlined functions themselves have runtime calls to
1104   // __kmpc_fork_teams() and __kmpc_fork() for this purpose, codegen'd by
1105   // the compiler in emitTeamsCall() and emitParallelCall().
1106   //
1107   // In contrast, on the NVPTX target, the implementation of
1108   // __tgt_target_teams() launches a GPU kernel with the requested number
1109   // of teams and threads so no additional calls to the runtime are required.
1110   // Check the error code and execute the host version if required.
1111   Builder.restoreIP(emitTargetKernel(Builder, AllocaIP, Return, RTLoc, DeviceID,
1112                                      Args.NumTeams, Args.NumThreads,
1113                                      OutlinedFnID, ArgsVector));
1114 
1115   BasicBlock *OffloadFailedBlock =
1116       BasicBlock::Create(Builder.getContext(), "omp_offload.failed");
1117   BasicBlock *OffloadContBlock =
1118       BasicBlock::Create(Builder.getContext(), "omp_offload.cont");
1119   Value *Failed = Builder.CreateIsNotNull(Return);
1120   Builder.CreateCondBr(Failed, OffloadFailedBlock, OffloadContBlock);
1121 
1122   auto CurFn = Builder.GetInsertBlock()->getParent();
1123   emitBlock(OffloadFailedBlock, CurFn);
1124   Builder.restoreIP(emitTargetCallFallbackCB(Builder.saveIP()));
1125   emitBranch(OffloadContBlock);
1126   emitBlock(OffloadContBlock, CurFn, /*IsFinished=*/true);
1127   return Builder.saveIP();
1128 }
1129 
1130 void OpenMPIRBuilder::emitCancelationCheckImpl(Value *CancelFlag,
1131                                                omp::Directive CanceledDirective,
1132                                                FinalizeCallbackTy ExitCB) {
1133   assert(isLastFinalizationInfoCancellable(CanceledDirective) &&
1134          "Unexpected cancellation!");
1135 
1136   // For a cancel barrier we create two new blocks.
1137   BasicBlock *BB = Builder.GetInsertBlock();
1138   BasicBlock *NonCancellationBlock;
1139   if (Builder.GetInsertPoint() == BB->end()) {
1140     // TODO: This branch will not be needed once we moved to the
1141     // OpenMPIRBuilder codegen completely.
1142     NonCancellationBlock = BasicBlock::Create(
1143         BB->getContext(), BB->getName() + ".cont", BB->getParent());
1144   } else {
1145     NonCancellationBlock = SplitBlock(BB, &*Builder.GetInsertPoint());
1146     BB->getTerminator()->eraseFromParent();
1147     Builder.SetInsertPoint(BB);
1148   }
1149   BasicBlock *CancellationBlock = BasicBlock::Create(
1150       BB->getContext(), BB->getName() + ".cncl", BB->getParent());
1151 
1152   // Jump to them based on the return value.
1153   Value *Cmp = Builder.CreateIsNull(CancelFlag);
1154   Builder.CreateCondBr(Cmp, NonCancellationBlock, CancellationBlock,
1155                        /* TODO weight */ nullptr, nullptr);
1156 
1157   // From the cancellation block we finalize all variables and go to the
1158   // post finalization block that is known to the FiniCB callback.
1159   Builder.SetInsertPoint(CancellationBlock);
1160   if (ExitCB)
1161     ExitCB(Builder.saveIP());
1162   auto &FI = FinalizationStack.back();
1163   FI.FiniCB(Builder.saveIP());
1164 
1165   // The continuation block is where code generation continues.
1166   Builder.SetInsertPoint(NonCancellationBlock, NonCancellationBlock->begin());
1167 }
1168 
1169 // Callback used to create OpenMP runtime calls to support
1170 // omp parallel clause for the device.
1171 // We need to use this callback to replace call to the OutlinedFn in OuterFn
1172 // by the call to the OpenMP DeviceRTL runtime function (kmpc_parallel_51)
1173 static void targetParallelCallback(
1174     OpenMPIRBuilder *OMPIRBuilder, Function &OutlinedFn, Function *OuterFn,
1175     BasicBlock *OuterAllocaBB, Value *Ident, Value *IfCondition,
1176     Value *NumThreads, Instruction *PrivTID, AllocaInst *PrivTIDAddr,
1177     Value *ThreadID, const SmallVector<Instruction *, 4> &ToBeDeleted) {
1178   // Add some known attributes.
1179   IRBuilder<> &Builder = OMPIRBuilder->Builder;
1180   OutlinedFn.addParamAttr(0, Attribute::NoAlias);
1181   OutlinedFn.addParamAttr(1, Attribute::NoAlias);
1182   OutlinedFn.addParamAttr(0, Attribute::NoUndef);
1183   OutlinedFn.addParamAttr(1, Attribute::NoUndef);
1184   OutlinedFn.addFnAttr(Attribute::NoUnwind);
1185 
1186   assert(OutlinedFn.arg_size() >= 2 &&
1187          "Expected at least tid and bounded tid as arguments");
1188   unsigned NumCapturedVars = OutlinedFn.arg_size() - /* tid & bounded tid */ 2;
1189 
1190   CallInst *CI = cast<CallInst>(OutlinedFn.user_back());
1191   assert(CI && "Expected call instruction to outlined function");
1192   CI->getParent()->setName("omp_parallel");
1193 
1194   Builder.SetInsertPoint(CI);
1195   Type *PtrTy = OMPIRBuilder->VoidPtr;
1196   Value *NullPtrValue = Constant::getNullValue(PtrTy);
1197 
1198   // Add alloca for kernel args
1199   OpenMPIRBuilder ::InsertPointTy CurrentIP = Builder.saveIP();
1200   Builder.SetInsertPoint(OuterAllocaBB, OuterAllocaBB->getFirstInsertionPt());
1201   AllocaInst *ArgsAlloca =
1202       Builder.CreateAlloca(ArrayType::get(PtrTy, NumCapturedVars));
1203   Value *Args = ArgsAlloca;
1204   // Add address space cast if array for storing arguments is not allocated
1205   // in address space 0
1206   if (ArgsAlloca->getAddressSpace())
1207     Args = Builder.CreatePointerCast(ArgsAlloca, PtrTy);
1208   Builder.restoreIP(CurrentIP);
1209 
1210   // Store captured vars which are used by kmpc_parallel_51
1211   for (unsigned Idx = 0; Idx < NumCapturedVars; Idx++) {
1212     Value *V = *(CI->arg_begin() + 2 + Idx);
1213     Value *StoreAddress = Builder.CreateConstInBoundsGEP2_64(
1214         ArrayType::get(PtrTy, NumCapturedVars), Args, 0, Idx);
1215     Builder.CreateStore(V, StoreAddress);
1216   }
1217 
1218   Value *Cond =
1219       IfCondition ? Builder.CreateSExtOrTrunc(IfCondition, OMPIRBuilder->Int32)
1220                   : Builder.getInt32(1);
1221 
1222   // Build kmpc_parallel_51 call
1223   Value *Parallel51CallArgs[] = {
1224       /* identifier*/ Ident,
1225       /* global thread num*/ ThreadID,
1226       /* if expression */ Cond,
1227       /* number of threads */ NumThreads ? NumThreads : Builder.getInt32(-1),
1228       /* Proc bind */ Builder.getInt32(-1),
1229       /* outlined function */
1230       Builder.CreateBitCast(&OutlinedFn, OMPIRBuilder->ParallelTaskPtr),
1231       /* wrapper function */ NullPtrValue,
1232       /* arguments of the outlined funciton*/ Args,
1233       /* number of arguments */ Builder.getInt64(NumCapturedVars)};
1234 
1235   FunctionCallee RTLFn =
1236       OMPIRBuilder->getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_parallel_51);
1237 
1238   Builder.CreateCall(RTLFn, Parallel51CallArgs);
1239 
1240   LLVM_DEBUG(dbgs() << "With kmpc_parallel_51 placed: "
1241                     << *Builder.GetInsertBlock()->getParent() << "\n");
1242 
1243   // Initialize the local TID stack location with the argument value.
1244   Builder.SetInsertPoint(PrivTID);
1245   Function::arg_iterator OutlinedAI = OutlinedFn.arg_begin();
1246   Builder.CreateStore(Builder.CreateLoad(OMPIRBuilder->Int32, OutlinedAI),
1247                       PrivTIDAddr);
1248 
1249   // Remove redundant call to the outlined function.
1250   CI->eraseFromParent();
1251 
1252   for (Instruction *I : ToBeDeleted) {
1253     I->eraseFromParent();
1254   }
1255 }
1256 
1257 // Callback used to create OpenMP runtime calls to support
1258 // omp parallel clause for the host.
1259 // We need to use this callback to replace call to the OutlinedFn in OuterFn
1260 // by the call to the OpenMP host runtime function ( __kmpc_fork_call[_if])
1261 static void
1262 hostParallelCallback(OpenMPIRBuilder *OMPIRBuilder, Function &OutlinedFn,
1263                      Function *OuterFn, Value *Ident, Value *IfCondition,
1264                      Instruction *PrivTID, AllocaInst *PrivTIDAddr,
1265                      const SmallVector<Instruction *, 4> &ToBeDeleted) {
1266   IRBuilder<> &Builder = OMPIRBuilder->Builder;
1267   FunctionCallee RTLFn;
1268   if (IfCondition) {
1269     RTLFn =
1270         OMPIRBuilder->getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_call_if);
1271   } else {
1272     RTLFn =
1273         OMPIRBuilder->getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_call);
1274   }
1275   if (auto *F = dyn_cast<Function>(RTLFn.getCallee())) {
1276     if (!F->hasMetadata(LLVMContext::MD_callback)) {
1277       LLVMContext &Ctx = F->getContext();
1278       MDBuilder MDB(Ctx);
1279       // Annotate the callback behavior of the __kmpc_fork_call:
1280       //  - The callback callee is argument number 2 (microtask).
1281       //  - The first two arguments of the callback callee are unknown (-1).
1282       //  - All variadic arguments to the __kmpc_fork_call are passed to the
1283       //    callback callee.
1284       F->addMetadata(LLVMContext::MD_callback,
1285                      *MDNode::get(Ctx, {MDB.createCallbackEncoding(
1286                                            2, {-1, -1},
1287                                            /* VarArgsArePassed */ true)}));
1288     }
1289   }
1290   // Add some known attributes.
1291   OutlinedFn.addParamAttr(0, Attribute::NoAlias);
1292   OutlinedFn.addParamAttr(1, Attribute::NoAlias);
1293   OutlinedFn.addFnAttr(Attribute::NoUnwind);
1294 
1295   assert(OutlinedFn.arg_size() >= 2 &&
1296          "Expected at least tid and bounded tid as arguments");
1297   unsigned NumCapturedVars = OutlinedFn.arg_size() - /* tid & bounded tid */ 2;
1298 
1299   CallInst *CI = cast<CallInst>(OutlinedFn.user_back());
1300   CI->getParent()->setName("omp_parallel");
1301   Builder.SetInsertPoint(CI);
1302 
1303   // Build call __kmpc_fork_call[_if](Ident, n, microtask, var1, .., varn);
1304   Value *ForkCallArgs[] = {
1305       Ident, Builder.getInt32(NumCapturedVars),
1306       Builder.CreateBitCast(&OutlinedFn, OMPIRBuilder->ParallelTaskPtr)};
1307 
1308   SmallVector<Value *, 16> RealArgs;
1309   RealArgs.append(std::begin(ForkCallArgs), std::end(ForkCallArgs));
1310   if (IfCondition) {
1311     Value *Cond = Builder.CreateSExtOrTrunc(IfCondition, OMPIRBuilder->Int32);
1312     RealArgs.push_back(Cond);
1313   }
1314   RealArgs.append(CI->arg_begin() + /* tid & bound tid */ 2, CI->arg_end());
1315 
1316   // __kmpc_fork_call_if always expects a void ptr as the last argument
1317   // If there are no arguments, pass a null pointer.
1318   auto PtrTy = OMPIRBuilder->VoidPtr;
1319   if (IfCondition && NumCapturedVars == 0) {
1320     Value *NullPtrValue = Constant::getNullValue(PtrTy);
1321     RealArgs.push_back(NullPtrValue);
1322   }
1323   if (IfCondition && RealArgs.back()->getType() != PtrTy)
1324     RealArgs.back() = Builder.CreateBitCast(RealArgs.back(), PtrTy);
1325 
1326   Builder.CreateCall(RTLFn, RealArgs);
1327 
1328   LLVM_DEBUG(dbgs() << "With fork_call placed: "
1329                     << *Builder.GetInsertBlock()->getParent() << "\n");
1330 
1331   // Initialize the local TID stack location with the argument value.
1332   Builder.SetInsertPoint(PrivTID);
1333   Function::arg_iterator OutlinedAI = OutlinedFn.arg_begin();
1334   Builder.CreateStore(Builder.CreateLoad(OMPIRBuilder->Int32, OutlinedAI),
1335                       PrivTIDAddr);
1336 
1337   // Remove redundant call to the outlined function.
1338   CI->eraseFromParent();
1339 
1340   for (Instruction *I : ToBeDeleted) {
1341     I->eraseFromParent();
1342   }
1343 }
1344 
1345 IRBuilder<>::InsertPoint OpenMPIRBuilder::createParallel(
1346     const LocationDescription &Loc, InsertPointTy OuterAllocaIP,
1347     BodyGenCallbackTy BodyGenCB, PrivatizeCallbackTy PrivCB,
1348     FinalizeCallbackTy FiniCB, Value *IfCondition, Value *NumThreads,
1349     omp::ProcBindKind ProcBind, bool IsCancellable) {
1350   assert(!isConflictIP(Loc.IP, OuterAllocaIP) && "IPs must not be ambiguous");
1351 
1352   if (!updateToLocation(Loc))
1353     return Loc.IP;
1354 
1355   uint32_t SrcLocStrSize;
1356   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1357   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1358   Value *ThreadID = getOrCreateThreadID(Ident);
1359   // If we generate code for the target device, we need to allocate
1360   // struct for aggregate params in the device default alloca address space.
1361   // OpenMP runtime requires that the params of the extracted functions are
1362   // passed as zero address space pointers. This flag ensures that extracted
1363   // function arguments are declared in zero address space
1364   bool ArgsInZeroAddressSpace = Config.isTargetDevice();
1365 
1366   // Build call __kmpc_push_num_threads(&Ident, global_tid, num_threads)
1367   // only if we compile for host side.
1368   if (NumThreads && !Config.isTargetDevice()) {
1369     Value *Args[] = {
1370         Ident, ThreadID,
1371         Builder.CreateIntCast(NumThreads, Int32, /*isSigned*/ false)};
1372     Builder.CreateCall(
1373         getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_num_threads), Args);
1374   }
1375 
1376   if (ProcBind != OMP_PROC_BIND_default) {
1377     // Build call __kmpc_push_proc_bind(&Ident, global_tid, proc_bind)
1378     Value *Args[] = {
1379         Ident, ThreadID,
1380         ConstantInt::get(Int32, unsigned(ProcBind), /*isSigned=*/true)};
1381     Builder.CreateCall(
1382         getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_proc_bind), Args);
1383   }
1384 
1385   BasicBlock *InsertBB = Builder.GetInsertBlock();
1386   Function *OuterFn = InsertBB->getParent();
1387 
1388   // Save the outer alloca block because the insertion iterator may get
1389   // invalidated and we still need this later.
1390   BasicBlock *OuterAllocaBlock = OuterAllocaIP.getBlock();
1391 
1392   // Vector to remember instructions we used only during the modeling but which
1393   // we want to delete at the end.
1394   SmallVector<Instruction *, 4> ToBeDeleted;
1395 
1396   // Change the location to the outer alloca insertion point to create and
1397   // initialize the allocas we pass into the parallel region.
1398   InsertPointTy NewOuter(OuterAllocaBlock, OuterAllocaBlock->begin());
1399   Builder.restoreIP(NewOuter);
1400   AllocaInst *TIDAddrAlloca = Builder.CreateAlloca(Int32, nullptr, "tid.addr");
1401   AllocaInst *ZeroAddrAlloca =
1402       Builder.CreateAlloca(Int32, nullptr, "zero.addr");
1403   Instruction *TIDAddr = TIDAddrAlloca;
1404   Instruction *ZeroAddr = ZeroAddrAlloca;
1405   if (ArgsInZeroAddressSpace && M.getDataLayout().getAllocaAddrSpace() != 0) {
1406     // Add additional casts to enforce pointers in zero address space
1407     TIDAddr = new AddrSpaceCastInst(
1408         TIDAddrAlloca, PointerType ::get(M.getContext(), 0), "tid.addr.ascast");
1409     TIDAddr->insertAfter(TIDAddrAlloca);
1410     ToBeDeleted.push_back(TIDAddr);
1411     ZeroAddr = new AddrSpaceCastInst(ZeroAddrAlloca,
1412                                      PointerType ::get(M.getContext(), 0),
1413                                      "zero.addr.ascast");
1414     ZeroAddr->insertAfter(ZeroAddrAlloca);
1415     ToBeDeleted.push_back(ZeroAddr);
1416   }
1417 
1418   // We only need TIDAddr and ZeroAddr for modeling purposes to get the
1419   // associated arguments in the outlined function, so we delete them later.
1420   ToBeDeleted.push_back(TIDAddrAlloca);
1421   ToBeDeleted.push_back(ZeroAddrAlloca);
1422 
1423   // Create an artificial insertion point that will also ensure the blocks we
1424   // are about to split are not degenerated.
1425   auto *UI = new UnreachableInst(Builder.getContext(), InsertBB);
1426 
1427   BasicBlock *EntryBB = UI->getParent();
1428   BasicBlock *PRegEntryBB = EntryBB->splitBasicBlock(UI, "omp.par.entry");
1429   BasicBlock *PRegBodyBB = PRegEntryBB->splitBasicBlock(UI, "omp.par.region");
1430   BasicBlock *PRegPreFiniBB =
1431       PRegBodyBB->splitBasicBlock(UI, "omp.par.pre_finalize");
1432   BasicBlock *PRegExitBB = PRegPreFiniBB->splitBasicBlock(UI, "omp.par.exit");
1433 
1434   auto FiniCBWrapper = [&](InsertPointTy IP) {
1435     // Hide "open-ended" blocks from the given FiniCB by setting the right jump
1436     // target to the region exit block.
1437     if (IP.getBlock()->end() == IP.getPoint()) {
1438       IRBuilder<>::InsertPointGuard IPG(Builder);
1439       Builder.restoreIP(IP);
1440       Instruction *I = Builder.CreateBr(PRegExitBB);
1441       IP = InsertPointTy(I->getParent(), I->getIterator());
1442     }
1443     assert(IP.getBlock()->getTerminator()->getNumSuccessors() == 1 &&
1444            IP.getBlock()->getTerminator()->getSuccessor(0) == PRegExitBB &&
1445            "Unexpected insertion point for finalization call!");
1446     return FiniCB(IP);
1447   };
1448 
1449   FinalizationStack.push_back({FiniCBWrapper, OMPD_parallel, IsCancellable});
1450 
1451   // Generate the privatization allocas in the block that will become the entry
1452   // of the outlined function.
1453   Builder.SetInsertPoint(PRegEntryBB->getTerminator());
1454   InsertPointTy InnerAllocaIP = Builder.saveIP();
1455 
1456   AllocaInst *PrivTIDAddr =
1457       Builder.CreateAlloca(Int32, nullptr, "tid.addr.local");
1458   Instruction *PrivTID = Builder.CreateLoad(Int32, PrivTIDAddr, "tid");
1459 
1460   // Add some fake uses for OpenMP provided arguments.
1461   ToBeDeleted.push_back(Builder.CreateLoad(Int32, TIDAddr, "tid.addr.use"));
1462   Instruction *ZeroAddrUse =
1463       Builder.CreateLoad(Int32, ZeroAddr, "zero.addr.use");
1464   ToBeDeleted.push_back(ZeroAddrUse);
1465 
1466   // EntryBB
1467   //   |
1468   //   V
1469   // PRegionEntryBB         <- Privatization allocas are placed here.
1470   //   |
1471   //   V
1472   // PRegionBodyBB          <- BodeGen is invoked here.
1473   //   |
1474   //   V
1475   // PRegPreFiniBB          <- The block we will start finalization from.
1476   //   |
1477   //   V
1478   // PRegionExitBB          <- A common exit to simplify block collection.
1479   //
1480 
1481   LLVM_DEBUG(dbgs() << "Before body codegen: " << *OuterFn << "\n");
1482 
1483   // Let the caller create the body.
1484   assert(BodyGenCB && "Expected body generation callback!");
1485   InsertPointTy CodeGenIP(PRegBodyBB, PRegBodyBB->begin());
1486   BodyGenCB(InnerAllocaIP, CodeGenIP);
1487 
1488   LLVM_DEBUG(dbgs() << "After  body codegen: " << *OuterFn << "\n");
1489 
1490   OutlineInfo OI;
1491   if (Config.isTargetDevice()) {
1492     // Generate OpenMP target specific runtime call
1493     OI.PostOutlineCB = [=, ToBeDeletedVec =
1494                                std::move(ToBeDeleted)](Function &OutlinedFn) {
1495       targetParallelCallback(this, OutlinedFn, OuterFn, OuterAllocaBlock, Ident,
1496                              IfCondition, NumThreads, PrivTID, PrivTIDAddr,
1497                              ThreadID, ToBeDeletedVec);
1498     };
1499   } else {
1500     // Generate OpenMP host runtime call
1501     OI.PostOutlineCB = [=, ToBeDeletedVec =
1502                                std::move(ToBeDeleted)](Function &OutlinedFn) {
1503       hostParallelCallback(this, OutlinedFn, OuterFn, Ident, IfCondition,
1504                            PrivTID, PrivTIDAddr, ToBeDeletedVec);
1505     };
1506   }
1507 
1508   OI.OuterAllocaBB = OuterAllocaBlock;
1509   OI.EntryBB = PRegEntryBB;
1510   OI.ExitBB = PRegExitBB;
1511 
1512   SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
1513   SmallVector<BasicBlock *, 32> Blocks;
1514   OI.collectBlocks(ParallelRegionBlockSet, Blocks);
1515 
1516   // Ensure a single exit node for the outlined region by creating one.
1517   // We might have multiple incoming edges to the exit now due to finalizations,
1518   // e.g., cancel calls that cause the control flow to leave the region.
1519   BasicBlock *PRegOutlinedExitBB = PRegExitBB;
1520   PRegExitBB = SplitBlock(PRegExitBB, &*PRegExitBB->getFirstInsertionPt());
1521   PRegOutlinedExitBB->setName("omp.par.outlined.exit");
1522   Blocks.push_back(PRegOutlinedExitBB);
1523 
1524   CodeExtractorAnalysisCache CEAC(*OuterFn);
1525   CodeExtractor Extractor(Blocks, /* DominatorTree */ nullptr,
1526                           /* AggregateArgs */ false,
1527                           /* BlockFrequencyInfo */ nullptr,
1528                           /* BranchProbabilityInfo */ nullptr,
1529                           /* AssumptionCache */ nullptr,
1530                           /* AllowVarArgs */ true,
1531                           /* AllowAlloca */ true,
1532                           /* AllocationBlock */ OuterAllocaBlock,
1533                           /* Suffix */ ".omp_par", ArgsInZeroAddressSpace);
1534 
1535   // Find inputs to, outputs from the code region.
1536   BasicBlock *CommonExit = nullptr;
1537   SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
1538   Extractor.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
1539   Extractor.findInputsOutputs(Inputs, Outputs, SinkingCands);
1540 
1541   LLVM_DEBUG(dbgs() << "Before privatization: " << *OuterFn << "\n");
1542 
1543   FunctionCallee TIDRTLFn =
1544       getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_global_thread_num);
1545 
1546   auto PrivHelper = [&](Value &V) {
1547     if (&V == TIDAddr || &V == ZeroAddr) {
1548       OI.ExcludeArgsFromAggregate.push_back(&V);
1549       return;
1550     }
1551 
1552     SetVector<Use *> Uses;
1553     for (Use &U : V.uses())
1554       if (auto *UserI = dyn_cast<Instruction>(U.getUser()))
1555         if (ParallelRegionBlockSet.count(UserI->getParent()))
1556           Uses.insert(&U);
1557 
1558     // __kmpc_fork_call expects extra arguments as pointers. If the input
1559     // already has a pointer type, everything is fine. Otherwise, store the
1560     // value onto stack and load it back inside the to-be-outlined region. This
1561     // will ensure only the pointer will be passed to the function.
1562     // FIXME: if there are more than 15 trailing arguments, they must be
1563     // additionally packed in a struct.
1564     Value *Inner = &V;
1565     if (!V.getType()->isPointerTy()) {
1566       IRBuilder<>::InsertPointGuard Guard(Builder);
1567       LLVM_DEBUG(llvm::dbgs() << "Forwarding input as pointer: " << V << "\n");
1568 
1569       Builder.restoreIP(OuterAllocaIP);
1570       Value *Ptr =
1571           Builder.CreateAlloca(V.getType(), nullptr, V.getName() + ".reloaded");
1572 
1573       // Store to stack at end of the block that currently branches to the entry
1574       // block of the to-be-outlined region.
1575       Builder.SetInsertPoint(InsertBB,
1576                              InsertBB->getTerminator()->getIterator());
1577       Builder.CreateStore(&V, Ptr);
1578 
1579       // Load back next to allocations in the to-be-outlined region.
1580       Builder.restoreIP(InnerAllocaIP);
1581       Inner = Builder.CreateLoad(V.getType(), Ptr);
1582     }
1583 
1584     Value *ReplacementValue = nullptr;
1585     CallInst *CI = dyn_cast<CallInst>(&V);
1586     if (CI && CI->getCalledFunction() == TIDRTLFn.getCallee()) {
1587       ReplacementValue = PrivTID;
1588     } else {
1589       Builder.restoreIP(
1590           PrivCB(InnerAllocaIP, Builder.saveIP(), V, *Inner, ReplacementValue));
1591       InnerAllocaIP = {
1592           InnerAllocaIP.getBlock(),
1593           InnerAllocaIP.getBlock()->getTerminator()->getIterator()};
1594 
1595       assert(ReplacementValue &&
1596              "Expected copy/create callback to set replacement value!");
1597       if (ReplacementValue == &V)
1598         return;
1599     }
1600 
1601     for (Use *UPtr : Uses)
1602       UPtr->set(ReplacementValue);
1603   };
1604 
1605   // Reset the inner alloca insertion as it will be used for loading the values
1606   // wrapped into pointers before passing them into the to-be-outlined region.
1607   // Configure it to insert immediately after the fake use of zero address so
1608   // that they are available in the generated body and so that the
1609   // OpenMP-related values (thread ID and zero address pointers) remain leading
1610   // in the argument list.
1611   InnerAllocaIP = IRBuilder<>::InsertPoint(
1612       ZeroAddrUse->getParent(), ZeroAddrUse->getNextNode()->getIterator());
1613 
1614   // Reset the outer alloca insertion point to the entry of the relevant block
1615   // in case it was invalidated.
1616   OuterAllocaIP = IRBuilder<>::InsertPoint(
1617       OuterAllocaBlock, OuterAllocaBlock->getFirstInsertionPt());
1618 
1619   for (Value *Input : Inputs) {
1620     LLVM_DEBUG(dbgs() << "Captured input: " << *Input << "\n");
1621     PrivHelper(*Input);
1622   }
1623   LLVM_DEBUG({
1624     for (Value *Output : Outputs)
1625       LLVM_DEBUG(dbgs() << "Captured output: " << *Output << "\n");
1626   });
1627   assert(Outputs.empty() &&
1628          "OpenMP outlining should not produce live-out values!");
1629 
1630   LLVM_DEBUG(dbgs() << "After  privatization: " << *OuterFn << "\n");
1631   LLVM_DEBUG({
1632     for (auto *BB : Blocks)
1633       dbgs() << " PBR: " << BB->getName() << "\n";
1634   });
1635 
1636   // Adjust the finalization stack, verify the adjustment, and call the
1637   // finalize function a last time to finalize values between the pre-fini
1638   // block and the exit block if we left the parallel "the normal way".
1639   auto FiniInfo = FinalizationStack.pop_back_val();
1640   (void)FiniInfo;
1641   assert(FiniInfo.DK == OMPD_parallel &&
1642          "Unexpected finalization stack state!");
1643 
1644   Instruction *PRegPreFiniTI = PRegPreFiniBB->getTerminator();
1645 
1646   InsertPointTy PreFiniIP(PRegPreFiniBB, PRegPreFiniTI->getIterator());
1647   FiniCB(PreFiniIP);
1648 
1649   // Register the outlined info.
1650   addOutlineInfo(std::move(OI));
1651 
1652   InsertPointTy AfterIP(UI->getParent(), UI->getParent()->end());
1653   UI->eraseFromParent();
1654 
1655   return AfterIP;
1656 }
1657 
1658 void OpenMPIRBuilder::emitFlush(const LocationDescription &Loc) {
1659   // Build call void __kmpc_flush(ident_t *loc)
1660   uint32_t SrcLocStrSize;
1661   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1662   Value *Args[] = {getOrCreateIdent(SrcLocStr, SrcLocStrSize)};
1663 
1664   Builder.CreateCall(getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_flush), Args);
1665 }
1666 
1667 void OpenMPIRBuilder::createFlush(const LocationDescription &Loc) {
1668   if (!updateToLocation(Loc))
1669     return;
1670   emitFlush(Loc);
1671 }
1672 
1673 void OpenMPIRBuilder::emitTaskwaitImpl(const LocationDescription &Loc) {
1674   // Build call kmp_int32 __kmpc_omp_taskwait(ident_t *loc, kmp_int32
1675   // global_tid);
1676   uint32_t SrcLocStrSize;
1677   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1678   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1679   Value *Args[] = {Ident, getOrCreateThreadID(Ident)};
1680 
1681   // Ignore return result until untied tasks are supported.
1682   Builder.CreateCall(getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_taskwait),
1683                      Args);
1684 }
1685 
1686 void OpenMPIRBuilder::createTaskwait(const LocationDescription &Loc) {
1687   if (!updateToLocation(Loc))
1688     return;
1689   emitTaskwaitImpl(Loc);
1690 }
1691 
1692 void OpenMPIRBuilder::emitTaskyieldImpl(const LocationDescription &Loc) {
1693   // Build call __kmpc_omp_taskyield(loc, thread_id, 0);
1694   uint32_t SrcLocStrSize;
1695   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1696   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1697   Constant *I32Null = ConstantInt::getNullValue(Int32);
1698   Value *Args[] = {Ident, getOrCreateThreadID(Ident), I32Null};
1699 
1700   Builder.CreateCall(getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_taskyield),
1701                      Args);
1702 }
1703 
1704 void OpenMPIRBuilder::createTaskyield(const LocationDescription &Loc) {
1705   if (!updateToLocation(Loc))
1706     return;
1707   emitTaskyieldImpl(Loc);
1708 }
1709 
1710 // Processes the dependencies in Dependencies and does the following
1711 // - Allocates space on the stack of an array of DependInfo objects
1712 // - Populates each DependInfo object with relevant information of
1713 //   the corresponding dependence.
1714 // - All code is inserted in the entry block of the current function.
1715 static Value *emitTaskDependencies(
1716     OpenMPIRBuilder &OMPBuilder,
1717     SmallVectorImpl<OpenMPIRBuilder::DependData> &Dependencies) {
1718   // Early return if we have no dependencies to process
1719   if (Dependencies.empty())
1720     return nullptr;
1721 
1722   // Given a vector of DependData objects, in this function we create an
1723   // array on the stack that holds kmp_dep_info objects corresponding
1724   // to each dependency. This is then passed to the OpenMP runtime.
1725   // For example, if there are 'n' dependencies then the following psedo
1726   // code is generated. Assume the first dependence is on a variable 'a'
1727   //
1728   // \code{c}
1729   // DepArray = alloc(n x sizeof(kmp_depend_info);
1730   // idx = 0;
1731   // DepArray[idx].base_addr = ptrtoint(&a);
1732   // DepArray[idx].len = 8;
1733   // DepArray[idx].flags = Dep.DepKind; /*(See OMPContants.h for DepKind)*/
1734   // ++idx;
1735   // DepArray[idx].base_addr = ...;
1736   // \endcode
1737 
1738   IRBuilderBase &Builder = OMPBuilder.Builder;
1739   Type *DependInfo = OMPBuilder.DependInfo;
1740   Module &M = OMPBuilder.M;
1741 
1742   Value *DepArray = nullptr;
1743   OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
1744   Builder.SetInsertPoint(
1745       OldIP.getBlock()->getParent()->getEntryBlock().getTerminator());
1746 
1747   Type *DepArrayTy = ArrayType::get(DependInfo, Dependencies.size());
1748   DepArray = Builder.CreateAlloca(DepArrayTy, nullptr, ".dep.arr.addr");
1749 
1750   for (const auto &[DepIdx, Dep] : enumerate(Dependencies)) {
1751     Value *Base =
1752         Builder.CreateConstInBoundsGEP2_64(DepArrayTy, DepArray, 0, DepIdx);
1753     // Store the pointer to the variable
1754     Value *Addr = Builder.CreateStructGEP(
1755         DependInfo, Base,
1756         static_cast<unsigned int>(RTLDependInfoFields::BaseAddr));
1757     Value *DepValPtr = Builder.CreatePtrToInt(Dep.DepVal, Builder.getInt64Ty());
1758     Builder.CreateStore(DepValPtr, Addr);
1759     // Store the size of the variable
1760     Value *Size = Builder.CreateStructGEP(
1761         DependInfo, Base, static_cast<unsigned int>(RTLDependInfoFields::Len));
1762     Builder.CreateStore(
1763         Builder.getInt64(M.getDataLayout().getTypeStoreSize(Dep.DepValueType)),
1764         Size);
1765     // Store the dependency kind
1766     Value *Flags = Builder.CreateStructGEP(
1767         DependInfo, Base,
1768         static_cast<unsigned int>(RTLDependInfoFields::Flags));
1769     Builder.CreateStore(
1770         ConstantInt::get(Builder.getInt8Ty(),
1771                          static_cast<unsigned int>(Dep.DepKind)),
1772         Flags);
1773   }
1774   Builder.restoreIP(OldIP);
1775   return DepArray;
1776 }
1777 
1778 OpenMPIRBuilder::InsertPointTy
1779 OpenMPIRBuilder::createTask(const LocationDescription &Loc,
1780                             InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB,
1781                             bool Tied, Value *Final, Value *IfCondition,
1782                             SmallVector<DependData> Dependencies) {
1783 
1784   if (!updateToLocation(Loc))
1785     return InsertPointTy();
1786 
1787   uint32_t SrcLocStrSize;
1788   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1789   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1790   // The current basic block is split into four basic blocks. After outlining,
1791   // they will be mapped as follows:
1792   // ```
1793   // def current_fn() {
1794   //   current_basic_block:
1795   //     br label %task.exit
1796   //   task.exit:
1797   //     ; instructions after task
1798   // }
1799   // def outlined_fn() {
1800   //   task.alloca:
1801   //     br label %task.body
1802   //   task.body:
1803   //     ret void
1804   // }
1805   // ```
1806   BasicBlock *TaskExitBB = splitBB(Builder, /*CreateBranch=*/true, "task.exit");
1807   BasicBlock *TaskBodyBB = splitBB(Builder, /*CreateBranch=*/true, "task.body");
1808   BasicBlock *TaskAllocaBB =
1809       splitBB(Builder, /*CreateBranch=*/true, "task.alloca");
1810 
1811   InsertPointTy TaskAllocaIP =
1812       InsertPointTy(TaskAllocaBB, TaskAllocaBB->begin());
1813   InsertPointTy TaskBodyIP = InsertPointTy(TaskBodyBB, TaskBodyBB->begin());
1814   BodyGenCB(TaskAllocaIP, TaskBodyIP);
1815 
1816   OutlineInfo OI;
1817   OI.EntryBB = TaskAllocaBB;
1818   OI.OuterAllocaBB = AllocaIP.getBlock();
1819   OI.ExitBB = TaskExitBB;
1820 
1821   // Add the thread ID argument.
1822   SmallVector<Instruction *, 4> ToBeDeleted;
1823   OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
1824       Builder, AllocaIP, ToBeDeleted, TaskAllocaIP, "global.tid", false));
1825 
1826   OI.PostOutlineCB = [this, Ident, Tied, Final, IfCondition, Dependencies,
1827                       TaskAllocaBB, ToBeDeleted](Function &OutlinedFn) mutable {
1828     // Replace the Stale CI by appropriate RTL function call.
1829     assert(OutlinedFn.getNumUses() == 1 &&
1830            "there must be a single user for the outlined function");
1831     CallInst *StaleCI = cast<CallInst>(OutlinedFn.user_back());
1832 
1833     // HasShareds is true if any variables are captured in the outlined region,
1834     // false otherwise.
1835     bool HasShareds = StaleCI->arg_size() > 1;
1836     Builder.SetInsertPoint(StaleCI);
1837 
1838     // Gather the arguments for emitting the runtime call for
1839     // @__kmpc_omp_task_alloc
1840     Function *TaskAllocFn =
1841         getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_alloc);
1842 
1843     // Arguments - `loc_ref` (Ident) and `gtid` (ThreadID)
1844     // call.
1845     Value *ThreadID = getOrCreateThreadID(Ident);
1846 
1847     // Argument - `flags`
1848     // Task is tied iff (Flags & 1) == 1.
1849     // Task is untied iff (Flags & 1) == 0.
1850     // Task is final iff (Flags & 2) == 2.
1851     // Task is not final iff (Flags & 2) == 0.
1852     // TODO: Handle the other flags.
1853     Value *Flags = Builder.getInt32(Tied);
1854     if (Final) {
1855       Value *FinalFlag =
1856           Builder.CreateSelect(Final, Builder.getInt32(2), Builder.getInt32(0));
1857       Flags = Builder.CreateOr(FinalFlag, Flags);
1858     }
1859 
1860     // Argument - `sizeof_kmp_task_t` (TaskSize)
1861     // Tasksize refers to the size in bytes of kmp_task_t data structure
1862     // including private vars accessed in task.
1863     // TODO: add kmp_task_t_with_privates (privates)
1864     Value *TaskSize = Builder.getInt64(
1865         divideCeil(M.getDataLayout().getTypeSizeInBits(Task), 8));
1866 
1867     // Argument - `sizeof_shareds` (SharedsSize)
1868     // SharedsSize refers to the shareds array size in the kmp_task_t data
1869     // structure.
1870     Value *SharedsSize = Builder.getInt64(0);
1871     if (HasShareds) {
1872       AllocaInst *ArgStructAlloca =
1873           dyn_cast<AllocaInst>(StaleCI->getArgOperand(1));
1874       assert(ArgStructAlloca &&
1875              "Unable to find the alloca instruction corresponding to arguments "
1876              "for extracted function");
1877       StructType *ArgStructType =
1878           dyn_cast<StructType>(ArgStructAlloca->getAllocatedType());
1879       assert(ArgStructType && "Unable to find struct type corresponding to "
1880                               "arguments for extracted function");
1881       SharedsSize =
1882           Builder.getInt64(M.getDataLayout().getTypeStoreSize(ArgStructType));
1883     }
1884     // Emit the @__kmpc_omp_task_alloc runtime call
1885     // The runtime call returns a pointer to an area where the task captured
1886     // variables must be copied before the task is run (TaskData)
1887     CallInst *TaskData = Builder.CreateCall(
1888         TaskAllocFn, {/*loc_ref=*/Ident, /*gtid=*/ThreadID, /*flags=*/Flags,
1889                       /*sizeof_task=*/TaskSize, /*sizeof_shared=*/SharedsSize,
1890                       /*task_func=*/&OutlinedFn});
1891 
1892     // Copy the arguments for outlined function
1893     if (HasShareds) {
1894       Value *Shareds = StaleCI->getArgOperand(1);
1895       Align Alignment = TaskData->getPointerAlignment(M.getDataLayout());
1896       Value *TaskShareds = Builder.CreateLoad(VoidPtr, TaskData);
1897       Builder.CreateMemCpy(TaskShareds, Alignment, Shareds, Alignment,
1898                            SharedsSize);
1899     }
1900 
1901     Value *DepArray = nullptr;
1902     if (Dependencies.size()) {
1903       InsertPointTy OldIP = Builder.saveIP();
1904       Builder.SetInsertPoint(
1905           &OldIP.getBlock()->getParent()->getEntryBlock().back());
1906 
1907       Type *DepArrayTy = ArrayType::get(DependInfo, Dependencies.size());
1908       DepArray = Builder.CreateAlloca(DepArrayTy, nullptr, ".dep.arr.addr");
1909 
1910       unsigned P = 0;
1911       for (const DependData &Dep : Dependencies) {
1912         Value *Base =
1913             Builder.CreateConstInBoundsGEP2_64(DepArrayTy, DepArray, 0, P);
1914         // Store the pointer to the variable
1915         Value *Addr = Builder.CreateStructGEP(
1916             DependInfo, Base,
1917             static_cast<unsigned int>(RTLDependInfoFields::BaseAddr));
1918         Value *DepValPtr =
1919             Builder.CreatePtrToInt(Dep.DepVal, Builder.getInt64Ty());
1920         Builder.CreateStore(DepValPtr, Addr);
1921         // Store the size of the variable
1922         Value *Size = Builder.CreateStructGEP(
1923             DependInfo, Base,
1924             static_cast<unsigned int>(RTLDependInfoFields::Len));
1925         Builder.CreateStore(Builder.getInt64(M.getDataLayout().getTypeStoreSize(
1926                                 Dep.DepValueType)),
1927                             Size);
1928         // Store the dependency kind
1929         Value *Flags = Builder.CreateStructGEP(
1930             DependInfo, Base,
1931             static_cast<unsigned int>(RTLDependInfoFields::Flags));
1932         Builder.CreateStore(
1933             ConstantInt::get(Builder.getInt8Ty(),
1934                              static_cast<unsigned int>(Dep.DepKind)),
1935             Flags);
1936         ++P;
1937       }
1938 
1939       Builder.restoreIP(OldIP);
1940     }
1941 
1942     // In the presence of the `if` clause, the following IR is generated:
1943     //    ...
1944     //    %data = call @__kmpc_omp_task_alloc(...)
1945     //    br i1 %if_condition, label %then, label %else
1946     //  then:
1947     //    call @__kmpc_omp_task(...)
1948     //    br label %exit
1949     //  else:
1950     //    ;; Wait for resolution of dependencies, if any, before
1951     //    ;; beginning the task
1952     //    call @__kmpc_omp_wait_deps(...)
1953     //    call @__kmpc_omp_task_begin_if0(...)
1954     //    call @outlined_fn(...)
1955     //    call @__kmpc_omp_task_complete_if0(...)
1956     //    br label %exit
1957     //  exit:
1958     //    ...
1959     if (IfCondition) {
1960       // `SplitBlockAndInsertIfThenElse` requires the block to have a
1961       // terminator.
1962       splitBB(Builder, /*CreateBranch=*/true, "if.end");
1963       Instruction *IfTerminator =
1964           Builder.GetInsertPoint()->getParent()->getTerminator();
1965       Instruction *ThenTI = IfTerminator, *ElseTI = nullptr;
1966       Builder.SetInsertPoint(IfTerminator);
1967       SplitBlockAndInsertIfThenElse(IfCondition, IfTerminator, &ThenTI,
1968                                     &ElseTI);
1969       Builder.SetInsertPoint(ElseTI);
1970 
1971       if (Dependencies.size()) {
1972         Function *TaskWaitFn =
1973             getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_wait_deps);
1974         Builder.CreateCall(
1975             TaskWaitFn,
1976             {Ident, ThreadID, Builder.getInt32(Dependencies.size()), DepArray,
1977              ConstantInt::get(Builder.getInt32Ty(), 0),
1978              ConstantPointerNull::get(PointerType::getUnqual(M.getContext()))});
1979       }
1980       Function *TaskBeginFn =
1981           getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_begin_if0);
1982       Function *TaskCompleteFn =
1983           getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_complete_if0);
1984       Builder.CreateCall(TaskBeginFn, {Ident, ThreadID, TaskData});
1985       CallInst *CI = nullptr;
1986       if (HasShareds)
1987         CI = Builder.CreateCall(&OutlinedFn, {ThreadID, TaskData});
1988       else
1989         CI = Builder.CreateCall(&OutlinedFn, {ThreadID});
1990       CI->setDebugLoc(StaleCI->getDebugLoc());
1991       Builder.CreateCall(TaskCompleteFn, {Ident, ThreadID, TaskData});
1992       Builder.SetInsertPoint(ThenTI);
1993     }
1994 
1995     if (Dependencies.size()) {
1996       Function *TaskFn =
1997           getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_with_deps);
1998       Builder.CreateCall(
1999           TaskFn,
2000           {Ident, ThreadID, TaskData, Builder.getInt32(Dependencies.size()),
2001            DepArray, ConstantInt::get(Builder.getInt32Ty(), 0),
2002            ConstantPointerNull::get(PointerType::getUnqual(M.getContext()))});
2003 
2004     } else {
2005       // Emit the @__kmpc_omp_task runtime call to spawn the task
2006       Function *TaskFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task);
2007       Builder.CreateCall(TaskFn, {Ident, ThreadID, TaskData});
2008     }
2009 
2010     StaleCI->eraseFromParent();
2011 
2012     Builder.SetInsertPoint(TaskAllocaBB, TaskAllocaBB->begin());
2013     if (HasShareds) {
2014       LoadInst *Shareds = Builder.CreateLoad(VoidPtr, OutlinedFn.getArg(1));
2015       OutlinedFn.getArg(1)->replaceUsesWithIf(
2016           Shareds, [Shareds](Use &U) { return U.getUser() != Shareds; });
2017     }
2018 
2019     llvm::for_each(llvm::reverse(ToBeDeleted),
2020                    [](Instruction *I) { I->eraseFromParent(); });
2021   };
2022 
2023   addOutlineInfo(std::move(OI));
2024   Builder.SetInsertPoint(TaskExitBB, TaskExitBB->begin());
2025 
2026   return Builder.saveIP();
2027 }
2028 
2029 OpenMPIRBuilder::InsertPointTy
2030 OpenMPIRBuilder::createTaskgroup(const LocationDescription &Loc,
2031                                  InsertPointTy AllocaIP,
2032                                  BodyGenCallbackTy BodyGenCB) {
2033   if (!updateToLocation(Loc))
2034     return InsertPointTy();
2035 
2036   uint32_t SrcLocStrSize;
2037   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
2038   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
2039   Value *ThreadID = getOrCreateThreadID(Ident);
2040 
2041   // Emit the @__kmpc_taskgroup runtime call to start the taskgroup
2042   Function *TaskgroupFn =
2043       getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_taskgroup);
2044   Builder.CreateCall(TaskgroupFn, {Ident, ThreadID});
2045 
2046   BasicBlock *TaskgroupExitBB = splitBB(Builder, true, "taskgroup.exit");
2047   BodyGenCB(AllocaIP, Builder.saveIP());
2048 
2049   Builder.SetInsertPoint(TaskgroupExitBB);
2050   // Emit the @__kmpc_end_taskgroup runtime call to end the taskgroup
2051   Function *EndTaskgroupFn =
2052       getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_taskgroup);
2053   Builder.CreateCall(EndTaskgroupFn, {Ident, ThreadID});
2054 
2055   return Builder.saveIP();
2056 }
2057 
2058 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createSections(
2059     const LocationDescription &Loc, InsertPointTy AllocaIP,
2060     ArrayRef<StorableBodyGenCallbackTy> SectionCBs, PrivatizeCallbackTy PrivCB,
2061     FinalizeCallbackTy FiniCB, bool IsCancellable, bool IsNowait) {
2062   assert(!isConflictIP(AllocaIP, Loc.IP) && "Dedicated IP allocas required");
2063 
2064   if (!updateToLocation(Loc))
2065     return Loc.IP;
2066 
2067   auto FiniCBWrapper = [&](InsertPointTy IP) {
2068     if (IP.getBlock()->end() != IP.getPoint())
2069       return FiniCB(IP);
2070     // This must be done otherwise any nested constructs using FinalizeOMPRegion
2071     // will fail because that function requires the Finalization Basic Block to
2072     // have a terminator, which is already removed by EmitOMPRegionBody.
2073     // IP is currently at cancelation block.
2074     // We need to backtrack to the condition block to fetch
2075     // the exit block and create a branch from cancelation
2076     // to exit block.
2077     IRBuilder<>::InsertPointGuard IPG(Builder);
2078     Builder.restoreIP(IP);
2079     auto *CaseBB = IP.getBlock()->getSinglePredecessor();
2080     auto *CondBB = CaseBB->getSinglePredecessor()->getSinglePredecessor();
2081     auto *ExitBB = CondBB->getTerminator()->getSuccessor(1);
2082     Instruction *I = Builder.CreateBr(ExitBB);
2083     IP = InsertPointTy(I->getParent(), I->getIterator());
2084     return FiniCB(IP);
2085   };
2086 
2087   FinalizationStack.push_back({FiniCBWrapper, OMPD_sections, IsCancellable});
2088 
2089   // Each section is emitted as a switch case
2090   // Each finalization callback is handled from clang.EmitOMPSectionDirective()
2091   // -> OMP.createSection() which generates the IR for each section
2092   // Iterate through all sections and emit a switch construct:
2093   // switch (IV) {
2094   //   case 0:
2095   //     <SectionStmt[0]>;
2096   //     break;
2097   // ...
2098   //   case <NumSection> - 1:
2099   //     <SectionStmt[<NumSection> - 1]>;
2100   //     break;
2101   // }
2102   // ...
2103   // section_loop.after:
2104   // <FiniCB>;
2105   auto LoopBodyGenCB = [&](InsertPointTy CodeGenIP, Value *IndVar) {
2106     Builder.restoreIP(CodeGenIP);
2107     BasicBlock *Continue =
2108         splitBBWithSuffix(Builder, /*CreateBranch=*/false, ".sections.after");
2109     Function *CurFn = Continue->getParent();
2110     SwitchInst *SwitchStmt = Builder.CreateSwitch(IndVar, Continue);
2111 
2112     unsigned CaseNumber = 0;
2113     for (auto SectionCB : SectionCBs) {
2114       BasicBlock *CaseBB = BasicBlock::Create(
2115           M.getContext(), "omp_section_loop.body.case", CurFn, Continue);
2116       SwitchStmt->addCase(Builder.getInt32(CaseNumber), CaseBB);
2117       Builder.SetInsertPoint(CaseBB);
2118       BranchInst *CaseEndBr = Builder.CreateBr(Continue);
2119       SectionCB(InsertPointTy(),
2120                 {CaseEndBr->getParent(), CaseEndBr->getIterator()});
2121       CaseNumber++;
2122     }
2123     // remove the existing terminator from body BB since there can be no
2124     // terminators after switch/case
2125   };
2126   // Loop body ends here
2127   // LowerBound, UpperBound, and STride for createCanonicalLoop
2128   Type *I32Ty = Type::getInt32Ty(M.getContext());
2129   Value *LB = ConstantInt::get(I32Ty, 0);
2130   Value *UB = ConstantInt::get(I32Ty, SectionCBs.size());
2131   Value *ST = ConstantInt::get(I32Ty, 1);
2132   llvm::CanonicalLoopInfo *LoopInfo = createCanonicalLoop(
2133       Loc, LoopBodyGenCB, LB, UB, ST, true, false, AllocaIP, "section_loop");
2134   InsertPointTy AfterIP =
2135       applyStaticWorkshareLoop(Loc.DL, LoopInfo, AllocaIP, !IsNowait);
2136 
2137   // Apply the finalization callback in LoopAfterBB
2138   auto FiniInfo = FinalizationStack.pop_back_val();
2139   assert(FiniInfo.DK == OMPD_sections &&
2140          "Unexpected finalization stack state!");
2141   if (FinalizeCallbackTy &CB = FiniInfo.FiniCB) {
2142     Builder.restoreIP(AfterIP);
2143     BasicBlock *FiniBB =
2144         splitBBWithSuffix(Builder, /*CreateBranch=*/true, "sections.fini");
2145     CB(Builder.saveIP());
2146     AfterIP = {FiniBB, FiniBB->begin()};
2147   }
2148 
2149   return AfterIP;
2150 }
2151 
2152 OpenMPIRBuilder::InsertPointTy
2153 OpenMPIRBuilder::createSection(const LocationDescription &Loc,
2154                                BodyGenCallbackTy BodyGenCB,
2155                                FinalizeCallbackTy FiniCB) {
2156   if (!updateToLocation(Loc))
2157     return Loc.IP;
2158 
2159   auto FiniCBWrapper = [&](InsertPointTy IP) {
2160     if (IP.getBlock()->end() != IP.getPoint())
2161       return FiniCB(IP);
2162     // This must be done otherwise any nested constructs using FinalizeOMPRegion
2163     // will fail because that function requires the Finalization Basic Block to
2164     // have a terminator, which is already removed by EmitOMPRegionBody.
2165     // IP is currently at cancelation block.
2166     // We need to backtrack to the condition block to fetch
2167     // the exit block and create a branch from cancelation
2168     // to exit block.
2169     IRBuilder<>::InsertPointGuard IPG(Builder);
2170     Builder.restoreIP(IP);
2171     auto *CaseBB = Loc.IP.getBlock();
2172     auto *CondBB = CaseBB->getSinglePredecessor()->getSinglePredecessor();
2173     auto *ExitBB = CondBB->getTerminator()->getSuccessor(1);
2174     Instruction *I = Builder.CreateBr(ExitBB);
2175     IP = InsertPointTy(I->getParent(), I->getIterator());
2176     return FiniCB(IP);
2177   };
2178 
2179   Directive OMPD = Directive::OMPD_sections;
2180   // Since we are using Finalization Callback here, HasFinalize
2181   // and IsCancellable have to be true
2182   return EmitOMPInlinedRegion(OMPD, nullptr, nullptr, BodyGenCB, FiniCBWrapper,
2183                               /*Conditional*/ false, /*hasFinalize*/ true,
2184                               /*IsCancellable*/ true);
2185 }
2186 
2187 static OpenMPIRBuilder::InsertPointTy getInsertPointAfterInstr(Instruction *I) {
2188   BasicBlock::iterator IT(I);
2189   IT++;
2190   return OpenMPIRBuilder::InsertPointTy(I->getParent(), IT);
2191 }
2192 
2193 void OpenMPIRBuilder::emitUsed(StringRef Name,
2194                                std::vector<WeakTrackingVH> &List) {
2195   if (List.empty())
2196     return;
2197 
2198   // Convert List to what ConstantArray needs.
2199   SmallVector<Constant *, 8> UsedArray;
2200   UsedArray.resize(List.size());
2201   for (unsigned I = 0, E = List.size(); I != E; ++I)
2202     UsedArray[I] = ConstantExpr::getPointerBitCastOrAddrSpaceCast(
2203         cast<Constant>(&*List[I]), Builder.getPtrTy());
2204 
2205   if (UsedArray.empty())
2206     return;
2207   ArrayType *ATy = ArrayType::get(Builder.getPtrTy(), UsedArray.size());
2208 
2209   auto *GV = new GlobalVariable(M, ATy, false, GlobalValue::AppendingLinkage,
2210                                 ConstantArray::get(ATy, UsedArray), Name);
2211 
2212   GV->setSection("llvm.metadata");
2213 }
2214 
2215 Value *OpenMPIRBuilder::getGPUThreadID() {
2216   return Builder.CreateCall(
2217       getOrCreateRuntimeFunction(M,
2218                                  OMPRTL___kmpc_get_hardware_thread_id_in_block),
2219       {});
2220 }
2221 
2222 Value *OpenMPIRBuilder::getGPUWarpSize() {
2223   return Builder.CreateCall(
2224       getOrCreateRuntimeFunction(M, OMPRTL___kmpc_get_warp_size), {});
2225 }
2226 
2227 Value *OpenMPIRBuilder::getNVPTXWarpID() {
2228   unsigned LaneIDBits = Log2_32(Config.getGridValue().GV_Warp_Size);
2229   return Builder.CreateAShr(getGPUThreadID(), LaneIDBits, "nvptx_warp_id");
2230 }
2231 
2232 Value *OpenMPIRBuilder::getNVPTXLaneID() {
2233   unsigned LaneIDBits = Log2_32(Config.getGridValue().GV_Warp_Size);
2234   assert(LaneIDBits < 32 && "Invalid LaneIDBits size in NVPTX device.");
2235   unsigned LaneIDMask = ~0u >> (32u - LaneIDBits);
2236   return Builder.CreateAnd(getGPUThreadID(), Builder.getInt32(LaneIDMask),
2237                            "nvptx_lane_id");
2238 }
2239 
2240 Value *OpenMPIRBuilder::castValueToType(InsertPointTy AllocaIP, Value *From,
2241                                         Type *ToType) {
2242   Type *FromType = From->getType();
2243   uint64_t FromSize = M.getDataLayout().getTypeStoreSize(FromType);
2244   uint64_t ToSize = M.getDataLayout().getTypeStoreSize(ToType);
2245   assert(FromSize > 0 && "From size must be greater than zero");
2246   assert(ToSize > 0 && "To size must be greater than zero");
2247   if (FromType == ToType)
2248     return From;
2249   if (FromSize == ToSize)
2250     return Builder.CreateBitCast(From, ToType);
2251   if (ToType->isIntegerTy() && FromType->isIntegerTy())
2252     return Builder.CreateIntCast(From, ToType, /*isSigned*/ true);
2253   InsertPointTy SaveIP = Builder.saveIP();
2254   Builder.restoreIP(AllocaIP);
2255   Value *CastItem = Builder.CreateAlloca(ToType);
2256   Builder.restoreIP(SaveIP);
2257 
2258   Value *ValCastItem = Builder.CreatePointerBitCastOrAddrSpaceCast(
2259       CastItem, FromType->getPointerTo());
2260   Builder.CreateStore(From, ValCastItem);
2261   return Builder.CreateLoad(ToType, CastItem);
2262 }
2263 
2264 Value *OpenMPIRBuilder::createRuntimeShuffleFunction(InsertPointTy AllocaIP,
2265                                                      Value *Element,
2266                                                      Type *ElementType,
2267                                                      Value *Offset) {
2268   uint64_t Size = M.getDataLayout().getTypeStoreSize(ElementType);
2269   assert(Size <= 8 && "Unsupported bitwidth in shuffle instruction");
2270 
2271   // Cast all types to 32- or 64-bit values before calling shuffle routines.
2272   Type *CastTy = Builder.getIntNTy(Size <= 4 ? 32 : 64);
2273   Value *ElemCast = castValueToType(AllocaIP, Element, CastTy);
2274   Value *WarpSize =
2275       Builder.CreateIntCast(getGPUWarpSize(), Builder.getInt16Ty(), true);
2276   Function *ShuffleFunc = getOrCreateRuntimeFunctionPtr(
2277       Size <= 4 ? RuntimeFunction::OMPRTL___kmpc_shuffle_int32
2278                 : RuntimeFunction::OMPRTL___kmpc_shuffle_int64);
2279   Value *WarpSizeCast =
2280       Builder.CreateIntCast(WarpSize, Builder.getInt16Ty(), /*isSigned=*/true);
2281   Value *ShuffleCall =
2282       Builder.CreateCall(ShuffleFunc, {ElemCast, Offset, WarpSizeCast});
2283   return castValueToType(AllocaIP, ShuffleCall, CastTy);
2284 }
2285 
2286 void OpenMPIRBuilder::shuffleAndStore(InsertPointTy AllocaIP, Value *SrcAddr,
2287                                       Value *DstAddr, Type *ElemType,
2288                                       Value *Offset, Type *ReductionArrayTy) {
2289   uint64_t Size = M.getDataLayout().getTypeStoreSize(ElemType);
2290   // Create the loop over the big sized data.
2291   // ptr = (void*)Elem;
2292   // ptrEnd = (void*) Elem + 1;
2293   // Step = 8;
2294   // while (ptr + Step < ptrEnd)
2295   //   shuffle((int64_t)*ptr);
2296   // Step = 4;
2297   // while (ptr + Step < ptrEnd)
2298   //   shuffle((int32_t)*ptr);
2299   // ...
2300   Type *IndexTy = Builder.getIndexTy(
2301       M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
2302   Value *ElemPtr = DstAddr;
2303   Value *Ptr = SrcAddr;
2304   for (unsigned IntSize = 8; IntSize >= 1; IntSize /= 2) {
2305     if (Size < IntSize)
2306       continue;
2307     Type *IntType = Builder.getIntNTy(IntSize * 8);
2308     Ptr = Builder.CreatePointerBitCastOrAddrSpaceCast(
2309         Ptr, IntType->getPointerTo(), Ptr->getName() + ".ascast");
2310     Value *SrcAddrGEP =
2311         Builder.CreateGEP(ElemType, SrcAddr, {ConstantInt::get(IndexTy, 1)});
2312     ElemPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
2313         ElemPtr, IntType->getPointerTo(), ElemPtr->getName() + ".ascast");
2314 
2315     Function *CurFunc = Builder.GetInsertBlock()->getParent();
2316     if ((Size / IntSize) > 1) {
2317       Value *PtrEnd = Builder.CreatePointerBitCastOrAddrSpaceCast(
2318           SrcAddrGEP, Builder.getPtrTy());
2319       BasicBlock *PreCondBB =
2320           BasicBlock::Create(M.getContext(), ".shuffle.pre_cond");
2321       BasicBlock *ThenBB = BasicBlock::Create(M.getContext(), ".shuffle.then");
2322       BasicBlock *ExitBB = BasicBlock::Create(M.getContext(), ".shuffle.exit");
2323       BasicBlock *CurrentBB = Builder.GetInsertBlock();
2324       emitBlock(PreCondBB, CurFunc);
2325       PHINode *PhiSrc =
2326           Builder.CreatePHI(Ptr->getType(), /*NumReservedValues=*/2);
2327       PhiSrc->addIncoming(Ptr, CurrentBB);
2328       PHINode *PhiDest =
2329           Builder.CreatePHI(ElemPtr->getType(), /*NumReservedValues=*/2);
2330       PhiDest->addIncoming(ElemPtr, CurrentBB);
2331       Ptr = PhiSrc;
2332       ElemPtr = PhiDest;
2333       Value *PtrDiff = Builder.CreatePtrDiff(
2334           Builder.getInt8Ty(), PtrEnd,
2335           Builder.CreatePointerBitCastOrAddrSpaceCast(Ptr, Builder.getPtrTy()));
2336       Builder.CreateCondBr(
2337           Builder.CreateICmpSGT(PtrDiff, Builder.getInt64(IntSize - 1)), ThenBB,
2338           ExitBB);
2339       emitBlock(ThenBB, CurFunc);
2340       Value *Res = createRuntimeShuffleFunction(
2341           AllocaIP,
2342           Builder.CreateAlignedLoad(
2343               IntType, Ptr, M.getDataLayout().getPrefTypeAlign(ElemType)),
2344           IntType, Offset);
2345       Builder.CreateAlignedStore(Res, ElemPtr,
2346                                  M.getDataLayout().getPrefTypeAlign(ElemType));
2347       Value *LocalPtr =
2348           Builder.CreateGEP(IntType, Ptr, {ConstantInt::get(IndexTy, 1)});
2349       Value *LocalElemPtr =
2350           Builder.CreateGEP(IntType, ElemPtr, {ConstantInt::get(IndexTy, 1)});
2351       PhiSrc->addIncoming(LocalPtr, ThenBB);
2352       PhiDest->addIncoming(LocalElemPtr, ThenBB);
2353       emitBranch(PreCondBB);
2354       emitBlock(ExitBB, CurFunc);
2355     } else {
2356       Value *Res = createRuntimeShuffleFunction(
2357           AllocaIP, Builder.CreateLoad(IntType, Ptr), IntType, Offset);
2358       if (ElemType->isIntegerTy() && ElemType->getScalarSizeInBits() <
2359                                          Res->getType()->getScalarSizeInBits())
2360         Res = Builder.CreateTrunc(Res, ElemType);
2361       Builder.CreateStore(Res, ElemPtr);
2362       Ptr = Builder.CreateGEP(IntType, Ptr, {ConstantInt::get(IndexTy, 1)});
2363       ElemPtr =
2364           Builder.CreateGEP(IntType, ElemPtr, {ConstantInt::get(IndexTy, 1)});
2365     }
2366     Size = Size % IntSize;
2367   }
2368 }
2369 
2370 void OpenMPIRBuilder::emitReductionListCopy(
2371     InsertPointTy AllocaIP, CopyAction Action, Type *ReductionArrayTy,
2372     ArrayRef<ReductionInfo> ReductionInfos, Value *SrcBase, Value *DestBase,
2373     CopyOptionsTy CopyOptions) {
2374   Type *IndexTy = Builder.getIndexTy(
2375       M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
2376   Value *RemoteLaneOffset = CopyOptions.RemoteLaneOffset;
2377 
2378   // Iterates, element-by-element, through the source Reduce list and
2379   // make a copy.
2380   for (auto En : enumerate(ReductionInfos)) {
2381     const ReductionInfo &RI = En.value();
2382     Value *SrcElementAddr = nullptr;
2383     Value *DestElementAddr = nullptr;
2384     Value *DestElementPtrAddr = nullptr;
2385     // Should we shuffle in an element from a remote lane?
2386     bool ShuffleInElement = false;
2387     // Set to true to update the pointer in the dest Reduce list to a
2388     // newly created element.
2389     bool UpdateDestListPtr = false;
2390 
2391     // Step 1.1: Get the address for the src element in the Reduce list.
2392     Value *SrcElementPtrAddr = Builder.CreateInBoundsGEP(
2393         ReductionArrayTy, SrcBase,
2394         {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
2395     SrcElementAddr = Builder.CreateLoad(Builder.getPtrTy(), SrcElementPtrAddr);
2396 
2397     // Step 1.2: Create a temporary to store the element in the destination
2398     // Reduce list.
2399     DestElementPtrAddr = Builder.CreateInBoundsGEP(
2400         ReductionArrayTy, DestBase,
2401         {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
2402     switch (Action) {
2403     case CopyAction::RemoteLaneToThread: {
2404       InsertPointTy CurIP = Builder.saveIP();
2405       Builder.restoreIP(AllocaIP);
2406       AllocaInst *DestAlloca = Builder.CreateAlloca(RI.ElementType, nullptr,
2407                                                     ".omp.reduction.element");
2408       DestAlloca->setAlignment(
2409           M.getDataLayout().getPrefTypeAlign(RI.ElementType));
2410       DestElementAddr = DestAlloca;
2411       DestElementAddr =
2412           Builder.CreateAddrSpaceCast(DestElementAddr, Builder.getPtrTy(),
2413                                       DestElementAddr->getName() + ".ascast");
2414       Builder.restoreIP(CurIP);
2415       ShuffleInElement = true;
2416       UpdateDestListPtr = true;
2417       break;
2418     }
2419     case CopyAction::ThreadCopy: {
2420       DestElementAddr =
2421           Builder.CreateLoad(Builder.getPtrTy(), DestElementPtrAddr);
2422       break;
2423     }
2424     }
2425 
2426     // Now that all active lanes have read the element in the
2427     // Reduce list, shuffle over the value from the remote lane.
2428     if (ShuffleInElement) {
2429       shuffleAndStore(AllocaIP, SrcElementAddr, DestElementAddr, RI.ElementType,
2430                       RemoteLaneOffset, ReductionArrayTy);
2431     } else {
2432       switch (RI.EvaluationKind) {
2433       case EvalKind::Scalar: {
2434         Value *Elem = Builder.CreateLoad(RI.ElementType, SrcElementAddr);
2435         // Store the source element value to the dest element address.
2436         Builder.CreateStore(Elem, DestElementAddr);
2437         break;
2438       }
2439       case EvalKind::Complex: {
2440         Value *SrcRealPtr = Builder.CreateConstInBoundsGEP2_32(
2441             RI.ElementType, SrcElementAddr, 0, 0, ".realp");
2442         Value *SrcReal = Builder.CreateLoad(
2443             RI.ElementType->getStructElementType(0), SrcRealPtr, ".real");
2444         Value *SrcImgPtr = Builder.CreateConstInBoundsGEP2_32(
2445             RI.ElementType, SrcElementAddr, 0, 1, ".imagp");
2446         Value *SrcImg = Builder.CreateLoad(
2447             RI.ElementType->getStructElementType(1), SrcImgPtr, ".imag");
2448 
2449         Value *DestRealPtr = Builder.CreateConstInBoundsGEP2_32(
2450             RI.ElementType, DestElementAddr, 0, 0, ".realp");
2451         Value *DestImgPtr = Builder.CreateConstInBoundsGEP2_32(
2452             RI.ElementType, DestElementAddr, 0, 1, ".imagp");
2453         Builder.CreateStore(SrcReal, DestRealPtr);
2454         Builder.CreateStore(SrcImg, DestImgPtr);
2455         break;
2456       }
2457       case EvalKind::Aggregate: {
2458         Value *SizeVal = Builder.getInt64(
2459             M.getDataLayout().getTypeStoreSize(RI.ElementType));
2460         Builder.CreateMemCpy(
2461             DestElementAddr, M.getDataLayout().getPrefTypeAlign(RI.ElementType),
2462             SrcElementAddr, M.getDataLayout().getPrefTypeAlign(RI.ElementType),
2463             SizeVal, false);
2464         break;
2465       }
2466       };
2467     }
2468 
2469     // Step 3.1: Modify reference in dest Reduce list as needed.
2470     // Modifying the reference in Reduce list to point to the newly
2471     // created element.  The element is live in the current function
2472     // scope and that of functions it invokes (i.e., reduce_function).
2473     // RemoteReduceData[i] = (void*)&RemoteElem
2474     if (UpdateDestListPtr) {
2475       Value *CastDestAddr = Builder.CreatePointerBitCastOrAddrSpaceCast(
2476           DestElementAddr, Builder.getPtrTy(),
2477           DestElementAddr->getName() + ".ascast");
2478       Builder.CreateStore(CastDestAddr, DestElementPtrAddr);
2479     }
2480   }
2481 }
2482 
2483 Function *OpenMPIRBuilder::emitInterWarpCopyFunction(
2484     const LocationDescription &Loc, ArrayRef<ReductionInfo> ReductionInfos,
2485     AttributeList FuncAttrs) {
2486   InsertPointTy SavedIP = Builder.saveIP();
2487   LLVMContext &Ctx = M.getContext();
2488   FunctionType *FuncTy = FunctionType::get(
2489       Builder.getVoidTy(), {Builder.getPtrTy(), Builder.getInt32Ty()},
2490       /* IsVarArg */ false);
2491   Function *WcFunc =
2492       Function::Create(FuncTy, GlobalVariable::InternalLinkage,
2493                        "_omp_reduction_inter_warp_copy_func", &M);
2494   WcFunc->setAttributes(FuncAttrs);
2495   WcFunc->addParamAttr(0, Attribute::NoUndef);
2496   WcFunc->addParamAttr(1, Attribute::NoUndef);
2497   BasicBlock *EntryBB = BasicBlock::Create(M.getContext(), "entry", WcFunc);
2498   Builder.SetInsertPoint(EntryBB);
2499 
2500   // ReduceList: thread local Reduce list.
2501   // At the stage of the computation when this function is called, partially
2502   // aggregated values reside in the first lane of every active warp.
2503   Argument *ReduceListArg = WcFunc->getArg(0);
2504   // NumWarps: number of warps active in the parallel region.  This could
2505   // be smaller than 32 (max warps in a CTA) for partial block reduction.
2506   Argument *NumWarpsArg = WcFunc->getArg(1);
2507 
2508   // This array is used as a medium to transfer, one reduce element at a time,
2509   // the data from the first lane of every warp to lanes in the first warp
2510   // in order to perform the final step of a reduction in a parallel region
2511   // (reduction across warps).  The array is placed in NVPTX __shared__ memory
2512   // for reduced latency, as well as to have a distinct copy for concurrently
2513   // executing target regions.  The array is declared with common linkage so
2514   // as to be shared across compilation units.
2515   StringRef TransferMediumName =
2516       "__openmp_nvptx_data_transfer_temporary_storage";
2517   GlobalVariable *TransferMedium = M.getGlobalVariable(TransferMediumName);
2518   unsigned WarpSize = Config.getGridValue().GV_Warp_Size;
2519   ArrayType *ArrayTy = ArrayType::get(Builder.getInt32Ty(), WarpSize);
2520   if (!TransferMedium) {
2521     TransferMedium = new GlobalVariable(
2522         M, ArrayTy, /*isConstant=*/false, GlobalVariable::WeakAnyLinkage,
2523         UndefValue::get(ArrayTy), TransferMediumName,
2524         /*InsertBefore=*/nullptr, GlobalVariable::NotThreadLocal,
2525         /*AddressSpace=*/3);
2526   }
2527 
2528   // Get the CUDA thread id of the current OpenMP thread on the GPU.
2529   Value *GPUThreadID = getGPUThreadID();
2530   // nvptx_lane_id = nvptx_id % warpsize
2531   Value *LaneID = getNVPTXLaneID();
2532   // nvptx_warp_id = nvptx_id / warpsize
2533   Value *WarpID = getNVPTXWarpID();
2534 
2535   InsertPointTy AllocaIP =
2536       InsertPointTy(Builder.GetInsertBlock(),
2537                     Builder.GetInsertBlock()->getFirstInsertionPt());
2538   Type *Arg0Type = ReduceListArg->getType();
2539   Type *Arg1Type = NumWarpsArg->getType();
2540   Builder.restoreIP(AllocaIP);
2541   AllocaInst *ReduceListAlloca = Builder.CreateAlloca(
2542       Arg0Type, nullptr, ReduceListArg->getName() + ".addr");
2543   AllocaInst *NumWarpsAlloca =
2544       Builder.CreateAlloca(Arg1Type, nullptr, NumWarpsArg->getName() + ".addr");
2545   Value *ReduceListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2546       ReduceListAlloca, Arg0Type, ReduceListAlloca->getName() + ".ascast");
2547   Value *NumWarpsAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2548       NumWarpsAlloca, Arg1Type->getPointerTo(),
2549       NumWarpsAlloca->getName() + ".ascast");
2550   Builder.CreateStore(ReduceListArg, ReduceListAddrCast);
2551   Builder.CreateStore(NumWarpsArg, NumWarpsAddrCast);
2552   AllocaIP = getInsertPointAfterInstr(NumWarpsAlloca);
2553   InsertPointTy CodeGenIP =
2554       getInsertPointAfterInstr(&Builder.GetInsertBlock()->back());
2555   Builder.restoreIP(CodeGenIP);
2556 
2557   Value *ReduceList =
2558       Builder.CreateLoad(Builder.getPtrTy(), ReduceListAddrCast);
2559 
2560   for (auto En : enumerate(ReductionInfos)) {
2561     //
2562     // Warp master copies reduce element to transfer medium in __shared__
2563     // memory.
2564     //
2565     const ReductionInfo &RI = En.value();
2566     unsigned RealTySize = M.getDataLayout().getTypeAllocSize(RI.ElementType);
2567     for (unsigned TySize = 4; TySize > 0 && RealTySize > 0; TySize /= 2) {
2568       Type *CType = Builder.getIntNTy(TySize * 8);
2569 
2570       unsigned NumIters = RealTySize / TySize;
2571       if (NumIters == 0)
2572         continue;
2573       Value *Cnt = nullptr;
2574       Value *CntAddr = nullptr;
2575       BasicBlock *PrecondBB = nullptr;
2576       BasicBlock *ExitBB = nullptr;
2577       if (NumIters > 1) {
2578         CodeGenIP = Builder.saveIP();
2579         Builder.restoreIP(AllocaIP);
2580         CntAddr =
2581             Builder.CreateAlloca(Builder.getInt32Ty(), nullptr, ".cnt.addr");
2582 
2583         CntAddr = Builder.CreateAddrSpaceCast(CntAddr, Builder.getPtrTy(),
2584                                               CntAddr->getName() + ".ascast");
2585         Builder.restoreIP(CodeGenIP);
2586         Builder.CreateStore(Constant::getNullValue(Builder.getInt32Ty()),
2587                             CntAddr,
2588                             /*Volatile=*/false);
2589         PrecondBB = BasicBlock::Create(Ctx, "precond");
2590         ExitBB = BasicBlock::Create(Ctx, "exit");
2591         BasicBlock *BodyBB = BasicBlock::Create(Ctx, "body");
2592         emitBlock(PrecondBB, Builder.GetInsertBlock()->getParent());
2593         Cnt = Builder.CreateLoad(Builder.getInt32Ty(), CntAddr,
2594                                  /*Volatile=*/false);
2595         Value *Cmp = Builder.CreateICmpULT(
2596             Cnt, ConstantInt::get(Builder.getInt32Ty(), NumIters));
2597         Builder.CreateCondBr(Cmp, BodyBB, ExitBB);
2598         emitBlock(BodyBB, Builder.GetInsertBlock()->getParent());
2599       }
2600 
2601       // kmpc_barrier.
2602       createBarrier(LocationDescription(Builder.saveIP(), Loc.DL),
2603                     omp::Directive::OMPD_unknown,
2604                     /* ForceSimpleCall */ false,
2605                     /* CheckCancelFlag */ true);
2606       BasicBlock *ThenBB = BasicBlock::Create(Ctx, "then");
2607       BasicBlock *ElseBB = BasicBlock::Create(Ctx, "else");
2608       BasicBlock *MergeBB = BasicBlock::Create(Ctx, "ifcont");
2609 
2610       // if (lane_id  == 0)
2611       Value *IsWarpMaster = Builder.CreateIsNull(LaneID, "warp_master");
2612       Builder.CreateCondBr(IsWarpMaster, ThenBB, ElseBB);
2613       emitBlock(ThenBB, Builder.GetInsertBlock()->getParent());
2614 
2615       // Reduce element = LocalReduceList[i]
2616       auto *RedListArrayTy =
2617           ArrayType::get(Builder.getPtrTy(), ReductionInfos.size());
2618       Type *IndexTy = Builder.getIndexTy(
2619           M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
2620       Value *ElemPtrPtr =
2621           Builder.CreateInBoundsGEP(RedListArrayTy, ReduceList,
2622                                     {ConstantInt::get(IndexTy, 0),
2623                                      ConstantInt::get(IndexTy, En.index())});
2624       // elemptr = ((CopyType*)(elemptrptr)) + I
2625       Value *ElemPtr = Builder.CreateLoad(Builder.getPtrTy(), ElemPtrPtr);
2626       if (NumIters > 1)
2627         ElemPtr = Builder.CreateGEP(Builder.getInt32Ty(), ElemPtr, Cnt);
2628 
2629       // Get pointer to location in transfer medium.
2630       // MediumPtr = &medium[warp_id]
2631       Value *MediumPtr = Builder.CreateInBoundsGEP(
2632           ArrayTy, TransferMedium, {Builder.getInt64(0), WarpID});
2633       // elem = *elemptr
2634       //*MediumPtr = elem
2635       Value *Elem = Builder.CreateLoad(CType, ElemPtr);
2636       // Store the source element value to the dest element address.
2637       Builder.CreateStore(Elem, MediumPtr,
2638                           /*IsVolatile*/ true);
2639       Builder.CreateBr(MergeBB);
2640 
2641       // else
2642       emitBlock(ElseBB, Builder.GetInsertBlock()->getParent());
2643       Builder.CreateBr(MergeBB);
2644 
2645       // endif
2646       emitBlock(MergeBB, Builder.GetInsertBlock()->getParent());
2647       createBarrier(LocationDescription(Builder.saveIP(), Loc.DL),
2648                     omp::Directive::OMPD_unknown,
2649                     /* ForceSimpleCall */ false,
2650                     /* CheckCancelFlag */ true);
2651 
2652       // Warp 0 copies reduce element from transfer medium
2653       BasicBlock *W0ThenBB = BasicBlock::Create(Ctx, "then");
2654       BasicBlock *W0ElseBB = BasicBlock::Create(Ctx, "else");
2655       BasicBlock *W0MergeBB = BasicBlock::Create(Ctx, "ifcont");
2656 
2657       Value *NumWarpsVal =
2658           Builder.CreateLoad(Builder.getInt32Ty(), NumWarpsAddrCast);
2659       // Up to 32 threads in warp 0 are active.
2660       Value *IsActiveThread =
2661           Builder.CreateICmpULT(GPUThreadID, NumWarpsVal, "is_active_thread");
2662       Builder.CreateCondBr(IsActiveThread, W0ThenBB, W0ElseBB);
2663 
2664       emitBlock(W0ThenBB, Builder.GetInsertBlock()->getParent());
2665 
2666       // SecMediumPtr = &medium[tid]
2667       // SrcMediumVal = *SrcMediumPtr
2668       Value *SrcMediumPtrVal = Builder.CreateInBoundsGEP(
2669           ArrayTy, TransferMedium, {Builder.getInt64(0), GPUThreadID});
2670       // TargetElemPtr = (CopyType*)(SrcDataAddr[i]) + I
2671       Value *TargetElemPtrPtr =
2672           Builder.CreateInBoundsGEP(RedListArrayTy, ReduceList,
2673                                     {ConstantInt::get(IndexTy, 0),
2674                                      ConstantInt::get(IndexTy, En.index())});
2675       Value *TargetElemPtrVal =
2676           Builder.CreateLoad(Builder.getPtrTy(), TargetElemPtrPtr);
2677       Value *TargetElemPtr = TargetElemPtrVal;
2678       if (NumIters > 1)
2679         TargetElemPtr =
2680             Builder.CreateGEP(Builder.getInt32Ty(), TargetElemPtr, Cnt);
2681 
2682       // *TargetElemPtr = SrcMediumVal;
2683       Value *SrcMediumValue =
2684           Builder.CreateLoad(CType, SrcMediumPtrVal, /*IsVolatile*/ true);
2685       Builder.CreateStore(SrcMediumValue, TargetElemPtr);
2686       Builder.CreateBr(W0MergeBB);
2687 
2688       emitBlock(W0ElseBB, Builder.GetInsertBlock()->getParent());
2689       Builder.CreateBr(W0MergeBB);
2690 
2691       emitBlock(W0MergeBB, Builder.GetInsertBlock()->getParent());
2692 
2693       if (NumIters > 1) {
2694         Cnt = Builder.CreateNSWAdd(
2695             Cnt, ConstantInt::get(Builder.getInt32Ty(), /*V=*/1));
2696         Builder.CreateStore(Cnt, CntAddr, /*Volatile=*/false);
2697 
2698         auto *CurFn = Builder.GetInsertBlock()->getParent();
2699         emitBranch(PrecondBB);
2700         emitBlock(ExitBB, CurFn);
2701       }
2702       RealTySize %= TySize;
2703     }
2704   }
2705 
2706   Builder.CreateRetVoid();
2707   Builder.restoreIP(SavedIP);
2708 
2709   return WcFunc;
2710 }
2711 
2712 Function *OpenMPIRBuilder::emitShuffleAndReduceFunction(
2713     ArrayRef<ReductionInfo> ReductionInfos, Function *ReduceFn,
2714     AttributeList FuncAttrs) {
2715   LLVMContext &Ctx = M.getContext();
2716   FunctionType *FuncTy =
2717       FunctionType::get(Builder.getVoidTy(),
2718                         {Builder.getPtrTy(), Builder.getInt16Ty(),
2719                          Builder.getInt16Ty(), Builder.getInt16Ty()},
2720                         /* IsVarArg */ false);
2721   Function *SarFunc =
2722       Function::Create(FuncTy, GlobalVariable::InternalLinkage,
2723                        "_omp_reduction_shuffle_and_reduce_func", &M);
2724   SarFunc->setAttributes(FuncAttrs);
2725   SarFunc->addParamAttr(0, Attribute::NoUndef);
2726   SarFunc->addParamAttr(1, Attribute::NoUndef);
2727   SarFunc->addParamAttr(2, Attribute::NoUndef);
2728   SarFunc->addParamAttr(3, Attribute::NoUndef);
2729   SarFunc->addParamAttr(1, Attribute::SExt);
2730   SarFunc->addParamAttr(2, Attribute::SExt);
2731   SarFunc->addParamAttr(3, Attribute::SExt);
2732   BasicBlock *EntryBB = BasicBlock::Create(M.getContext(), "entry", SarFunc);
2733   Builder.SetInsertPoint(EntryBB);
2734 
2735   // Thread local Reduce list used to host the values of data to be reduced.
2736   Argument *ReduceListArg = SarFunc->getArg(0);
2737   // Current lane id; could be logical.
2738   Argument *LaneIDArg = SarFunc->getArg(1);
2739   // Offset of the remote source lane relative to the current lane.
2740   Argument *RemoteLaneOffsetArg = SarFunc->getArg(2);
2741   // Algorithm version.  This is expected to be known at compile time.
2742   Argument *AlgoVerArg = SarFunc->getArg(3);
2743 
2744   Type *ReduceListArgType = ReduceListArg->getType();
2745   Type *LaneIDArgType = LaneIDArg->getType();
2746   Type *LaneIDArgPtrType = LaneIDArg->getType()->getPointerTo();
2747   Value *ReduceListAlloca = Builder.CreateAlloca(
2748       ReduceListArgType, nullptr, ReduceListArg->getName() + ".addr");
2749   Value *LaneIdAlloca = Builder.CreateAlloca(LaneIDArgType, nullptr,
2750                                              LaneIDArg->getName() + ".addr");
2751   Value *RemoteLaneOffsetAlloca = Builder.CreateAlloca(
2752       LaneIDArgType, nullptr, RemoteLaneOffsetArg->getName() + ".addr");
2753   Value *AlgoVerAlloca = Builder.CreateAlloca(LaneIDArgType, nullptr,
2754                                               AlgoVerArg->getName() + ".addr");
2755   ArrayType *RedListArrayTy =
2756       ArrayType::get(Builder.getPtrTy(), ReductionInfos.size());
2757 
2758   // Create a local thread-private variable to host the Reduce list
2759   // from a remote lane.
2760   Instruction *RemoteReductionListAlloca = Builder.CreateAlloca(
2761       RedListArrayTy, nullptr, ".omp.reduction.remote_reduce_list");
2762 
2763   Value *ReduceListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2764       ReduceListAlloca, ReduceListArgType,
2765       ReduceListAlloca->getName() + ".ascast");
2766   Value *LaneIdAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2767       LaneIdAlloca, LaneIDArgPtrType, LaneIdAlloca->getName() + ".ascast");
2768   Value *RemoteLaneOffsetAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2769       RemoteLaneOffsetAlloca, LaneIDArgPtrType,
2770       RemoteLaneOffsetAlloca->getName() + ".ascast");
2771   Value *AlgoVerAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2772       AlgoVerAlloca, LaneIDArgPtrType, AlgoVerAlloca->getName() + ".ascast");
2773   Value *RemoteListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2774       RemoteReductionListAlloca, Builder.getPtrTy(),
2775       RemoteReductionListAlloca->getName() + ".ascast");
2776 
2777   Builder.CreateStore(ReduceListArg, ReduceListAddrCast);
2778   Builder.CreateStore(LaneIDArg, LaneIdAddrCast);
2779   Builder.CreateStore(RemoteLaneOffsetArg, RemoteLaneOffsetAddrCast);
2780   Builder.CreateStore(AlgoVerArg, AlgoVerAddrCast);
2781 
2782   Value *ReduceList = Builder.CreateLoad(ReduceListArgType, ReduceListAddrCast);
2783   Value *LaneId = Builder.CreateLoad(LaneIDArgType, LaneIdAddrCast);
2784   Value *RemoteLaneOffset =
2785       Builder.CreateLoad(LaneIDArgType, RemoteLaneOffsetAddrCast);
2786   Value *AlgoVer = Builder.CreateLoad(LaneIDArgType, AlgoVerAddrCast);
2787 
2788   InsertPointTy AllocaIP = getInsertPointAfterInstr(RemoteReductionListAlloca);
2789 
2790   // This loop iterates through the list of reduce elements and copies,
2791   // element by element, from a remote lane in the warp to RemoteReduceList,
2792   // hosted on the thread's stack.
2793   emitReductionListCopy(
2794       AllocaIP, CopyAction::RemoteLaneToThread, RedListArrayTy, ReductionInfos,
2795       ReduceList, RemoteListAddrCast, {RemoteLaneOffset, nullptr, nullptr});
2796 
2797   // The actions to be performed on the Remote Reduce list is dependent
2798   // on the algorithm version.
2799   //
2800   //  if (AlgoVer==0) || (AlgoVer==1 && (LaneId < Offset)) || (AlgoVer==2 &&
2801   //  LaneId % 2 == 0 && Offset > 0):
2802   //    do the reduction value aggregation
2803   //
2804   //  The thread local variable Reduce list is mutated in place to host the
2805   //  reduced data, which is the aggregated value produced from local and
2806   //  remote lanes.
2807   //
2808   //  Note that AlgoVer is expected to be a constant integer known at compile
2809   //  time.
2810   //  When AlgoVer==0, the first conjunction evaluates to true, making
2811   //    the entire predicate true during compile time.
2812   //  When AlgoVer==1, the second conjunction has only the second part to be
2813   //    evaluated during runtime.  Other conjunctions evaluates to false
2814   //    during compile time.
2815   //  When AlgoVer==2, the third conjunction has only the second part to be
2816   //    evaluated during runtime.  Other conjunctions evaluates to false
2817   //    during compile time.
2818   Value *CondAlgo0 = Builder.CreateIsNull(AlgoVer);
2819   Value *Algo1 = Builder.CreateICmpEQ(AlgoVer, Builder.getInt16(1));
2820   Value *LaneComp = Builder.CreateICmpULT(LaneId, RemoteLaneOffset);
2821   Value *CondAlgo1 = Builder.CreateAnd(Algo1, LaneComp);
2822   Value *Algo2 = Builder.CreateICmpEQ(AlgoVer, Builder.getInt16(2));
2823   Value *LaneIdAnd1 = Builder.CreateAnd(LaneId, Builder.getInt16(1));
2824   Value *LaneIdComp = Builder.CreateIsNull(LaneIdAnd1);
2825   Value *Algo2AndLaneIdComp = Builder.CreateAnd(Algo2, LaneIdComp);
2826   Value *RemoteOffsetComp =
2827       Builder.CreateICmpSGT(RemoteLaneOffset, Builder.getInt16(0));
2828   Value *CondAlgo2 = Builder.CreateAnd(Algo2AndLaneIdComp, RemoteOffsetComp);
2829   Value *CA0OrCA1 = Builder.CreateOr(CondAlgo0, CondAlgo1);
2830   Value *CondReduce = Builder.CreateOr(CA0OrCA1, CondAlgo2);
2831 
2832   BasicBlock *ThenBB = BasicBlock::Create(Ctx, "then");
2833   BasicBlock *ElseBB = BasicBlock::Create(Ctx, "else");
2834   BasicBlock *MergeBB = BasicBlock::Create(Ctx, "ifcont");
2835 
2836   Builder.CreateCondBr(CondReduce, ThenBB, ElseBB);
2837   emitBlock(ThenBB, Builder.GetInsertBlock()->getParent());
2838   Value *LocalReduceListPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
2839       ReduceList, Builder.getPtrTy());
2840   Value *RemoteReduceListPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
2841       RemoteListAddrCast, Builder.getPtrTy());
2842   Builder.CreateCall(ReduceFn, {LocalReduceListPtr, RemoteReduceListPtr})
2843       ->addFnAttr(Attribute::NoUnwind);
2844   Builder.CreateBr(MergeBB);
2845 
2846   emitBlock(ElseBB, Builder.GetInsertBlock()->getParent());
2847   Builder.CreateBr(MergeBB);
2848 
2849   emitBlock(MergeBB, Builder.GetInsertBlock()->getParent());
2850 
2851   // if (AlgoVer==1 && (LaneId >= Offset)) copy Remote Reduce list to local
2852   // Reduce list.
2853   Algo1 = Builder.CreateICmpEQ(AlgoVer, Builder.getInt16(1));
2854   Value *LaneIdGtOffset = Builder.CreateICmpUGE(LaneId, RemoteLaneOffset);
2855   Value *CondCopy = Builder.CreateAnd(Algo1, LaneIdGtOffset);
2856 
2857   BasicBlock *CpyThenBB = BasicBlock::Create(Ctx, "then");
2858   BasicBlock *CpyElseBB = BasicBlock::Create(Ctx, "else");
2859   BasicBlock *CpyMergeBB = BasicBlock::Create(Ctx, "ifcont");
2860   Builder.CreateCondBr(CondCopy, CpyThenBB, CpyElseBB);
2861 
2862   emitBlock(CpyThenBB, Builder.GetInsertBlock()->getParent());
2863   emitReductionListCopy(AllocaIP, CopyAction::ThreadCopy, RedListArrayTy,
2864                         ReductionInfos, RemoteListAddrCast, ReduceList);
2865   Builder.CreateBr(CpyMergeBB);
2866 
2867   emitBlock(CpyElseBB, Builder.GetInsertBlock()->getParent());
2868   Builder.CreateBr(CpyMergeBB);
2869 
2870   emitBlock(CpyMergeBB, Builder.GetInsertBlock()->getParent());
2871 
2872   Builder.CreateRetVoid();
2873 
2874   return SarFunc;
2875 }
2876 
2877 Function *OpenMPIRBuilder::emitListToGlobalCopyFunction(
2878     ArrayRef<ReductionInfo> ReductionInfos, Type *ReductionsBufferTy,
2879     AttributeList FuncAttrs) {
2880   OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
2881   LLVMContext &Ctx = M.getContext();
2882   FunctionType *FuncTy = FunctionType::get(
2883       Builder.getVoidTy(),
2884       {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
2885       /* IsVarArg */ false);
2886   Function *LtGCFunc =
2887       Function::Create(FuncTy, GlobalVariable::InternalLinkage,
2888                        "_omp_reduction_list_to_global_copy_func", &M);
2889   LtGCFunc->setAttributes(FuncAttrs);
2890   LtGCFunc->addParamAttr(0, Attribute::NoUndef);
2891   LtGCFunc->addParamAttr(1, Attribute::NoUndef);
2892   LtGCFunc->addParamAttr(2, Attribute::NoUndef);
2893 
2894   BasicBlock *EntryBlock = BasicBlock::Create(Ctx, "entry", LtGCFunc);
2895   Builder.SetInsertPoint(EntryBlock);
2896 
2897   // Buffer: global reduction buffer.
2898   Argument *BufferArg = LtGCFunc->getArg(0);
2899   // Idx: index of the buffer.
2900   Argument *IdxArg = LtGCFunc->getArg(1);
2901   // ReduceList: thread local Reduce list.
2902   Argument *ReduceListArg = LtGCFunc->getArg(2);
2903 
2904   Value *BufferArgAlloca = Builder.CreateAlloca(Builder.getPtrTy(), nullptr,
2905                                                 BufferArg->getName() + ".addr");
2906   Value *IdxArgAlloca = Builder.CreateAlloca(Builder.getInt32Ty(), nullptr,
2907                                              IdxArg->getName() + ".addr");
2908   Value *ReduceListArgAlloca = Builder.CreateAlloca(
2909       Builder.getPtrTy(), nullptr, ReduceListArg->getName() + ".addr");
2910   Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2911       BufferArgAlloca, Builder.getPtrTy(),
2912       BufferArgAlloca->getName() + ".ascast");
2913   Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2914       IdxArgAlloca, Builder.getPtrTy(), IdxArgAlloca->getName() + ".ascast");
2915   Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2916       ReduceListArgAlloca, Builder.getPtrTy(),
2917       ReduceListArgAlloca->getName() + ".ascast");
2918 
2919   Builder.CreateStore(BufferArg, BufferArgAddrCast);
2920   Builder.CreateStore(IdxArg, IdxArgAddrCast);
2921   Builder.CreateStore(ReduceListArg, ReduceListArgAddrCast);
2922 
2923   Value *LocalReduceList =
2924       Builder.CreateLoad(Builder.getPtrTy(), ReduceListArgAddrCast);
2925   Value *BufferArgVal =
2926       Builder.CreateLoad(Builder.getPtrTy(), BufferArgAddrCast);
2927   Value *Idxs[] = {Builder.CreateLoad(Builder.getInt32Ty(), IdxArgAddrCast)};
2928   Type *IndexTy = Builder.getIndexTy(
2929       M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
2930   for (auto En : enumerate(ReductionInfos)) {
2931     const ReductionInfo &RI = En.value();
2932     auto *RedListArrayTy =
2933         ArrayType::get(Builder.getPtrTy(), ReductionInfos.size());
2934     // Reduce element = LocalReduceList[i]
2935     Value *ElemPtrPtr = Builder.CreateInBoundsGEP(
2936         RedListArrayTy, LocalReduceList,
2937         {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
2938     // elemptr = ((CopyType*)(elemptrptr)) + I
2939     Value *ElemPtr = Builder.CreateLoad(Builder.getPtrTy(), ElemPtrPtr);
2940 
2941     // Global = Buffer.VD[Idx];
2942     Value *BufferVD =
2943         Builder.CreateInBoundsGEP(ReductionsBufferTy, BufferArgVal, Idxs);
2944     Value *GlobVal = Builder.CreateConstInBoundsGEP2_32(
2945         ReductionsBufferTy, BufferVD, 0, En.index());
2946 
2947     switch (RI.EvaluationKind) {
2948     case EvalKind::Scalar: {
2949       Value *TargetElement = Builder.CreateLoad(RI.ElementType, ElemPtr);
2950       Builder.CreateStore(TargetElement, GlobVal);
2951       break;
2952     }
2953     case EvalKind::Complex: {
2954       Value *SrcRealPtr = Builder.CreateConstInBoundsGEP2_32(
2955           RI.ElementType, ElemPtr, 0, 0, ".realp");
2956       Value *SrcReal = Builder.CreateLoad(
2957           RI.ElementType->getStructElementType(0), SrcRealPtr, ".real");
2958       Value *SrcImgPtr = Builder.CreateConstInBoundsGEP2_32(
2959           RI.ElementType, ElemPtr, 0, 1, ".imagp");
2960       Value *SrcImg = Builder.CreateLoad(
2961           RI.ElementType->getStructElementType(1), SrcImgPtr, ".imag");
2962 
2963       Value *DestRealPtr = Builder.CreateConstInBoundsGEP2_32(
2964           RI.ElementType, GlobVal, 0, 0, ".realp");
2965       Value *DestImgPtr = Builder.CreateConstInBoundsGEP2_32(
2966           RI.ElementType, GlobVal, 0, 1, ".imagp");
2967       Builder.CreateStore(SrcReal, DestRealPtr);
2968       Builder.CreateStore(SrcImg, DestImgPtr);
2969       break;
2970     }
2971     case EvalKind::Aggregate: {
2972       Value *SizeVal =
2973           Builder.getInt64(M.getDataLayout().getTypeStoreSize(RI.ElementType));
2974       Builder.CreateMemCpy(
2975           GlobVal, M.getDataLayout().getPrefTypeAlign(RI.ElementType), ElemPtr,
2976           M.getDataLayout().getPrefTypeAlign(RI.ElementType), SizeVal, false);
2977       break;
2978     }
2979     }
2980   }
2981 
2982   Builder.CreateRetVoid();
2983   Builder.restoreIP(OldIP);
2984   return LtGCFunc;
2985 }
2986 
2987 Function *OpenMPIRBuilder::emitListToGlobalReduceFunction(
2988     ArrayRef<ReductionInfo> ReductionInfos, Function *ReduceFn,
2989     Type *ReductionsBufferTy, AttributeList FuncAttrs) {
2990   OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
2991   LLVMContext &Ctx = M.getContext();
2992   FunctionType *FuncTy = FunctionType::get(
2993       Builder.getVoidTy(),
2994       {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
2995       /* IsVarArg */ false);
2996   Function *LtGRFunc =
2997       Function::Create(FuncTy, GlobalVariable::InternalLinkage,
2998                        "_omp_reduction_list_to_global_reduce_func", &M);
2999   LtGRFunc->setAttributes(FuncAttrs);
3000   LtGRFunc->addParamAttr(0, Attribute::NoUndef);
3001   LtGRFunc->addParamAttr(1, Attribute::NoUndef);
3002   LtGRFunc->addParamAttr(2, Attribute::NoUndef);
3003 
3004   BasicBlock *EntryBlock = BasicBlock::Create(Ctx, "entry", LtGRFunc);
3005   Builder.SetInsertPoint(EntryBlock);
3006 
3007   // Buffer: global reduction buffer.
3008   Argument *BufferArg = LtGRFunc->getArg(0);
3009   // Idx: index of the buffer.
3010   Argument *IdxArg = LtGRFunc->getArg(1);
3011   // ReduceList: thread local Reduce list.
3012   Argument *ReduceListArg = LtGRFunc->getArg(2);
3013 
3014   Value *BufferArgAlloca = Builder.CreateAlloca(Builder.getPtrTy(), nullptr,
3015                                                 BufferArg->getName() + ".addr");
3016   Value *IdxArgAlloca = Builder.CreateAlloca(Builder.getInt32Ty(), nullptr,
3017                                              IdxArg->getName() + ".addr");
3018   Value *ReduceListArgAlloca = Builder.CreateAlloca(
3019       Builder.getPtrTy(), nullptr, ReduceListArg->getName() + ".addr");
3020   auto *RedListArrayTy =
3021       ArrayType::get(Builder.getPtrTy(), ReductionInfos.size());
3022 
3023   // 1. Build a list of reduction variables.
3024   // void *RedList[<n>] = {<ReductionVars>[0], ..., <ReductionVars>[<n>-1]};
3025   Value *LocalReduceList =
3026       Builder.CreateAlloca(RedListArrayTy, nullptr, ".omp.reduction.red_list");
3027 
3028   Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3029       BufferArgAlloca, Builder.getPtrTy(),
3030       BufferArgAlloca->getName() + ".ascast");
3031   Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3032       IdxArgAlloca, Builder.getPtrTy(), IdxArgAlloca->getName() + ".ascast");
3033   Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3034       ReduceListArgAlloca, Builder.getPtrTy(),
3035       ReduceListArgAlloca->getName() + ".ascast");
3036   Value *LocalReduceListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3037       LocalReduceList, Builder.getPtrTy(),
3038       LocalReduceList->getName() + ".ascast");
3039 
3040   Builder.CreateStore(BufferArg, BufferArgAddrCast);
3041   Builder.CreateStore(IdxArg, IdxArgAddrCast);
3042   Builder.CreateStore(ReduceListArg, ReduceListArgAddrCast);
3043 
3044   Value *BufferVal = Builder.CreateLoad(Builder.getPtrTy(), BufferArgAddrCast);
3045   Value *Idxs[] = {Builder.CreateLoad(Builder.getInt32Ty(), IdxArgAddrCast)};
3046   Type *IndexTy = Builder.getIndexTy(
3047       M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
3048   for (auto En : enumerate(ReductionInfos)) {
3049     Value *TargetElementPtrPtr = Builder.CreateInBoundsGEP(
3050         RedListArrayTy, LocalReduceListAddrCast,
3051         {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
3052     Value *BufferVD =
3053         Builder.CreateInBoundsGEP(ReductionsBufferTy, BufferVal, Idxs);
3054     // Global = Buffer.VD[Idx];
3055     Value *GlobValPtr = Builder.CreateConstInBoundsGEP2_32(
3056         ReductionsBufferTy, BufferVD, 0, En.index());
3057     Builder.CreateStore(GlobValPtr, TargetElementPtrPtr);
3058   }
3059 
3060   // Call reduce_function(GlobalReduceList, ReduceList)
3061   Value *ReduceList =
3062       Builder.CreateLoad(Builder.getPtrTy(), ReduceListArgAddrCast);
3063   Builder.CreateCall(ReduceFn, {LocalReduceListAddrCast, ReduceList})
3064       ->addFnAttr(Attribute::NoUnwind);
3065   Builder.CreateRetVoid();
3066   Builder.restoreIP(OldIP);
3067   return LtGRFunc;
3068 }
3069 
3070 Function *OpenMPIRBuilder::emitGlobalToListCopyFunction(
3071     ArrayRef<ReductionInfo> ReductionInfos, Type *ReductionsBufferTy,
3072     AttributeList FuncAttrs) {
3073   OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
3074   LLVMContext &Ctx = M.getContext();
3075   FunctionType *FuncTy = FunctionType::get(
3076       Builder.getVoidTy(),
3077       {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
3078       /* IsVarArg */ false);
3079   Function *LtGCFunc =
3080       Function::Create(FuncTy, GlobalVariable::InternalLinkage,
3081                        "_omp_reduction_global_to_list_copy_func", &M);
3082   LtGCFunc->setAttributes(FuncAttrs);
3083   LtGCFunc->addParamAttr(0, Attribute::NoUndef);
3084   LtGCFunc->addParamAttr(1, Attribute::NoUndef);
3085   LtGCFunc->addParamAttr(2, Attribute::NoUndef);
3086 
3087   BasicBlock *EntryBlock = BasicBlock::Create(Ctx, "entry", LtGCFunc);
3088   Builder.SetInsertPoint(EntryBlock);
3089 
3090   // Buffer: global reduction buffer.
3091   Argument *BufferArg = LtGCFunc->getArg(0);
3092   // Idx: index of the buffer.
3093   Argument *IdxArg = LtGCFunc->getArg(1);
3094   // ReduceList: thread local Reduce list.
3095   Argument *ReduceListArg = LtGCFunc->getArg(2);
3096 
3097   Value *BufferArgAlloca = Builder.CreateAlloca(Builder.getPtrTy(), nullptr,
3098                                                 BufferArg->getName() + ".addr");
3099   Value *IdxArgAlloca = Builder.CreateAlloca(Builder.getInt32Ty(), nullptr,
3100                                              IdxArg->getName() + ".addr");
3101   Value *ReduceListArgAlloca = Builder.CreateAlloca(
3102       Builder.getPtrTy(), nullptr, ReduceListArg->getName() + ".addr");
3103   Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3104       BufferArgAlloca, Builder.getPtrTy(),
3105       BufferArgAlloca->getName() + ".ascast");
3106   Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3107       IdxArgAlloca, Builder.getPtrTy(), IdxArgAlloca->getName() + ".ascast");
3108   Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3109       ReduceListArgAlloca, Builder.getPtrTy(),
3110       ReduceListArgAlloca->getName() + ".ascast");
3111   Builder.CreateStore(BufferArg, BufferArgAddrCast);
3112   Builder.CreateStore(IdxArg, IdxArgAddrCast);
3113   Builder.CreateStore(ReduceListArg, ReduceListArgAddrCast);
3114 
3115   Value *LocalReduceList =
3116       Builder.CreateLoad(Builder.getPtrTy(), ReduceListArgAddrCast);
3117   Value *BufferVal = Builder.CreateLoad(Builder.getPtrTy(), BufferArgAddrCast);
3118   Value *Idxs[] = {Builder.CreateLoad(Builder.getInt32Ty(), IdxArgAddrCast)};
3119   Type *IndexTy = Builder.getIndexTy(
3120       M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
3121   for (auto En : enumerate(ReductionInfos)) {
3122     const OpenMPIRBuilder::ReductionInfo &RI = En.value();
3123     auto *RedListArrayTy =
3124         ArrayType::get(Builder.getPtrTy(), ReductionInfos.size());
3125     // Reduce element = LocalReduceList[i]
3126     Value *ElemPtrPtr = Builder.CreateInBoundsGEP(
3127         RedListArrayTy, LocalReduceList,
3128         {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
3129     // elemptr = ((CopyType*)(elemptrptr)) + I
3130     Value *ElemPtr = Builder.CreateLoad(Builder.getPtrTy(), ElemPtrPtr);
3131     // Global = Buffer.VD[Idx];
3132     Value *BufferVD =
3133         Builder.CreateInBoundsGEP(ReductionsBufferTy, BufferVal, Idxs);
3134     Value *GlobValPtr = Builder.CreateConstInBoundsGEP2_32(
3135         ReductionsBufferTy, BufferVD, 0, En.index());
3136 
3137     switch (RI.EvaluationKind) {
3138     case EvalKind::Scalar: {
3139       Value *TargetElement = Builder.CreateLoad(RI.ElementType, GlobValPtr);
3140       Builder.CreateStore(TargetElement, ElemPtr);
3141       break;
3142     }
3143     case EvalKind::Complex: {
3144       Value *SrcRealPtr = Builder.CreateConstInBoundsGEP2_32(
3145           RI.ElementType, GlobValPtr, 0, 0, ".realp");
3146       Value *SrcReal = Builder.CreateLoad(
3147           RI.ElementType->getStructElementType(0), SrcRealPtr, ".real");
3148       Value *SrcImgPtr = Builder.CreateConstInBoundsGEP2_32(
3149           RI.ElementType, GlobValPtr, 0, 1, ".imagp");
3150       Value *SrcImg = Builder.CreateLoad(
3151           RI.ElementType->getStructElementType(1), SrcImgPtr, ".imag");
3152 
3153       Value *DestRealPtr = Builder.CreateConstInBoundsGEP2_32(
3154           RI.ElementType, ElemPtr, 0, 0, ".realp");
3155       Value *DestImgPtr = Builder.CreateConstInBoundsGEP2_32(
3156           RI.ElementType, ElemPtr, 0, 1, ".imagp");
3157       Builder.CreateStore(SrcReal, DestRealPtr);
3158       Builder.CreateStore(SrcImg, DestImgPtr);
3159       break;
3160     }
3161     case EvalKind::Aggregate: {
3162       Value *SizeVal =
3163           Builder.getInt64(M.getDataLayout().getTypeStoreSize(RI.ElementType));
3164       Builder.CreateMemCpy(
3165           ElemPtr, M.getDataLayout().getPrefTypeAlign(RI.ElementType),
3166           GlobValPtr, M.getDataLayout().getPrefTypeAlign(RI.ElementType),
3167           SizeVal, false);
3168       break;
3169     }
3170     }
3171   }
3172 
3173   Builder.CreateRetVoid();
3174   Builder.restoreIP(OldIP);
3175   return LtGCFunc;
3176 }
3177 
3178 Function *OpenMPIRBuilder::emitGlobalToListReduceFunction(
3179     ArrayRef<ReductionInfo> ReductionInfos, Function *ReduceFn,
3180     Type *ReductionsBufferTy, AttributeList FuncAttrs) {
3181   OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
3182   LLVMContext &Ctx = M.getContext();
3183   auto *FuncTy = FunctionType::get(
3184       Builder.getVoidTy(),
3185       {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
3186       /* IsVarArg */ false);
3187   Function *LtGRFunc =
3188       Function::Create(FuncTy, GlobalVariable::InternalLinkage,
3189                        "_omp_reduction_global_to_list_reduce_func", &M);
3190   LtGRFunc->setAttributes(FuncAttrs);
3191   LtGRFunc->addParamAttr(0, Attribute::NoUndef);
3192   LtGRFunc->addParamAttr(1, Attribute::NoUndef);
3193   LtGRFunc->addParamAttr(2, Attribute::NoUndef);
3194 
3195   BasicBlock *EntryBlock = BasicBlock::Create(Ctx, "entry", LtGRFunc);
3196   Builder.SetInsertPoint(EntryBlock);
3197 
3198   // Buffer: global reduction buffer.
3199   Argument *BufferArg = LtGRFunc->getArg(0);
3200   // Idx: index of the buffer.
3201   Argument *IdxArg = LtGRFunc->getArg(1);
3202   // ReduceList: thread local Reduce list.
3203   Argument *ReduceListArg = LtGRFunc->getArg(2);
3204 
3205   Value *BufferArgAlloca = Builder.CreateAlloca(Builder.getPtrTy(), nullptr,
3206                                                 BufferArg->getName() + ".addr");
3207   Value *IdxArgAlloca = Builder.CreateAlloca(Builder.getInt32Ty(), nullptr,
3208                                              IdxArg->getName() + ".addr");
3209   Value *ReduceListArgAlloca = Builder.CreateAlloca(
3210       Builder.getPtrTy(), nullptr, ReduceListArg->getName() + ".addr");
3211   ArrayType *RedListArrayTy =
3212       ArrayType::get(Builder.getPtrTy(), ReductionInfos.size());
3213 
3214   // 1. Build a list of reduction variables.
3215   // void *RedList[<n>] = {<ReductionVars>[0], ..., <ReductionVars>[<n>-1]};
3216   Value *LocalReduceList =
3217       Builder.CreateAlloca(RedListArrayTy, nullptr, ".omp.reduction.red_list");
3218 
3219   Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3220       BufferArgAlloca, Builder.getPtrTy(),
3221       BufferArgAlloca->getName() + ".ascast");
3222   Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3223       IdxArgAlloca, Builder.getPtrTy(), IdxArgAlloca->getName() + ".ascast");
3224   Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3225       ReduceListArgAlloca, Builder.getPtrTy(),
3226       ReduceListArgAlloca->getName() + ".ascast");
3227   Value *ReductionList = Builder.CreatePointerBitCastOrAddrSpaceCast(
3228       LocalReduceList, Builder.getPtrTy(),
3229       LocalReduceList->getName() + ".ascast");
3230 
3231   Builder.CreateStore(BufferArg, BufferArgAddrCast);
3232   Builder.CreateStore(IdxArg, IdxArgAddrCast);
3233   Builder.CreateStore(ReduceListArg, ReduceListArgAddrCast);
3234 
3235   Value *BufferVal = Builder.CreateLoad(Builder.getPtrTy(), BufferArgAddrCast);
3236   Value *Idxs[] = {Builder.CreateLoad(Builder.getInt32Ty(), IdxArgAddrCast)};
3237   Type *IndexTy = Builder.getIndexTy(
3238       M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
3239   for (auto En : enumerate(ReductionInfos)) {
3240     Value *TargetElementPtrPtr = Builder.CreateInBoundsGEP(
3241         RedListArrayTy, ReductionList,
3242         {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
3243     // Global = Buffer.VD[Idx];
3244     Value *BufferVD =
3245         Builder.CreateInBoundsGEP(ReductionsBufferTy, BufferVal, Idxs);
3246     Value *GlobValPtr = Builder.CreateConstInBoundsGEP2_32(
3247         ReductionsBufferTy, BufferVD, 0, En.index());
3248     Builder.CreateStore(GlobValPtr, TargetElementPtrPtr);
3249   }
3250 
3251   // Call reduce_function(ReduceList, GlobalReduceList)
3252   Value *ReduceList =
3253       Builder.CreateLoad(Builder.getPtrTy(), ReduceListArgAddrCast);
3254   Builder.CreateCall(ReduceFn, {ReduceList, ReductionList})
3255       ->addFnAttr(Attribute::NoUnwind);
3256   Builder.CreateRetVoid();
3257   Builder.restoreIP(OldIP);
3258   return LtGRFunc;
3259 }
3260 
3261 std::string OpenMPIRBuilder::getReductionFuncName(StringRef Name) const {
3262   std::string Suffix =
3263       createPlatformSpecificName({"omp", "reduction", "reduction_func"});
3264   return (Name + Suffix).str();
3265 }
3266 
3267 Function *OpenMPIRBuilder::createReductionFunction(
3268     StringRef ReducerName, ArrayRef<ReductionInfo> ReductionInfos,
3269     ReductionGenCBKind ReductionGenCBKind, AttributeList FuncAttrs) {
3270   auto *FuncTy = FunctionType::get(Builder.getVoidTy(),
3271                                    {Builder.getPtrTy(), Builder.getPtrTy()},
3272                                    /* IsVarArg */ false);
3273   std::string Name = getReductionFuncName(ReducerName);
3274   Function *ReductionFunc =
3275       Function::Create(FuncTy, GlobalVariable::InternalLinkage, Name, &M);
3276   ReductionFunc->setAttributes(FuncAttrs);
3277   ReductionFunc->addParamAttr(0, Attribute::NoUndef);
3278   ReductionFunc->addParamAttr(1, Attribute::NoUndef);
3279   BasicBlock *EntryBB =
3280       BasicBlock::Create(M.getContext(), "entry", ReductionFunc);
3281   Builder.SetInsertPoint(EntryBB);
3282 
3283   // Need to alloca memory here and deal with the pointers before getting
3284   // LHS/RHS pointers out
3285   Value *LHSArrayPtr = nullptr;
3286   Value *RHSArrayPtr = nullptr;
3287   Argument *Arg0 = ReductionFunc->getArg(0);
3288   Argument *Arg1 = ReductionFunc->getArg(1);
3289   Type *Arg0Type = Arg0->getType();
3290   Type *Arg1Type = Arg1->getType();
3291 
3292   Value *LHSAlloca =
3293       Builder.CreateAlloca(Arg0Type, nullptr, Arg0->getName() + ".addr");
3294   Value *RHSAlloca =
3295       Builder.CreateAlloca(Arg1Type, nullptr, Arg1->getName() + ".addr");
3296   Value *LHSAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3297       LHSAlloca, Arg0Type, LHSAlloca->getName() + ".ascast");
3298   Value *RHSAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3299       RHSAlloca, Arg1Type, RHSAlloca->getName() + ".ascast");
3300   Builder.CreateStore(Arg0, LHSAddrCast);
3301   Builder.CreateStore(Arg1, RHSAddrCast);
3302   LHSArrayPtr = Builder.CreateLoad(Arg0Type, LHSAddrCast);
3303   RHSArrayPtr = Builder.CreateLoad(Arg1Type, RHSAddrCast);
3304 
3305   Type *RedArrayTy = ArrayType::get(Builder.getPtrTy(), ReductionInfos.size());
3306   Type *IndexTy = Builder.getIndexTy(
3307       M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
3308   SmallVector<Value *> LHSPtrs, RHSPtrs;
3309   for (auto En : enumerate(ReductionInfos)) {
3310     const ReductionInfo &RI = En.value();
3311     Value *RHSI8PtrPtr = Builder.CreateInBoundsGEP(
3312         RedArrayTy, RHSArrayPtr,
3313         {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
3314     Value *RHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), RHSI8PtrPtr);
3315     Value *RHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
3316         RHSI8Ptr, RI.PrivateVariable->getType(),
3317         RHSI8Ptr->getName() + ".ascast");
3318 
3319     Value *LHSI8PtrPtr = Builder.CreateInBoundsGEP(
3320         RedArrayTy, LHSArrayPtr,
3321         {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
3322     Value *LHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), LHSI8PtrPtr);
3323     Value *LHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
3324         LHSI8Ptr, RI.Variable->getType(), LHSI8Ptr->getName() + ".ascast");
3325 
3326     if (ReductionGenCBKind == ReductionGenCBKind::Clang) {
3327       LHSPtrs.emplace_back(LHSPtr);
3328       RHSPtrs.emplace_back(RHSPtr);
3329     } else {
3330       Value *LHS = Builder.CreateLoad(RI.ElementType, LHSPtr);
3331       Value *RHS = Builder.CreateLoad(RI.ElementType, RHSPtr);
3332       Value *Reduced;
3333       RI.ReductionGen(Builder.saveIP(), LHS, RHS, Reduced);
3334       if (!Builder.GetInsertBlock())
3335         return ReductionFunc;
3336       Builder.CreateStore(Reduced, LHSPtr);
3337     }
3338   }
3339 
3340   if (ReductionGenCBKind == ReductionGenCBKind::Clang)
3341     for (auto En : enumerate(ReductionInfos)) {
3342       unsigned Index = En.index();
3343       const ReductionInfo &RI = En.value();
3344       Value *LHSFixupPtr, *RHSFixupPtr;
3345       Builder.restoreIP(RI.ReductionGenClang(
3346           Builder.saveIP(), Index, &LHSFixupPtr, &RHSFixupPtr, ReductionFunc));
3347 
3348       // Fix the CallBack code genereated to use the correct Values for the LHS
3349       // and RHS
3350       LHSFixupPtr->replaceUsesWithIf(
3351           LHSPtrs[Index], [ReductionFunc](const Use &U) {
3352             return cast<Instruction>(U.getUser())->getParent()->getParent() ==
3353                    ReductionFunc;
3354           });
3355       RHSFixupPtr->replaceUsesWithIf(
3356           RHSPtrs[Index], [ReductionFunc](const Use &U) {
3357             return cast<Instruction>(U.getUser())->getParent()->getParent() ==
3358                    ReductionFunc;
3359           });
3360     }
3361 
3362   Builder.CreateRetVoid();
3363   return ReductionFunc;
3364 }
3365 
3366 static void
3367 checkReductionInfos(ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos,
3368                     bool IsGPU) {
3369   for (const OpenMPIRBuilder::ReductionInfo &RI : ReductionInfos) {
3370     (void)RI;
3371     assert(RI.Variable && "expected non-null variable");
3372     assert(RI.PrivateVariable && "expected non-null private variable");
3373     assert((RI.ReductionGen || RI.ReductionGenClang) &&
3374            "expected non-null reduction generator callback");
3375     if (!IsGPU) {
3376       assert(
3377           RI.Variable->getType() == RI.PrivateVariable->getType() &&
3378           "expected variables and their private equivalents to have the same "
3379           "type");
3380     }
3381     assert(RI.Variable->getType()->isPointerTy() &&
3382            "expected variables to be pointers");
3383   }
3384 }
3385 
3386 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createReductionsGPU(
3387     const LocationDescription &Loc, InsertPointTy AllocaIP,
3388     InsertPointTy CodeGenIP, ArrayRef<ReductionInfo> ReductionInfos,
3389     bool IsNoWait, bool IsTeamsReduction, bool HasDistribute,
3390     ReductionGenCBKind ReductionGenCBKind, std::optional<omp::GV> GridValue,
3391     unsigned ReductionBufNum, Value *SrcLocInfo) {
3392   if (!updateToLocation(Loc))
3393     return InsertPointTy();
3394   Builder.restoreIP(CodeGenIP);
3395   checkReductionInfos(ReductionInfos, /*IsGPU*/ true);
3396   LLVMContext &Ctx = M.getContext();
3397 
3398   // Source location for the ident struct
3399   if (!SrcLocInfo) {
3400     uint32_t SrcLocStrSize;
3401     Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3402     SrcLocInfo = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3403   }
3404 
3405   if (ReductionInfos.size() == 0)
3406     return Builder.saveIP();
3407 
3408   Function *CurFunc = Builder.GetInsertBlock()->getParent();
3409   AttributeList FuncAttrs;
3410   AttrBuilder AttrBldr(Ctx);
3411   for (auto Attr : CurFunc->getAttributes().getFnAttrs())
3412     AttrBldr.addAttribute(Attr);
3413   AttrBldr.removeAttribute(Attribute::OptimizeNone);
3414   FuncAttrs = FuncAttrs.addFnAttributes(Ctx, AttrBldr);
3415 
3416   Function *ReductionFunc = nullptr;
3417   CodeGenIP = Builder.saveIP();
3418   ReductionFunc =
3419       createReductionFunction(Builder.GetInsertBlock()->getParent()->getName(),
3420                               ReductionInfos, ReductionGenCBKind, FuncAttrs);
3421   Builder.restoreIP(CodeGenIP);
3422 
3423   // Set the grid value in the config needed for lowering later on
3424   if (GridValue.has_value())
3425     Config.setGridValue(GridValue.value());
3426   else
3427     Config.setGridValue(getGridValue(T, ReductionFunc));
3428 
3429   uint32_t SrcLocStrSize;
3430   Constant *SrcLocStr = getOrCreateDefaultSrcLocStr(SrcLocStrSize);
3431   Value *RTLoc =
3432       getOrCreateIdent(SrcLocStr, SrcLocStrSize, omp::IdentFlag(0), 0);
3433 
3434   // Build res = __kmpc_reduce{_nowait}(<gtid>, <n>, sizeof(RedList),
3435   // RedList, shuffle_reduce_func, interwarp_copy_func);
3436   // or
3437   // Build res = __kmpc_reduce_teams_nowait_simple(<loc>, <gtid>, <lck>);
3438   Value *Res;
3439 
3440   // 1. Build a list of reduction variables.
3441   // void *RedList[<n>] = {<ReductionVars>[0], ..., <ReductionVars>[<n>-1]};
3442   auto Size = ReductionInfos.size();
3443   Type *PtrTy = PointerType::getUnqual(Ctx);
3444   Type *RedArrayTy = ArrayType::get(PtrTy, Size);
3445   CodeGenIP = Builder.saveIP();
3446   Builder.restoreIP(AllocaIP);
3447   Value *ReductionListAlloca =
3448       Builder.CreateAlloca(RedArrayTy, nullptr, ".omp.reduction.red_list");
3449   Value *ReductionList = Builder.CreatePointerBitCastOrAddrSpaceCast(
3450       ReductionListAlloca, PtrTy, ReductionListAlloca->getName() + ".ascast");
3451   Builder.restoreIP(CodeGenIP);
3452   Type *IndexTy = Builder.getIndexTy(
3453       M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
3454   for (auto En : enumerate(ReductionInfos)) {
3455     const ReductionInfo &RI = En.value();
3456     Value *ElemPtr = Builder.CreateInBoundsGEP(
3457         RedArrayTy, ReductionList,
3458         {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
3459     Value *CastElem =
3460         Builder.CreatePointerBitCastOrAddrSpaceCast(RI.PrivateVariable, PtrTy);
3461     Builder.CreateStore(CastElem, ElemPtr);
3462   }
3463   CodeGenIP = Builder.saveIP();
3464   Function *SarFunc =
3465       emitShuffleAndReduceFunction(ReductionInfos, ReductionFunc, FuncAttrs);
3466   Function *WcFunc = emitInterWarpCopyFunction(Loc, ReductionInfos, FuncAttrs);
3467   Builder.restoreIP(CodeGenIP);
3468 
3469   Value *RL = Builder.CreatePointerBitCastOrAddrSpaceCast(ReductionList, PtrTy);
3470 
3471   unsigned MaxDataSize = 0;
3472   SmallVector<Type *> ReductionTypeArgs;
3473   for (auto En : enumerate(ReductionInfos)) {
3474     auto Size = M.getDataLayout().getTypeStoreSize(En.value().ElementType);
3475     if (Size > MaxDataSize)
3476       MaxDataSize = Size;
3477     ReductionTypeArgs.emplace_back(En.value().ElementType);
3478   }
3479   Value *ReductionDataSize =
3480       Builder.getInt64(MaxDataSize * ReductionInfos.size());
3481   if (!IsTeamsReduction) {
3482     Value *SarFuncCast =
3483         Builder.CreatePointerBitCastOrAddrSpaceCast(SarFunc, PtrTy);
3484     Value *WcFuncCast =
3485         Builder.CreatePointerBitCastOrAddrSpaceCast(WcFunc, PtrTy);
3486     Value *Args[] = {RTLoc, ReductionDataSize, RL, SarFuncCast, WcFuncCast};
3487     Function *Pv2Ptr = getOrCreateRuntimeFunctionPtr(
3488         RuntimeFunction::OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2);
3489     Res = Builder.CreateCall(Pv2Ptr, Args);
3490   } else {
3491     CodeGenIP = Builder.saveIP();
3492     StructType *ReductionsBufferTy = StructType::create(
3493         Ctx, ReductionTypeArgs, "struct._globalized_locals_ty");
3494     Function *RedFixedBuferFn = getOrCreateRuntimeFunctionPtr(
3495         RuntimeFunction::OMPRTL___kmpc_reduction_get_fixed_buffer);
3496     Function *LtGCFunc = emitListToGlobalCopyFunction(
3497         ReductionInfos, ReductionsBufferTy, FuncAttrs);
3498     Function *LtGRFunc = emitListToGlobalReduceFunction(
3499         ReductionInfos, ReductionFunc, ReductionsBufferTy, FuncAttrs);
3500     Function *GtLCFunc = emitGlobalToListCopyFunction(
3501         ReductionInfos, ReductionsBufferTy, FuncAttrs);
3502     Function *GtLRFunc = emitGlobalToListReduceFunction(
3503         ReductionInfos, ReductionFunc, ReductionsBufferTy, FuncAttrs);
3504     Builder.restoreIP(CodeGenIP);
3505 
3506     Value *KernelTeamsReductionPtr = Builder.CreateCall(
3507         RedFixedBuferFn, {}, "_openmp_teams_reductions_buffer_$_$ptr");
3508 
3509     Value *Args3[] = {RTLoc,
3510                       KernelTeamsReductionPtr,
3511                       Builder.getInt32(ReductionBufNum),
3512                       ReductionDataSize,
3513                       RL,
3514                       SarFunc,
3515                       WcFunc,
3516                       LtGCFunc,
3517                       LtGRFunc,
3518                       GtLCFunc,
3519                       GtLRFunc};
3520 
3521     Function *TeamsReduceFn = getOrCreateRuntimeFunctionPtr(
3522         RuntimeFunction::OMPRTL___kmpc_nvptx_teams_reduce_nowait_v2);
3523     Res = Builder.CreateCall(TeamsReduceFn, Args3);
3524   }
3525 
3526   // 5. Build if (res == 1)
3527   BasicBlock *ExitBB = BasicBlock::Create(Ctx, ".omp.reduction.done");
3528   BasicBlock *ThenBB = BasicBlock::Create(Ctx, ".omp.reduction.then");
3529   Value *Cond = Builder.CreateICmpEQ(Res, Builder.getInt32(1));
3530   Builder.CreateCondBr(Cond, ThenBB, ExitBB);
3531 
3532   // 6. Build then branch: where we have reduced values in the master
3533   //    thread in each team.
3534   //    __kmpc_end_reduce{_nowait}(<gtid>);
3535   //    break;
3536   emitBlock(ThenBB, CurFunc);
3537 
3538   // Add emission of __kmpc_end_reduce{_nowait}(<gtid>);
3539   for (auto En : enumerate(ReductionInfos)) {
3540     const ReductionInfo &RI = En.value();
3541     Value *LHS = RI.Variable;
3542     Value *RHS =
3543         Builder.CreatePointerBitCastOrAddrSpaceCast(RI.PrivateVariable, PtrTy);
3544 
3545     if (ReductionGenCBKind == ReductionGenCBKind::Clang) {
3546       Value *LHSPtr, *RHSPtr;
3547       Builder.restoreIP(RI.ReductionGenClang(Builder.saveIP(), En.index(),
3548                                              &LHSPtr, &RHSPtr, CurFunc));
3549 
3550       // Fix the CallBack code genereated to use the correct Values for the LHS
3551       // and RHS
3552       LHSPtr->replaceUsesWithIf(LHS, [ReductionFunc](const Use &U) {
3553         return cast<Instruction>(U.getUser())->getParent()->getParent() ==
3554                ReductionFunc;
3555       });
3556       RHSPtr->replaceUsesWithIf(RHS, [ReductionFunc](const Use &U) {
3557         return cast<Instruction>(U.getUser())->getParent()->getParent() ==
3558                ReductionFunc;
3559       });
3560     } else {
3561       assert(false && "Unhandled ReductionGenCBKind");
3562     }
3563   }
3564   emitBlock(ExitBB, CurFunc);
3565 
3566   Config.setEmitLLVMUsed();
3567 
3568   return Builder.saveIP();
3569 }
3570 
3571 static Function *getFreshReductionFunc(Module &M) {
3572   Type *VoidTy = Type::getVoidTy(M.getContext());
3573   Type *Int8PtrTy = PointerType::getUnqual(M.getContext());
3574   auto *FuncTy =
3575       FunctionType::get(VoidTy, {Int8PtrTy, Int8PtrTy}, /* IsVarArg */ false);
3576   return Function::Create(FuncTy, GlobalVariable::InternalLinkage,
3577                           ".omp.reduction.func", &M);
3578 }
3579 
3580 OpenMPIRBuilder::InsertPointTy
3581 OpenMPIRBuilder::createReductions(const LocationDescription &Loc,
3582                                   InsertPointTy AllocaIP,
3583                                   ArrayRef<ReductionInfo> ReductionInfos,
3584                                   ArrayRef<bool> IsByRef, bool IsNoWait) {
3585   assert(ReductionInfos.size() == IsByRef.size());
3586   for (const ReductionInfo &RI : ReductionInfos) {
3587     (void)RI;
3588     assert(RI.Variable && "expected non-null variable");
3589     assert(RI.PrivateVariable && "expected non-null private variable");
3590     assert(RI.ReductionGen && "expected non-null reduction generator callback");
3591     assert(RI.Variable->getType() == RI.PrivateVariable->getType() &&
3592            "expected variables and their private equivalents to have the same "
3593            "type");
3594     assert(RI.Variable->getType()->isPointerTy() &&
3595            "expected variables to be pointers");
3596   }
3597 
3598   if (!updateToLocation(Loc))
3599     return InsertPointTy();
3600 
3601   BasicBlock *InsertBlock = Loc.IP.getBlock();
3602   BasicBlock *ContinuationBlock =
3603       InsertBlock->splitBasicBlock(Loc.IP.getPoint(), "reduce.finalize");
3604   InsertBlock->getTerminator()->eraseFromParent();
3605 
3606   // Create and populate array of type-erased pointers to private reduction
3607   // values.
3608   unsigned NumReductions = ReductionInfos.size();
3609   Type *RedArrayTy = ArrayType::get(Builder.getPtrTy(), NumReductions);
3610   Builder.SetInsertPoint(AllocaIP.getBlock()->getTerminator());
3611   Value *RedArray = Builder.CreateAlloca(RedArrayTy, nullptr, "red.array");
3612 
3613   Builder.SetInsertPoint(InsertBlock, InsertBlock->end());
3614 
3615   for (auto En : enumerate(ReductionInfos)) {
3616     unsigned Index = En.index();
3617     const ReductionInfo &RI = En.value();
3618     Value *RedArrayElemPtr = Builder.CreateConstInBoundsGEP2_64(
3619         RedArrayTy, RedArray, 0, Index, "red.array.elem." + Twine(Index));
3620     Builder.CreateStore(RI.PrivateVariable, RedArrayElemPtr);
3621   }
3622 
3623   // Emit a call to the runtime function that orchestrates the reduction.
3624   // Declare the reduction function in the process.
3625   Function *Func = Builder.GetInsertBlock()->getParent();
3626   Module *Module = Func->getParent();
3627   uint32_t SrcLocStrSize;
3628   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3629   bool CanGenerateAtomic = all_of(ReductionInfos, [](const ReductionInfo &RI) {
3630     return RI.AtomicReductionGen;
3631   });
3632   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize,
3633                                   CanGenerateAtomic
3634                                       ? IdentFlag::OMP_IDENT_FLAG_ATOMIC_REDUCE
3635                                       : IdentFlag(0));
3636   Value *ThreadId = getOrCreateThreadID(Ident);
3637   Constant *NumVariables = Builder.getInt32(NumReductions);
3638   const DataLayout &DL = Module->getDataLayout();
3639   unsigned RedArrayByteSize = DL.getTypeStoreSize(RedArrayTy);
3640   Constant *RedArraySize = Builder.getInt64(RedArrayByteSize);
3641   Function *ReductionFunc = getFreshReductionFunc(*Module);
3642   Value *Lock = getOMPCriticalRegionLock(".reduction");
3643   Function *ReduceFunc = getOrCreateRuntimeFunctionPtr(
3644       IsNoWait ? RuntimeFunction::OMPRTL___kmpc_reduce_nowait
3645                : RuntimeFunction::OMPRTL___kmpc_reduce);
3646   CallInst *ReduceCall =
3647       Builder.CreateCall(ReduceFunc,
3648                          {Ident, ThreadId, NumVariables, RedArraySize, RedArray,
3649                           ReductionFunc, Lock},
3650                          "reduce");
3651 
3652   // Create final reduction entry blocks for the atomic and non-atomic case.
3653   // Emit IR that dispatches control flow to one of the blocks based on the
3654   // reduction supporting the atomic mode.
3655   BasicBlock *NonAtomicRedBlock =
3656       BasicBlock::Create(Module->getContext(), "reduce.switch.nonatomic", Func);
3657   BasicBlock *AtomicRedBlock =
3658       BasicBlock::Create(Module->getContext(), "reduce.switch.atomic", Func);
3659   SwitchInst *Switch =
3660       Builder.CreateSwitch(ReduceCall, ContinuationBlock, /* NumCases */ 2);
3661   Switch->addCase(Builder.getInt32(1), NonAtomicRedBlock);
3662   Switch->addCase(Builder.getInt32(2), AtomicRedBlock);
3663 
3664   // Populate the non-atomic reduction using the elementwise reduction function.
3665   // This loads the elements from the global and private variables and reduces
3666   // them before storing back the result to the global variable.
3667   Builder.SetInsertPoint(NonAtomicRedBlock);
3668   for (auto En : enumerate(ReductionInfos)) {
3669     const ReductionInfo &RI = En.value();
3670     Type *ValueType = RI.ElementType;
3671     // We have one less load for by-ref case because that load is now inside of
3672     // the reduction region
3673     Value *RedValue = nullptr;
3674     if (!IsByRef[En.index()]) {
3675       RedValue = Builder.CreateLoad(ValueType, RI.Variable,
3676                                     "red.value." + Twine(En.index()));
3677     }
3678     Value *PrivateRedValue =
3679         Builder.CreateLoad(ValueType, RI.PrivateVariable,
3680                            "red.private.value." + Twine(En.index()));
3681     Value *Reduced;
3682     if (IsByRef[En.index()]) {
3683       Builder.restoreIP(RI.ReductionGen(Builder.saveIP(), RI.Variable,
3684                                         PrivateRedValue, Reduced));
3685     } else {
3686       Builder.restoreIP(RI.ReductionGen(Builder.saveIP(), RedValue,
3687                                         PrivateRedValue, Reduced));
3688     }
3689     if (!Builder.GetInsertBlock())
3690       return InsertPointTy();
3691     // for by-ref case, the load is inside of the reduction region
3692     if (!IsByRef[En.index()])
3693       Builder.CreateStore(Reduced, RI.Variable);
3694   }
3695   Function *EndReduceFunc = getOrCreateRuntimeFunctionPtr(
3696       IsNoWait ? RuntimeFunction::OMPRTL___kmpc_end_reduce_nowait
3697                : RuntimeFunction::OMPRTL___kmpc_end_reduce);
3698   Builder.CreateCall(EndReduceFunc, {Ident, ThreadId, Lock});
3699   Builder.CreateBr(ContinuationBlock);
3700 
3701   // Populate the atomic reduction using the atomic elementwise reduction
3702   // function. There are no loads/stores here because they will be happening
3703   // inside the atomic elementwise reduction.
3704   Builder.SetInsertPoint(AtomicRedBlock);
3705   if (CanGenerateAtomic && llvm::none_of(IsByRef, [](bool P) { return P; })) {
3706     for (const ReductionInfo &RI : ReductionInfos) {
3707       Builder.restoreIP(RI.AtomicReductionGen(Builder.saveIP(), RI.ElementType,
3708                                               RI.Variable, RI.PrivateVariable));
3709       if (!Builder.GetInsertBlock())
3710         return InsertPointTy();
3711     }
3712     Builder.CreateBr(ContinuationBlock);
3713   } else {
3714     Builder.CreateUnreachable();
3715   }
3716 
3717   // Populate the outlined reduction function using the elementwise reduction
3718   // function. Partial values are extracted from the type-erased array of
3719   // pointers to private variables.
3720   BasicBlock *ReductionFuncBlock =
3721       BasicBlock::Create(Module->getContext(), "", ReductionFunc);
3722   Builder.SetInsertPoint(ReductionFuncBlock);
3723   Value *LHSArrayPtr = ReductionFunc->getArg(0);
3724   Value *RHSArrayPtr = ReductionFunc->getArg(1);
3725 
3726   for (auto En : enumerate(ReductionInfos)) {
3727     const ReductionInfo &RI = En.value();
3728     Value *LHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
3729         RedArrayTy, LHSArrayPtr, 0, En.index());
3730     Value *LHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), LHSI8PtrPtr);
3731     Value *LHSPtr = Builder.CreateBitCast(LHSI8Ptr, RI.Variable->getType());
3732     Value *LHS = Builder.CreateLoad(RI.ElementType, LHSPtr);
3733     Value *RHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
3734         RedArrayTy, RHSArrayPtr, 0, En.index());
3735     Value *RHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), RHSI8PtrPtr);
3736     Value *RHSPtr =
3737         Builder.CreateBitCast(RHSI8Ptr, RI.PrivateVariable->getType());
3738     Value *RHS = Builder.CreateLoad(RI.ElementType, RHSPtr);
3739     Value *Reduced;
3740     Builder.restoreIP(RI.ReductionGen(Builder.saveIP(), LHS, RHS, Reduced));
3741     if (!Builder.GetInsertBlock())
3742       return InsertPointTy();
3743     // store is inside of the reduction region when using by-ref
3744     if (!IsByRef[En.index()])
3745       Builder.CreateStore(Reduced, LHSPtr);
3746   }
3747   Builder.CreateRetVoid();
3748 
3749   Builder.SetInsertPoint(ContinuationBlock);
3750   return Builder.saveIP();
3751 }
3752 
3753 OpenMPIRBuilder::InsertPointTy
3754 OpenMPIRBuilder::createMaster(const LocationDescription &Loc,
3755                               BodyGenCallbackTy BodyGenCB,
3756                               FinalizeCallbackTy FiniCB) {
3757 
3758   if (!updateToLocation(Loc))
3759     return Loc.IP;
3760 
3761   Directive OMPD = Directive::OMPD_master;
3762   uint32_t SrcLocStrSize;
3763   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3764   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3765   Value *ThreadId = getOrCreateThreadID(Ident);
3766   Value *Args[] = {Ident, ThreadId};
3767 
3768   Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_master);
3769   Instruction *EntryCall = Builder.CreateCall(EntryRTLFn, Args);
3770 
3771   Function *ExitRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_master);
3772   Instruction *ExitCall = Builder.CreateCall(ExitRTLFn, Args);
3773 
3774   return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
3775                               /*Conditional*/ true, /*hasFinalize*/ true);
3776 }
3777 
3778 OpenMPIRBuilder::InsertPointTy
3779 OpenMPIRBuilder::createMasked(const LocationDescription &Loc,
3780                               BodyGenCallbackTy BodyGenCB,
3781                               FinalizeCallbackTy FiniCB, Value *Filter) {
3782   if (!updateToLocation(Loc))
3783     return Loc.IP;
3784 
3785   Directive OMPD = Directive::OMPD_masked;
3786   uint32_t SrcLocStrSize;
3787   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3788   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3789   Value *ThreadId = getOrCreateThreadID(Ident);
3790   Value *Args[] = {Ident, ThreadId, Filter};
3791   Value *ArgsEnd[] = {Ident, ThreadId};
3792 
3793   Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_masked);
3794   Instruction *EntryCall = Builder.CreateCall(EntryRTLFn, Args);
3795 
3796   Function *ExitRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_masked);
3797   Instruction *ExitCall = Builder.CreateCall(ExitRTLFn, ArgsEnd);
3798 
3799   return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
3800                               /*Conditional*/ true, /*hasFinalize*/ true);
3801 }
3802 
3803 CanonicalLoopInfo *OpenMPIRBuilder::createLoopSkeleton(
3804     DebugLoc DL, Value *TripCount, Function *F, BasicBlock *PreInsertBefore,
3805     BasicBlock *PostInsertBefore, const Twine &Name) {
3806   Module *M = F->getParent();
3807   LLVMContext &Ctx = M->getContext();
3808   Type *IndVarTy = TripCount->getType();
3809 
3810   // Create the basic block structure.
3811   BasicBlock *Preheader =
3812       BasicBlock::Create(Ctx, "omp_" + Name + ".preheader", F, PreInsertBefore);
3813   BasicBlock *Header =
3814       BasicBlock::Create(Ctx, "omp_" + Name + ".header", F, PreInsertBefore);
3815   BasicBlock *Cond =
3816       BasicBlock::Create(Ctx, "omp_" + Name + ".cond", F, PreInsertBefore);
3817   BasicBlock *Body =
3818       BasicBlock::Create(Ctx, "omp_" + Name + ".body", F, PreInsertBefore);
3819   BasicBlock *Latch =
3820       BasicBlock::Create(Ctx, "omp_" + Name + ".inc", F, PostInsertBefore);
3821   BasicBlock *Exit =
3822       BasicBlock::Create(Ctx, "omp_" + Name + ".exit", F, PostInsertBefore);
3823   BasicBlock *After =
3824       BasicBlock::Create(Ctx, "omp_" + Name + ".after", F, PostInsertBefore);
3825 
3826   // Use specified DebugLoc for new instructions.
3827   Builder.SetCurrentDebugLocation(DL);
3828 
3829   Builder.SetInsertPoint(Preheader);
3830   Builder.CreateBr(Header);
3831 
3832   Builder.SetInsertPoint(Header);
3833   PHINode *IndVarPHI = Builder.CreatePHI(IndVarTy, 2, "omp_" + Name + ".iv");
3834   IndVarPHI->addIncoming(ConstantInt::get(IndVarTy, 0), Preheader);
3835   Builder.CreateBr(Cond);
3836 
3837   Builder.SetInsertPoint(Cond);
3838   Value *Cmp =
3839       Builder.CreateICmpULT(IndVarPHI, TripCount, "omp_" + Name + ".cmp");
3840   Builder.CreateCondBr(Cmp, Body, Exit);
3841 
3842   Builder.SetInsertPoint(Body);
3843   Builder.CreateBr(Latch);
3844 
3845   Builder.SetInsertPoint(Latch);
3846   Value *Next = Builder.CreateAdd(IndVarPHI, ConstantInt::get(IndVarTy, 1),
3847                                   "omp_" + Name + ".next", /*HasNUW=*/true);
3848   Builder.CreateBr(Header);
3849   IndVarPHI->addIncoming(Next, Latch);
3850 
3851   Builder.SetInsertPoint(Exit);
3852   Builder.CreateBr(After);
3853 
3854   // Remember and return the canonical control flow.
3855   LoopInfos.emplace_front();
3856   CanonicalLoopInfo *CL = &LoopInfos.front();
3857 
3858   CL->Header = Header;
3859   CL->Cond = Cond;
3860   CL->Latch = Latch;
3861   CL->Exit = Exit;
3862 
3863 #ifndef NDEBUG
3864   CL->assertOK();
3865 #endif
3866   return CL;
3867 }
3868 
3869 CanonicalLoopInfo *
3870 OpenMPIRBuilder::createCanonicalLoop(const LocationDescription &Loc,
3871                                      LoopBodyGenCallbackTy BodyGenCB,
3872                                      Value *TripCount, const Twine &Name) {
3873   BasicBlock *BB = Loc.IP.getBlock();
3874   BasicBlock *NextBB = BB->getNextNode();
3875 
3876   CanonicalLoopInfo *CL = createLoopSkeleton(Loc.DL, TripCount, BB->getParent(),
3877                                              NextBB, NextBB, Name);
3878   BasicBlock *After = CL->getAfter();
3879 
3880   // If location is not set, don't connect the loop.
3881   if (updateToLocation(Loc)) {
3882     // Split the loop at the insertion point: Branch to the preheader and move
3883     // every following instruction to after the loop (the After BB). Also, the
3884     // new successor is the loop's after block.
3885     spliceBB(Builder, After, /*CreateBranch=*/false);
3886     Builder.CreateBr(CL->getPreheader());
3887   }
3888 
3889   // Emit the body content. We do it after connecting the loop to the CFG to
3890   // avoid that the callback encounters degenerate BBs.
3891   BodyGenCB(CL->getBodyIP(), CL->getIndVar());
3892 
3893 #ifndef NDEBUG
3894   CL->assertOK();
3895 #endif
3896   return CL;
3897 }
3898 
3899 CanonicalLoopInfo *OpenMPIRBuilder::createCanonicalLoop(
3900     const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
3901     Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
3902     InsertPointTy ComputeIP, const Twine &Name) {
3903 
3904   // Consider the following difficulties (assuming 8-bit signed integers):
3905   //  * Adding \p Step to the loop counter which passes \p Stop may overflow:
3906   //      DO I = 1, 100, 50
3907   ///  * A \p Step of INT_MIN cannot not be normalized to a positive direction:
3908   //      DO I = 100, 0, -128
3909 
3910   // Start, Stop and Step must be of the same integer type.
3911   auto *IndVarTy = cast<IntegerType>(Start->getType());
3912   assert(IndVarTy == Stop->getType() && "Stop type mismatch");
3913   assert(IndVarTy == Step->getType() && "Step type mismatch");
3914 
3915   LocationDescription ComputeLoc =
3916       ComputeIP.isSet() ? LocationDescription(ComputeIP, Loc.DL) : Loc;
3917   updateToLocation(ComputeLoc);
3918 
3919   ConstantInt *Zero = ConstantInt::get(IndVarTy, 0);
3920   ConstantInt *One = ConstantInt::get(IndVarTy, 1);
3921 
3922   // Like Step, but always positive.
3923   Value *Incr = Step;
3924 
3925   // Distance between Start and Stop; always positive.
3926   Value *Span;
3927 
3928   // Condition whether there are no iterations are executed at all, e.g. because
3929   // UB < LB.
3930   Value *ZeroCmp;
3931 
3932   if (IsSigned) {
3933     // Ensure that increment is positive. If not, negate and invert LB and UB.
3934     Value *IsNeg = Builder.CreateICmpSLT(Step, Zero);
3935     Incr = Builder.CreateSelect(IsNeg, Builder.CreateNeg(Step), Step);
3936     Value *LB = Builder.CreateSelect(IsNeg, Stop, Start);
3937     Value *UB = Builder.CreateSelect(IsNeg, Start, Stop);
3938     Span = Builder.CreateSub(UB, LB, "", false, true);
3939     ZeroCmp = Builder.CreateICmp(
3940         InclusiveStop ? CmpInst::ICMP_SLT : CmpInst::ICMP_SLE, UB, LB);
3941   } else {
3942     Span = Builder.CreateSub(Stop, Start, "", true);
3943     ZeroCmp = Builder.CreateICmp(
3944         InclusiveStop ? CmpInst::ICMP_ULT : CmpInst::ICMP_ULE, Stop, Start);
3945   }
3946 
3947   Value *CountIfLooping;
3948   if (InclusiveStop) {
3949     CountIfLooping = Builder.CreateAdd(Builder.CreateUDiv(Span, Incr), One);
3950   } else {
3951     // Avoid incrementing past stop since it could overflow.
3952     Value *CountIfTwo = Builder.CreateAdd(
3953         Builder.CreateUDiv(Builder.CreateSub(Span, One), Incr), One);
3954     Value *OneCmp = Builder.CreateICmp(CmpInst::ICMP_ULE, Span, Incr);
3955     CountIfLooping = Builder.CreateSelect(OneCmp, One, CountIfTwo);
3956   }
3957   Value *TripCount = Builder.CreateSelect(ZeroCmp, Zero, CountIfLooping,
3958                                           "omp_" + Name + ".tripcount");
3959 
3960   auto BodyGen = [=](InsertPointTy CodeGenIP, Value *IV) {
3961     Builder.restoreIP(CodeGenIP);
3962     Value *Span = Builder.CreateMul(IV, Step);
3963     Value *IndVar = Builder.CreateAdd(Span, Start);
3964     BodyGenCB(Builder.saveIP(), IndVar);
3965   };
3966   LocationDescription LoopLoc = ComputeIP.isSet() ? Loc.IP : Builder.saveIP();
3967   return createCanonicalLoop(LoopLoc, BodyGen, TripCount, Name);
3968 }
3969 
3970 // Returns an LLVM function to call for initializing loop bounds using OpenMP
3971 // static scheduling depending on `type`. Only i32 and i64 are supported by the
3972 // runtime. Always interpret integers as unsigned similarly to
3973 // CanonicalLoopInfo.
3974 static FunctionCallee getKmpcForStaticInitForType(Type *Ty, Module &M,
3975                                                   OpenMPIRBuilder &OMPBuilder) {
3976   unsigned Bitwidth = Ty->getIntegerBitWidth();
3977   if (Bitwidth == 32)
3978     return OMPBuilder.getOrCreateRuntimeFunction(
3979         M, omp::RuntimeFunction::OMPRTL___kmpc_for_static_init_4u);
3980   if (Bitwidth == 64)
3981     return OMPBuilder.getOrCreateRuntimeFunction(
3982         M, omp::RuntimeFunction::OMPRTL___kmpc_for_static_init_8u);
3983   llvm_unreachable("unknown OpenMP loop iterator bitwidth");
3984 }
3985 
3986 OpenMPIRBuilder::InsertPointTy
3987 OpenMPIRBuilder::applyStaticWorkshareLoop(DebugLoc DL, CanonicalLoopInfo *CLI,
3988                                           InsertPointTy AllocaIP,
3989                                           bool NeedsBarrier) {
3990   assert(CLI->isValid() && "Requires a valid canonical loop");
3991   assert(!isConflictIP(AllocaIP, CLI->getPreheaderIP()) &&
3992          "Require dedicated allocate IP");
3993 
3994   // Set up the source location value for OpenMP runtime.
3995   Builder.restoreIP(CLI->getPreheaderIP());
3996   Builder.SetCurrentDebugLocation(DL);
3997 
3998   uint32_t SrcLocStrSize;
3999   Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
4000   Value *SrcLoc = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4001 
4002   // Declare useful OpenMP runtime functions.
4003   Value *IV = CLI->getIndVar();
4004   Type *IVTy = IV->getType();
4005   FunctionCallee StaticInit = getKmpcForStaticInitForType(IVTy, M, *this);
4006   FunctionCallee StaticFini =
4007       getOrCreateRuntimeFunction(M, omp::OMPRTL___kmpc_for_static_fini);
4008 
4009   // Allocate space for computed loop bounds as expected by the "init" function.
4010   Builder.SetInsertPoint(AllocaIP.getBlock()->getFirstNonPHIOrDbgOrAlloca());
4011 
4012   Type *I32Type = Type::getInt32Ty(M.getContext());
4013   Value *PLastIter = Builder.CreateAlloca(I32Type, nullptr, "p.lastiter");
4014   Value *PLowerBound = Builder.CreateAlloca(IVTy, nullptr, "p.lowerbound");
4015   Value *PUpperBound = Builder.CreateAlloca(IVTy, nullptr, "p.upperbound");
4016   Value *PStride = Builder.CreateAlloca(IVTy, nullptr, "p.stride");
4017 
4018   // At the end of the preheader, prepare for calling the "init" function by
4019   // storing the current loop bounds into the allocated space. A canonical loop
4020   // always iterates from 0 to trip-count with step 1. Note that "init" expects
4021   // and produces an inclusive upper bound.
4022   Builder.SetInsertPoint(CLI->getPreheader()->getTerminator());
4023   Constant *Zero = ConstantInt::get(IVTy, 0);
4024   Constant *One = ConstantInt::get(IVTy, 1);
4025   Builder.CreateStore(Zero, PLowerBound);
4026   Value *UpperBound = Builder.CreateSub(CLI->getTripCount(), One);
4027   Builder.CreateStore(UpperBound, PUpperBound);
4028   Builder.CreateStore(One, PStride);
4029 
4030   Value *ThreadNum = getOrCreateThreadID(SrcLoc);
4031 
4032   Constant *SchedulingType = ConstantInt::get(
4033       I32Type, static_cast<int>(OMPScheduleType::UnorderedStatic));
4034 
4035   // Call the "init" function and update the trip count of the loop with the
4036   // value it produced.
4037   Builder.CreateCall(StaticInit,
4038                      {SrcLoc, ThreadNum, SchedulingType, PLastIter, PLowerBound,
4039                       PUpperBound, PStride, One, Zero});
4040   Value *LowerBound = Builder.CreateLoad(IVTy, PLowerBound);
4041   Value *InclusiveUpperBound = Builder.CreateLoad(IVTy, PUpperBound);
4042   Value *TripCountMinusOne = Builder.CreateSub(InclusiveUpperBound, LowerBound);
4043   Value *TripCount = Builder.CreateAdd(TripCountMinusOne, One);
4044   CLI->setTripCount(TripCount);
4045 
4046   // Update all uses of the induction variable except the one in the condition
4047   // block that compares it with the actual upper bound, and the increment in
4048   // the latch block.
4049 
4050   CLI->mapIndVar([&](Instruction *OldIV) -> Value * {
4051     Builder.SetInsertPoint(CLI->getBody(),
4052                            CLI->getBody()->getFirstInsertionPt());
4053     Builder.SetCurrentDebugLocation(DL);
4054     return Builder.CreateAdd(OldIV, LowerBound);
4055   });
4056 
4057   // In the "exit" block, call the "fini" function.
4058   Builder.SetInsertPoint(CLI->getExit(),
4059                          CLI->getExit()->getTerminator()->getIterator());
4060   Builder.CreateCall(StaticFini, {SrcLoc, ThreadNum});
4061 
4062   // Add the barrier if requested.
4063   if (NeedsBarrier)
4064     createBarrier(LocationDescription(Builder.saveIP(), DL),
4065                   omp::Directive::OMPD_for, /* ForceSimpleCall */ false,
4066                   /* CheckCancelFlag */ false);
4067 
4068   InsertPointTy AfterIP = CLI->getAfterIP();
4069   CLI->invalidate();
4070 
4071   return AfterIP;
4072 }
4073 
4074 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(
4075     DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
4076     bool NeedsBarrier, Value *ChunkSize) {
4077   assert(CLI->isValid() && "Requires a valid canonical loop");
4078   assert(ChunkSize && "Chunk size is required");
4079 
4080   LLVMContext &Ctx = CLI->getFunction()->getContext();
4081   Value *IV = CLI->getIndVar();
4082   Value *OrigTripCount = CLI->getTripCount();
4083   Type *IVTy = IV->getType();
4084   assert(IVTy->getIntegerBitWidth() <= 64 &&
4085          "Max supported tripcount bitwidth is 64 bits");
4086   Type *InternalIVTy = IVTy->getIntegerBitWidth() <= 32 ? Type::getInt32Ty(Ctx)
4087                                                         : Type::getInt64Ty(Ctx);
4088   Type *I32Type = Type::getInt32Ty(M.getContext());
4089   Constant *Zero = ConstantInt::get(InternalIVTy, 0);
4090   Constant *One = ConstantInt::get(InternalIVTy, 1);
4091 
4092   // Declare useful OpenMP runtime functions.
4093   FunctionCallee StaticInit =
4094       getKmpcForStaticInitForType(InternalIVTy, M, *this);
4095   FunctionCallee StaticFini =
4096       getOrCreateRuntimeFunction(M, omp::OMPRTL___kmpc_for_static_fini);
4097 
4098   // Allocate space for computed loop bounds as expected by the "init" function.
4099   Builder.restoreIP(AllocaIP);
4100   Builder.SetCurrentDebugLocation(DL);
4101   Value *PLastIter = Builder.CreateAlloca(I32Type, nullptr, "p.lastiter");
4102   Value *PLowerBound =
4103       Builder.CreateAlloca(InternalIVTy, nullptr, "p.lowerbound");
4104   Value *PUpperBound =
4105       Builder.CreateAlloca(InternalIVTy, nullptr, "p.upperbound");
4106   Value *PStride = Builder.CreateAlloca(InternalIVTy, nullptr, "p.stride");
4107 
4108   // Set up the source location value for the OpenMP runtime.
4109   Builder.restoreIP(CLI->getPreheaderIP());
4110   Builder.SetCurrentDebugLocation(DL);
4111 
4112   // TODO: Detect overflow in ubsan or max-out with current tripcount.
4113   Value *CastedChunkSize =
4114       Builder.CreateZExtOrTrunc(ChunkSize, InternalIVTy, "chunksize");
4115   Value *CastedTripCount =
4116       Builder.CreateZExt(OrigTripCount, InternalIVTy, "tripcount");
4117 
4118   Constant *SchedulingType = ConstantInt::get(
4119       I32Type, static_cast<int>(OMPScheduleType::UnorderedStaticChunked));
4120   Builder.CreateStore(Zero, PLowerBound);
4121   Value *OrigUpperBound = Builder.CreateSub(CastedTripCount, One);
4122   Builder.CreateStore(OrigUpperBound, PUpperBound);
4123   Builder.CreateStore(One, PStride);
4124 
4125   // Call the "init" function and update the trip count of the loop with the
4126   // value it produced.
4127   uint32_t SrcLocStrSize;
4128   Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
4129   Value *SrcLoc = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4130   Value *ThreadNum = getOrCreateThreadID(SrcLoc);
4131   Builder.CreateCall(StaticInit,
4132                      {/*loc=*/SrcLoc, /*global_tid=*/ThreadNum,
4133                       /*schedtype=*/SchedulingType, /*plastiter=*/PLastIter,
4134                       /*plower=*/PLowerBound, /*pupper=*/PUpperBound,
4135                       /*pstride=*/PStride, /*incr=*/One,
4136                       /*chunk=*/CastedChunkSize});
4137 
4138   // Load values written by the "init" function.
4139   Value *FirstChunkStart =
4140       Builder.CreateLoad(InternalIVTy, PLowerBound, "omp_firstchunk.lb");
4141   Value *FirstChunkStop =
4142       Builder.CreateLoad(InternalIVTy, PUpperBound, "omp_firstchunk.ub");
4143   Value *FirstChunkEnd = Builder.CreateAdd(FirstChunkStop, One);
4144   Value *ChunkRange =
4145       Builder.CreateSub(FirstChunkEnd, FirstChunkStart, "omp_chunk.range");
4146   Value *NextChunkStride =
4147       Builder.CreateLoad(InternalIVTy, PStride, "omp_dispatch.stride");
4148 
4149   // Create outer "dispatch" loop for enumerating the chunks.
4150   BasicBlock *DispatchEnter = splitBB(Builder, true);
4151   Value *DispatchCounter;
4152   CanonicalLoopInfo *DispatchCLI = createCanonicalLoop(
4153       {Builder.saveIP(), DL},
4154       [&](InsertPointTy BodyIP, Value *Counter) { DispatchCounter = Counter; },
4155       FirstChunkStart, CastedTripCount, NextChunkStride,
4156       /*IsSigned=*/false, /*InclusiveStop=*/false, /*ComputeIP=*/{},
4157       "dispatch");
4158 
4159   // Remember the BasicBlocks of the dispatch loop we need, then invalidate to
4160   // not have to preserve the canonical invariant.
4161   BasicBlock *DispatchBody = DispatchCLI->getBody();
4162   BasicBlock *DispatchLatch = DispatchCLI->getLatch();
4163   BasicBlock *DispatchExit = DispatchCLI->getExit();
4164   BasicBlock *DispatchAfter = DispatchCLI->getAfter();
4165   DispatchCLI->invalidate();
4166 
4167   // Rewire the original loop to become the chunk loop inside the dispatch loop.
4168   redirectTo(DispatchAfter, CLI->getAfter(), DL);
4169   redirectTo(CLI->getExit(), DispatchLatch, DL);
4170   redirectTo(DispatchBody, DispatchEnter, DL);
4171 
4172   // Prepare the prolog of the chunk loop.
4173   Builder.restoreIP(CLI->getPreheaderIP());
4174   Builder.SetCurrentDebugLocation(DL);
4175 
4176   // Compute the number of iterations of the chunk loop.
4177   Builder.SetInsertPoint(CLI->getPreheader()->getTerminator());
4178   Value *ChunkEnd = Builder.CreateAdd(DispatchCounter, ChunkRange);
4179   Value *IsLastChunk =
4180       Builder.CreateICmpUGE(ChunkEnd, CastedTripCount, "omp_chunk.is_last");
4181   Value *CountUntilOrigTripCount =
4182       Builder.CreateSub(CastedTripCount, DispatchCounter);
4183   Value *ChunkTripCount = Builder.CreateSelect(
4184       IsLastChunk, CountUntilOrigTripCount, ChunkRange, "omp_chunk.tripcount");
4185   Value *BackcastedChunkTC =
4186       Builder.CreateTrunc(ChunkTripCount, IVTy, "omp_chunk.tripcount.trunc");
4187   CLI->setTripCount(BackcastedChunkTC);
4188 
4189   // Update all uses of the induction variable except the one in the condition
4190   // block that compares it with the actual upper bound, and the increment in
4191   // the latch block.
4192   Value *BackcastedDispatchCounter =
4193       Builder.CreateTrunc(DispatchCounter, IVTy, "omp_dispatch.iv.trunc");
4194   CLI->mapIndVar([&](Instruction *) -> Value * {
4195     Builder.restoreIP(CLI->getBodyIP());
4196     return Builder.CreateAdd(IV, BackcastedDispatchCounter);
4197   });
4198 
4199   // In the "exit" block, call the "fini" function.
4200   Builder.SetInsertPoint(DispatchExit, DispatchExit->getFirstInsertionPt());
4201   Builder.CreateCall(StaticFini, {SrcLoc, ThreadNum});
4202 
4203   // Add the barrier if requested.
4204   if (NeedsBarrier)
4205     createBarrier(LocationDescription(Builder.saveIP(), DL), OMPD_for,
4206                   /*ForceSimpleCall=*/false, /*CheckCancelFlag=*/false);
4207 
4208 #ifndef NDEBUG
4209   // Even though we currently do not support applying additional methods to it,
4210   // the chunk loop should remain a canonical loop.
4211   CLI->assertOK();
4212 #endif
4213 
4214   return {DispatchAfter, DispatchAfter->getFirstInsertionPt()};
4215 }
4216 
4217 // Returns an LLVM function to call for executing an OpenMP static worksharing
4218 // for loop depending on `type`. Only i32 and i64 are supported by the runtime.
4219 // Always interpret integers as unsigned similarly to CanonicalLoopInfo.
4220 static FunctionCallee
4221 getKmpcForStaticLoopForType(Type *Ty, OpenMPIRBuilder *OMPBuilder,
4222                             WorksharingLoopType LoopType) {
4223   unsigned Bitwidth = Ty->getIntegerBitWidth();
4224   Module &M = OMPBuilder->M;
4225   switch (LoopType) {
4226   case WorksharingLoopType::ForStaticLoop:
4227     if (Bitwidth == 32)
4228       return OMPBuilder->getOrCreateRuntimeFunction(
4229           M, omp::RuntimeFunction::OMPRTL___kmpc_for_static_loop_4u);
4230     if (Bitwidth == 64)
4231       return OMPBuilder->getOrCreateRuntimeFunction(
4232           M, omp::RuntimeFunction::OMPRTL___kmpc_for_static_loop_8u);
4233     break;
4234   case WorksharingLoopType::DistributeStaticLoop:
4235     if (Bitwidth == 32)
4236       return OMPBuilder->getOrCreateRuntimeFunction(
4237           M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_static_loop_4u);
4238     if (Bitwidth == 64)
4239       return OMPBuilder->getOrCreateRuntimeFunction(
4240           M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_static_loop_8u);
4241     break;
4242   case WorksharingLoopType::DistributeForStaticLoop:
4243     if (Bitwidth == 32)
4244       return OMPBuilder->getOrCreateRuntimeFunction(
4245           M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_for_static_loop_4u);
4246     if (Bitwidth == 64)
4247       return OMPBuilder->getOrCreateRuntimeFunction(
4248           M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_for_static_loop_8u);
4249     break;
4250   }
4251   if (Bitwidth != 32 && Bitwidth != 64) {
4252     llvm_unreachable("Unknown OpenMP loop iterator bitwidth");
4253   }
4254   llvm_unreachable("Unknown type of OpenMP worksharing loop");
4255 }
4256 
4257 // Inserts a call to proper OpenMP Device RTL function which handles
4258 // loop worksharing.
4259 static void createTargetLoopWorkshareCall(
4260     OpenMPIRBuilder *OMPBuilder, WorksharingLoopType LoopType,
4261     BasicBlock *InsertBlock, Value *Ident, Value *LoopBodyArg,
4262     Type *ParallelTaskPtr, Value *TripCount, Function &LoopBodyFn) {
4263   Type *TripCountTy = TripCount->getType();
4264   Module &M = OMPBuilder->M;
4265   IRBuilder<> &Builder = OMPBuilder->Builder;
4266   FunctionCallee RTLFn =
4267       getKmpcForStaticLoopForType(TripCountTy, OMPBuilder, LoopType);
4268   SmallVector<Value *, 8> RealArgs;
4269   RealArgs.push_back(Ident);
4270   RealArgs.push_back(Builder.CreateBitCast(&LoopBodyFn, ParallelTaskPtr));
4271   RealArgs.push_back(LoopBodyArg);
4272   RealArgs.push_back(TripCount);
4273   if (LoopType == WorksharingLoopType::DistributeStaticLoop) {
4274     RealArgs.push_back(ConstantInt::get(TripCountTy, 0));
4275     Builder.CreateCall(RTLFn, RealArgs);
4276     return;
4277   }
4278   FunctionCallee RTLNumThreads = OMPBuilder->getOrCreateRuntimeFunction(
4279       M, omp::RuntimeFunction::OMPRTL_omp_get_num_threads);
4280   Builder.restoreIP({InsertBlock, std::prev(InsertBlock->end())});
4281   Value *NumThreads = Builder.CreateCall(RTLNumThreads, {});
4282 
4283   RealArgs.push_back(
4284       Builder.CreateZExtOrTrunc(NumThreads, TripCountTy, "num.threads.cast"));
4285   RealArgs.push_back(ConstantInt::get(TripCountTy, 0));
4286   if (LoopType == WorksharingLoopType::DistributeForStaticLoop) {
4287     RealArgs.push_back(ConstantInt::get(TripCountTy, 0));
4288   }
4289 
4290   Builder.CreateCall(RTLFn, RealArgs);
4291 }
4292 
4293 static void
4294 workshareLoopTargetCallback(OpenMPIRBuilder *OMPIRBuilder,
4295                             CanonicalLoopInfo *CLI, Value *Ident,
4296                             Function &OutlinedFn, Type *ParallelTaskPtr,
4297                             const SmallVector<Instruction *, 4> &ToBeDeleted,
4298                             WorksharingLoopType LoopType) {
4299   IRBuilder<> &Builder = OMPIRBuilder->Builder;
4300   BasicBlock *Preheader = CLI->getPreheader();
4301   Value *TripCount = CLI->getTripCount();
4302 
4303   // After loop body outling, the loop body contains only set up
4304   // of loop body argument structure and the call to the outlined
4305   // loop body function. Firstly, we need to move setup of loop body args
4306   // into loop preheader.
4307   Preheader->splice(std::prev(Preheader->end()), CLI->getBody(),
4308                     CLI->getBody()->begin(), std::prev(CLI->getBody()->end()));
4309 
4310   // The next step is to remove the whole loop. We do not it need anymore.
4311   // That's why make an unconditional branch from loop preheader to loop
4312   // exit block
4313   Builder.restoreIP({Preheader, Preheader->end()});
4314   Preheader->getTerminator()->eraseFromParent();
4315   Builder.CreateBr(CLI->getExit());
4316 
4317   // Delete dead loop blocks
4318   OpenMPIRBuilder::OutlineInfo CleanUpInfo;
4319   SmallPtrSet<BasicBlock *, 32> RegionBlockSet;
4320   SmallVector<BasicBlock *, 32> BlocksToBeRemoved;
4321   CleanUpInfo.EntryBB = CLI->getHeader();
4322   CleanUpInfo.ExitBB = CLI->getExit();
4323   CleanUpInfo.collectBlocks(RegionBlockSet, BlocksToBeRemoved);
4324   DeleteDeadBlocks(BlocksToBeRemoved);
4325 
4326   // Find the instruction which corresponds to loop body argument structure
4327   // and remove the call to loop body function instruction.
4328   Value *LoopBodyArg;
4329   User *OutlinedFnUser = OutlinedFn.getUniqueUndroppableUser();
4330   assert(OutlinedFnUser &&
4331          "Expected unique undroppable user of outlined function");
4332   CallInst *OutlinedFnCallInstruction = dyn_cast<CallInst>(OutlinedFnUser);
4333   assert(OutlinedFnCallInstruction && "Expected outlined function call");
4334   assert((OutlinedFnCallInstruction->getParent() == Preheader) &&
4335          "Expected outlined function call to be located in loop preheader");
4336   // Check in case no argument structure has been passed.
4337   if (OutlinedFnCallInstruction->arg_size() > 1)
4338     LoopBodyArg = OutlinedFnCallInstruction->getArgOperand(1);
4339   else
4340     LoopBodyArg = Constant::getNullValue(Builder.getPtrTy());
4341   OutlinedFnCallInstruction->eraseFromParent();
4342 
4343   createTargetLoopWorkshareCall(OMPIRBuilder, LoopType, Preheader, Ident,
4344                                 LoopBodyArg, ParallelTaskPtr, TripCount,
4345                                 OutlinedFn);
4346 
4347   for (auto &ToBeDeletedItem : ToBeDeleted)
4348     ToBeDeletedItem->eraseFromParent();
4349   CLI->invalidate();
4350 }
4351 
4352 OpenMPIRBuilder::InsertPointTy
4353 OpenMPIRBuilder::applyWorkshareLoopTarget(DebugLoc DL, CanonicalLoopInfo *CLI,
4354                                           InsertPointTy AllocaIP,
4355                                           WorksharingLoopType LoopType) {
4356   uint32_t SrcLocStrSize;
4357   Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
4358   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4359 
4360   OutlineInfo OI;
4361   OI.OuterAllocaBB = CLI->getPreheader();
4362   Function *OuterFn = CLI->getPreheader()->getParent();
4363 
4364   // Instructions which need to be deleted at the end of code generation
4365   SmallVector<Instruction *, 4> ToBeDeleted;
4366 
4367   OI.OuterAllocaBB = AllocaIP.getBlock();
4368 
4369   // Mark the body loop as region which needs to be extracted
4370   OI.EntryBB = CLI->getBody();
4371   OI.ExitBB = CLI->getLatch()->splitBasicBlock(CLI->getLatch()->begin(),
4372                                                "omp.prelatch", true);
4373 
4374   // Prepare loop body for extraction
4375   Builder.restoreIP({CLI->getPreheader(), CLI->getPreheader()->begin()});
4376 
4377   // Insert new loop counter variable which will be used only in loop
4378   // body.
4379   AllocaInst *NewLoopCnt = Builder.CreateAlloca(CLI->getIndVarType(), 0, "");
4380   Instruction *NewLoopCntLoad =
4381       Builder.CreateLoad(CLI->getIndVarType(), NewLoopCnt);
4382   // New loop counter instructions are redundant in the loop preheader when
4383   // code generation for workshare loop is finshed. That's why mark them as
4384   // ready for deletion.
4385   ToBeDeleted.push_back(NewLoopCntLoad);
4386   ToBeDeleted.push_back(NewLoopCnt);
4387 
4388   // Analyse loop body region. Find all input variables which are used inside
4389   // loop body region.
4390   SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
4391   SmallVector<BasicBlock *, 32> Blocks;
4392   OI.collectBlocks(ParallelRegionBlockSet, Blocks);
4393   SmallVector<BasicBlock *, 32> BlocksT(ParallelRegionBlockSet.begin(),
4394                                         ParallelRegionBlockSet.end());
4395 
4396   CodeExtractorAnalysisCache CEAC(*OuterFn);
4397   CodeExtractor Extractor(Blocks,
4398                           /* DominatorTree */ nullptr,
4399                           /* AggregateArgs */ true,
4400                           /* BlockFrequencyInfo */ nullptr,
4401                           /* BranchProbabilityInfo */ nullptr,
4402                           /* AssumptionCache */ nullptr,
4403                           /* AllowVarArgs */ true,
4404                           /* AllowAlloca */ true,
4405                           /* AllocationBlock */ CLI->getPreheader(),
4406                           /* Suffix */ ".omp_wsloop",
4407                           /* AggrArgsIn0AddrSpace */ true);
4408 
4409   BasicBlock *CommonExit = nullptr;
4410   SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
4411 
4412   // Find allocas outside the loop body region which are used inside loop
4413   // body
4414   Extractor.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
4415 
4416   // We need to model loop body region as the function f(cnt, loop_arg).
4417   // That's why we replace loop induction variable by the new counter
4418   // which will be one of loop body function argument
4419   SmallVector<User *> Users(CLI->getIndVar()->user_begin(),
4420                             CLI->getIndVar()->user_end());
4421   for (auto Use : Users) {
4422     if (Instruction *Inst = dyn_cast<Instruction>(Use)) {
4423       if (ParallelRegionBlockSet.count(Inst->getParent())) {
4424         Inst->replaceUsesOfWith(CLI->getIndVar(), NewLoopCntLoad);
4425       }
4426     }
4427   }
4428   // Make sure that loop counter variable is not merged into loop body
4429   // function argument structure and it is passed as separate variable
4430   OI.ExcludeArgsFromAggregate.push_back(NewLoopCntLoad);
4431 
4432   // PostOutline CB is invoked when loop body function is outlined and
4433   // loop body is replaced by call to outlined function. We need to add
4434   // call to OpenMP device rtl inside loop preheader. OpenMP device rtl
4435   // function will handle loop control logic.
4436   //
4437   OI.PostOutlineCB = [=, ToBeDeletedVec =
4438                              std::move(ToBeDeleted)](Function &OutlinedFn) {
4439     workshareLoopTargetCallback(this, CLI, Ident, OutlinedFn, ParallelTaskPtr,
4440                                 ToBeDeletedVec, LoopType);
4441   };
4442   addOutlineInfo(std::move(OI));
4443   return CLI->getAfterIP();
4444 }
4445 
4446 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyWorkshareLoop(
4447     DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
4448     bool NeedsBarrier, omp::ScheduleKind SchedKind, Value *ChunkSize,
4449     bool HasSimdModifier, bool HasMonotonicModifier,
4450     bool HasNonmonotonicModifier, bool HasOrderedClause,
4451     WorksharingLoopType LoopType) {
4452   if (Config.isTargetDevice())
4453     return applyWorkshareLoopTarget(DL, CLI, AllocaIP, LoopType);
4454   OMPScheduleType EffectiveScheduleType = computeOpenMPScheduleType(
4455       SchedKind, ChunkSize, HasSimdModifier, HasMonotonicModifier,
4456       HasNonmonotonicModifier, HasOrderedClause);
4457 
4458   bool IsOrdered = (EffectiveScheduleType & OMPScheduleType::ModifierOrdered) ==
4459                    OMPScheduleType::ModifierOrdered;
4460   switch (EffectiveScheduleType & ~OMPScheduleType::ModifierMask) {
4461   case OMPScheduleType::BaseStatic:
4462     assert(!ChunkSize && "No chunk size with static-chunked schedule");
4463     if (IsOrdered)
4464       return applyDynamicWorkshareLoop(DL, CLI, AllocaIP, EffectiveScheduleType,
4465                                        NeedsBarrier, ChunkSize);
4466     // FIXME: Monotonicity ignored?
4467     return applyStaticWorkshareLoop(DL, CLI, AllocaIP, NeedsBarrier);
4468 
4469   case OMPScheduleType::BaseStaticChunked:
4470     if (IsOrdered)
4471       return applyDynamicWorkshareLoop(DL, CLI, AllocaIP, EffectiveScheduleType,
4472                                        NeedsBarrier, ChunkSize);
4473     // FIXME: Monotonicity ignored?
4474     return applyStaticChunkedWorkshareLoop(DL, CLI, AllocaIP, NeedsBarrier,
4475                                            ChunkSize);
4476 
4477   case OMPScheduleType::BaseRuntime:
4478   case OMPScheduleType::BaseAuto:
4479   case OMPScheduleType::BaseGreedy:
4480   case OMPScheduleType::BaseBalanced:
4481   case OMPScheduleType::BaseSteal:
4482   case OMPScheduleType::BaseGuidedSimd:
4483   case OMPScheduleType::BaseRuntimeSimd:
4484     assert(!ChunkSize &&
4485            "schedule type does not support user-defined chunk sizes");
4486     [[fallthrough]];
4487   case OMPScheduleType::BaseDynamicChunked:
4488   case OMPScheduleType::BaseGuidedChunked:
4489   case OMPScheduleType::BaseGuidedIterativeChunked:
4490   case OMPScheduleType::BaseGuidedAnalyticalChunked:
4491   case OMPScheduleType::BaseStaticBalancedChunked:
4492     return applyDynamicWorkshareLoop(DL, CLI, AllocaIP, EffectiveScheduleType,
4493                                      NeedsBarrier, ChunkSize);
4494 
4495   default:
4496     llvm_unreachable("Unknown/unimplemented schedule kind");
4497   }
4498 }
4499 
4500 /// Returns an LLVM function to call for initializing loop bounds using OpenMP
4501 /// dynamic scheduling depending on `type`. Only i32 and i64 are supported by
4502 /// the runtime. Always interpret integers as unsigned similarly to
4503 /// CanonicalLoopInfo.
4504 static FunctionCallee
4505 getKmpcForDynamicInitForType(Type *Ty, Module &M, OpenMPIRBuilder &OMPBuilder) {
4506   unsigned Bitwidth = Ty->getIntegerBitWidth();
4507   if (Bitwidth == 32)
4508     return OMPBuilder.getOrCreateRuntimeFunction(
4509         M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_init_4u);
4510   if (Bitwidth == 64)
4511     return OMPBuilder.getOrCreateRuntimeFunction(
4512         M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_init_8u);
4513   llvm_unreachable("unknown OpenMP loop iterator bitwidth");
4514 }
4515 
4516 /// Returns an LLVM function to call for updating the next loop using OpenMP
4517 /// dynamic scheduling depending on `type`. Only i32 and i64 are supported by
4518 /// the runtime. Always interpret integers as unsigned similarly to
4519 /// CanonicalLoopInfo.
4520 static FunctionCallee
4521 getKmpcForDynamicNextForType(Type *Ty, Module &M, OpenMPIRBuilder &OMPBuilder) {
4522   unsigned Bitwidth = Ty->getIntegerBitWidth();
4523   if (Bitwidth == 32)
4524     return OMPBuilder.getOrCreateRuntimeFunction(
4525         M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_next_4u);
4526   if (Bitwidth == 64)
4527     return OMPBuilder.getOrCreateRuntimeFunction(
4528         M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_next_8u);
4529   llvm_unreachable("unknown OpenMP loop iterator bitwidth");
4530 }
4531 
4532 /// Returns an LLVM function to call for finalizing the dynamic loop using
4533 /// depending on `type`. Only i32 and i64 are supported by the runtime. Always
4534 /// interpret integers as unsigned similarly to CanonicalLoopInfo.
4535 static FunctionCallee
4536 getKmpcForDynamicFiniForType(Type *Ty, Module &M, OpenMPIRBuilder &OMPBuilder) {
4537   unsigned Bitwidth = Ty->getIntegerBitWidth();
4538   if (Bitwidth == 32)
4539     return OMPBuilder.getOrCreateRuntimeFunction(
4540         M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_fini_4u);
4541   if (Bitwidth == 64)
4542     return OMPBuilder.getOrCreateRuntimeFunction(
4543         M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_fini_8u);
4544   llvm_unreachable("unknown OpenMP loop iterator bitwidth");
4545 }
4546 
4547 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyDynamicWorkshareLoop(
4548     DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
4549     OMPScheduleType SchedType, bool NeedsBarrier, Value *Chunk) {
4550   assert(CLI->isValid() && "Requires a valid canonical loop");
4551   assert(!isConflictIP(AllocaIP, CLI->getPreheaderIP()) &&
4552          "Require dedicated allocate IP");
4553   assert(isValidWorkshareLoopScheduleType(SchedType) &&
4554          "Require valid schedule type");
4555 
4556   bool Ordered = (SchedType & OMPScheduleType::ModifierOrdered) ==
4557                  OMPScheduleType::ModifierOrdered;
4558 
4559   // Set up the source location value for OpenMP runtime.
4560   Builder.SetCurrentDebugLocation(DL);
4561 
4562   uint32_t SrcLocStrSize;
4563   Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
4564   Value *SrcLoc = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4565 
4566   // Declare useful OpenMP runtime functions.
4567   Value *IV = CLI->getIndVar();
4568   Type *IVTy = IV->getType();
4569   FunctionCallee DynamicInit = getKmpcForDynamicInitForType(IVTy, M, *this);
4570   FunctionCallee DynamicNext = getKmpcForDynamicNextForType(IVTy, M, *this);
4571 
4572   // Allocate space for computed loop bounds as expected by the "init" function.
4573   Builder.SetInsertPoint(AllocaIP.getBlock()->getFirstNonPHIOrDbgOrAlloca());
4574   Type *I32Type = Type::getInt32Ty(M.getContext());
4575   Value *PLastIter = Builder.CreateAlloca(I32Type, nullptr, "p.lastiter");
4576   Value *PLowerBound = Builder.CreateAlloca(IVTy, nullptr, "p.lowerbound");
4577   Value *PUpperBound = Builder.CreateAlloca(IVTy, nullptr, "p.upperbound");
4578   Value *PStride = Builder.CreateAlloca(IVTy, nullptr, "p.stride");
4579 
4580   // At the end of the preheader, prepare for calling the "init" function by
4581   // storing the current loop bounds into the allocated space. A canonical loop
4582   // always iterates from 0 to trip-count with step 1. Note that "init" expects
4583   // and produces an inclusive upper bound.
4584   BasicBlock *PreHeader = CLI->getPreheader();
4585   Builder.SetInsertPoint(PreHeader->getTerminator());
4586   Constant *One = ConstantInt::get(IVTy, 1);
4587   Builder.CreateStore(One, PLowerBound);
4588   Value *UpperBound = CLI->getTripCount();
4589   Builder.CreateStore(UpperBound, PUpperBound);
4590   Builder.CreateStore(One, PStride);
4591 
4592   BasicBlock *Header = CLI->getHeader();
4593   BasicBlock *Exit = CLI->getExit();
4594   BasicBlock *Cond = CLI->getCond();
4595   BasicBlock *Latch = CLI->getLatch();
4596   InsertPointTy AfterIP = CLI->getAfterIP();
4597 
4598   // The CLI will be "broken" in the code below, as the loop is no longer
4599   // a valid canonical loop.
4600 
4601   if (!Chunk)
4602     Chunk = One;
4603 
4604   Value *ThreadNum = getOrCreateThreadID(SrcLoc);
4605 
4606   Constant *SchedulingType =
4607       ConstantInt::get(I32Type, static_cast<int>(SchedType));
4608 
4609   // Call the "init" function.
4610   Builder.CreateCall(DynamicInit,
4611                      {SrcLoc, ThreadNum, SchedulingType, /* LowerBound */ One,
4612                       UpperBound, /* step */ One, Chunk});
4613 
4614   // An outer loop around the existing one.
4615   BasicBlock *OuterCond = BasicBlock::Create(
4616       PreHeader->getContext(), Twine(PreHeader->getName()) + ".outer.cond",
4617       PreHeader->getParent());
4618   // This needs to be 32-bit always, so can't use the IVTy Zero above.
4619   Builder.SetInsertPoint(OuterCond, OuterCond->getFirstInsertionPt());
4620   Value *Res =
4621       Builder.CreateCall(DynamicNext, {SrcLoc, ThreadNum, PLastIter,
4622                                        PLowerBound, PUpperBound, PStride});
4623   Constant *Zero32 = ConstantInt::get(I32Type, 0);
4624   Value *MoreWork = Builder.CreateCmp(CmpInst::ICMP_NE, Res, Zero32);
4625   Value *LowerBound =
4626       Builder.CreateSub(Builder.CreateLoad(IVTy, PLowerBound), One, "lb");
4627   Builder.CreateCondBr(MoreWork, Header, Exit);
4628 
4629   // Change PHI-node in loop header to use outer cond rather than preheader,
4630   // and set IV to the LowerBound.
4631   Instruction *Phi = &Header->front();
4632   auto *PI = cast<PHINode>(Phi);
4633   PI->setIncomingBlock(0, OuterCond);
4634   PI->setIncomingValue(0, LowerBound);
4635 
4636   // Then set the pre-header to jump to the OuterCond
4637   Instruction *Term = PreHeader->getTerminator();
4638   auto *Br = cast<BranchInst>(Term);
4639   Br->setSuccessor(0, OuterCond);
4640 
4641   // Modify the inner condition:
4642   // * Use the UpperBound returned from the DynamicNext call.
4643   // * jump to the loop outer loop when done with one of the inner loops.
4644   Builder.SetInsertPoint(Cond, Cond->getFirstInsertionPt());
4645   UpperBound = Builder.CreateLoad(IVTy, PUpperBound, "ub");
4646   Instruction *Comp = &*Builder.GetInsertPoint();
4647   auto *CI = cast<CmpInst>(Comp);
4648   CI->setOperand(1, UpperBound);
4649   // Redirect the inner exit to branch to outer condition.
4650   Instruction *Branch = &Cond->back();
4651   auto *BI = cast<BranchInst>(Branch);
4652   assert(BI->getSuccessor(1) == Exit);
4653   BI->setSuccessor(1, OuterCond);
4654 
4655   // Call the "fini" function if "ordered" is present in wsloop directive.
4656   if (Ordered) {
4657     Builder.SetInsertPoint(&Latch->back());
4658     FunctionCallee DynamicFini = getKmpcForDynamicFiniForType(IVTy, M, *this);
4659     Builder.CreateCall(DynamicFini, {SrcLoc, ThreadNum});
4660   }
4661 
4662   // Add the barrier if requested.
4663   if (NeedsBarrier) {
4664     Builder.SetInsertPoint(&Exit->back());
4665     createBarrier(LocationDescription(Builder.saveIP(), DL),
4666                   omp::Directive::OMPD_for, /* ForceSimpleCall */ false,
4667                   /* CheckCancelFlag */ false);
4668   }
4669 
4670   CLI->invalidate();
4671   return AfterIP;
4672 }
4673 
4674 /// Redirect all edges that branch to \p OldTarget to \p NewTarget. That is,
4675 /// after this \p OldTarget will be orphaned.
4676 static void redirectAllPredecessorsTo(BasicBlock *OldTarget,
4677                                       BasicBlock *NewTarget, DebugLoc DL) {
4678   for (BasicBlock *Pred : make_early_inc_range(predecessors(OldTarget)))
4679     redirectTo(Pred, NewTarget, DL);
4680 }
4681 
4682 /// Determine which blocks in \p BBs are reachable from outside and remove the
4683 /// ones that are not reachable from the function.
4684 static void removeUnusedBlocksFromParent(ArrayRef<BasicBlock *> BBs) {
4685   SmallPtrSet<BasicBlock *, 6> BBsToErase{BBs.begin(), BBs.end()};
4686   auto HasRemainingUses = [&BBsToErase](BasicBlock *BB) {
4687     for (Use &U : BB->uses()) {
4688       auto *UseInst = dyn_cast<Instruction>(U.getUser());
4689       if (!UseInst)
4690         continue;
4691       if (BBsToErase.count(UseInst->getParent()))
4692         continue;
4693       return true;
4694     }
4695     return false;
4696   };
4697 
4698   while (BBsToErase.remove_if(HasRemainingUses)) {
4699     // Try again if anything was removed.
4700   }
4701 
4702   SmallVector<BasicBlock *, 7> BBVec(BBsToErase.begin(), BBsToErase.end());
4703   DeleteDeadBlocks(BBVec);
4704 }
4705 
4706 CanonicalLoopInfo *
4707 OpenMPIRBuilder::collapseLoops(DebugLoc DL, ArrayRef<CanonicalLoopInfo *> Loops,
4708                                InsertPointTy ComputeIP) {
4709   assert(Loops.size() >= 1 && "At least one loop required");
4710   size_t NumLoops = Loops.size();
4711 
4712   // Nothing to do if there is already just one loop.
4713   if (NumLoops == 1)
4714     return Loops.front();
4715 
4716   CanonicalLoopInfo *Outermost = Loops.front();
4717   CanonicalLoopInfo *Innermost = Loops.back();
4718   BasicBlock *OrigPreheader = Outermost->getPreheader();
4719   BasicBlock *OrigAfter = Outermost->getAfter();
4720   Function *F = OrigPreheader->getParent();
4721 
4722   // Loop control blocks that may become orphaned later.
4723   SmallVector<BasicBlock *, 12> OldControlBBs;
4724   OldControlBBs.reserve(6 * Loops.size());
4725   for (CanonicalLoopInfo *Loop : Loops)
4726     Loop->collectControlBlocks(OldControlBBs);
4727 
4728   // Setup the IRBuilder for inserting the trip count computation.
4729   Builder.SetCurrentDebugLocation(DL);
4730   if (ComputeIP.isSet())
4731     Builder.restoreIP(ComputeIP);
4732   else
4733     Builder.restoreIP(Outermost->getPreheaderIP());
4734 
4735   // Derive the collapsed' loop trip count.
4736   // TODO: Find common/largest indvar type.
4737   Value *CollapsedTripCount = nullptr;
4738   for (CanonicalLoopInfo *L : Loops) {
4739     assert(L->isValid() &&
4740            "All loops to collapse must be valid canonical loops");
4741     Value *OrigTripCount = L->getTripCount();
4742     if (!CollapsedTripCount) {
4743       CollapsedTripCount = OrigTripCount;
4744       continue;
4745     }
4746 
4747     // TODO: Enable UndefinedSanitizer to diagnose an overflow here.
4748     CollapsedTripCount = Builder.CreateMul(CollapsedTripCount, OrigTripCount,
4749                                            {}, /*HasNUW=*/true);
4750   }
4751 
4752   // Create the collapsed loop control flow.
4753   CanonicalLoopInfo *Result =
4754       createLoopSkeleton(DL, CollapsedTripCount, F,
4755                          OrigPreheader->getNextNode(), OrigAfter, "collapsed");
4756 
4757   // Build the collapsed loop body code.
4758   // Start with deriving the input loop induction variables from the collapsed
4759   // one, using a divmod scheme. To preserve the original loops' order, the
4760   // innermost loop use the least significant bits.
4761   Builder.restoreIP(Result->getBodyIP());
4762 
4763   Value *Leftover = Result->getIndVar();
4764   SmallVector<Value *> NewIndVars;
4765   NewIndVars.resize(NumLoops);
4766   for (int i = NumLoops - 1; i >= 1; --i) {
4767     Value *OrigTripCount = Loops[i]->getTripCount();
4768 
4769     Value *NewIndVar = Builder.CreateURem(Leftover, OrigTripCount);
4770     NewIndVars[i] = NewIndVar;
4771 
4772     Leftover = Builder.CreateUDiv(Leftover, OrigTripCount);
4773   }
4774   // Outermost loop gets all the remaining bits.
4775   NewIndVars[0] = Leftover;
4776 
4777   // Construct the loop body control flow.
4778   // We progressively construct the branch structure following in direction of
4779   // the control flow, from the leading in-between code, the loop nest body, the
4780   // trailing in-between code, and rejoining the collapsed loop's latch.
4781   // ContinueBlock and ContinuePred keep track of the source(s) of next edge. If
4782   // the ContinueBlock is set, continue with that block. If ContinuePred, use
4783   // its predecessors as sources.
4784   BasicBlock *ContinueBlock = Result->getBody();
4785   BasicBlock *ContinuePred = nullptr;
4786   auto ContinueWith = [&ContinueBlock, &ContinuePred, DL](BasicBlock *Dest,
4787                                                           BasicBlock *NextSrc) {
4788     if (ContinueBlock)
4789       redirectTo(ContinueBlock, Dest, DL);
4790     else
4791       redirectAllPredecessorsTo(ContinuePred, Dest, DL);
4792 
4793     ContinueBlock = nullptr;
4794     ContinuePred = NextSrc;
4795   };
4796 
4797   // The code before the nested loop of each level.
4798   // Because we are sinking it into the nest, it will be executed more often
4799   // that the original loop. More sophisticated schemes could keep track of what
4800   // the in-between code is and instantiate it only once per thread.
4801   for (size_t i = 0; i < NumLoops - 1; ++i)
4802     ContinueWith(Loops[i]->getBody(), Loops[i + 1]->getHeader());
4803 
4804   // Connect the loop nest body.
4805   ContinueWith(Innermost->getBody(), Innermost->getLatch());
4806 
4807   // The code after the nested loop at each level.
4808   for (size_t i = NumLoops - 1; i > 0; --i)
4809     ContinueWith(Loops[i]->getAfter(), Loops[i - 1]->getLatch());
4810 
4811   // Connect the finished loop to the collapsed loop latch.
4812   ContinueWith(Result->getLatch(), nullptr);
4813 
4814   // Replace the input loops with the new collapsed loop.
4815   redirectTo(Outermost->getPreheader(), Result->getPreheader(), DL);
4816   redirectTo(Result->getAfter(), Outermost->getAfter(), DL);
4817 
4818   // Replace the input loop indvars with the derived ones.
4819   for (size_t i = 0; i < NumLoops; ++i)
4820     Loops[i]->getIndVar()->replaceAllUsesWith(NewIndVars[i]);
4821 
4822   // Remove unused parts of the input loops.
4823   removeUnusedBlocksFromParent(OldControlBBs);
4824 
4825   for (CanonicalLoopInfo *L : Loops)
4826     L->invalidate();
4827 
4828 #ifndef NDEBUG
4829   Result->assertOK();
4830 #endif
4831   return Result;
4832 }
4833 
4834 std::vector<CanonicalLoopInfo *>
4835 OpenMPIRBuilder::tileLoops(DebugLoc DL, ArrayRef<CanonicalLoopInfo *> Loops,
4836                            ArrayRef<Value *> TileSizes) {
4837   assert(TileSizes.size() == Loops.size() &&
4838          "Must pass as many tile sizes as there are loops");
4839   int NumLoops = Loops.size();
4840   assert(NumLoops >= 1 && "At least one loop to tile required");
4841 
4842   CanonicalLoopInfo *OutermostLoop = Loops.front();
4843   CanonicalLoopInfo *InnermostLoop = Loops.back();
4844   Function *F = OutermostLoop->getBody()->getParent();
4845   BasicBlock *InnerEnter = InnermostLoop->getBody();
4846   BasicBlock *InnerLatch = InnermostLoop->getLatch();
4847 
4848   // Loop control blocks that may become orphaned later.
4849   SmallVector<BasicBlock *, 12> OldControlBBs;
4850   OldControlBBs.reserve(6 * Loops.size());
4851   for (CanonicalLoopInfo *Loop : Loops)
4852     Loop->collectControlBlocks(OldControlBBs);
4853 
4854   // Collect original trip counts and induction variable to be accessible by
4855   // index. Also, the structure of the original loops is not preserved during
4856   // the construction of the tiled loops, so do it before we scavenge the BBs of
4857   // any original CanonicalLoopInfo.
4858   SmallVector<Value *, 4> OrigTripCounts, OrigIndVars;
4859   for (CanonicalLoopInfo *L : Loops) {
4860     assert(L->isValid() && "All input loops must be valid canonical loops");
4861     OrigTripCounts.push_back(L->getTripCount());
4862     OrigIndVars.push_back(L->getIndVar());
4863   }
4864 
4865   // Collect the code between loop headers. These may contain SSA definitions
4866   // that are used in the loop nest body. To be usable with in the innermost
4867   // body, these BasicBlocks will be sunk into the loop nest body. That is,
4868   // these instructions may be executed more often than before the tiling.
4869   // TODO: It would be sufficient to only sink them into body of the
4870   // corresponding tile loop.
4871   SmallVector<std::pair<BasicBlock *, BasicBlock *>, 4> InbetweenCode;
4872   for (int i = 0; i < NumLoops - 1; ++i) {
4873     CanonicalLoopInfo *Surrounding = Loops[i];
4874     CanonicalLoopInfo *Nested = Loops[i + 1];
4875 
4876     BasicBlock *EnterBB = Surrounding->getBody();
4877     BasicBlock *ExitBB = Nested->getHeader();
4878     InbetweenCode.emplace_back(EnterBB, ExitBB);
4879   }
4880 
4881   // Compute the trip counts of the floor loops.
4882   Builder.SetCurrentDebugLocation(DL);
4883   Builder.restoreIP(OutermostLoop->getPreheaderIP());
4884   SmallVector<Value *, 4> FloorCount, FloorRems;
4885   for (int i = 0; i < NumLoops; ++i) {
4886     Value *TileSize = TileSizes[i];
4887     Value *OrigTripCount = OrigTripCounts[i];
4888     Type *IVType = OrigTripCount->getType();
4889 
4890     Value *FloorTripCount = Builder.CreateUDiv(OrigTripCount, TileSize);
4891     Value *FloorTripRem = Builder.CreateURem(OrigTripCount, TileSize);
4892 
4893     // 0 if tripcount divides the tilesize, 1 otherwise.
4894     // 1 means we need an additional iteration for a partial tile.
4895     //
4896     // Unfortunately we cannot just use the roundup-formula
4897     //   (tripcount + tilesize - 1)/tilesize
4898     // because the summation might overflow. We do not want introduce undefined
4899     // behavior when the untiled loop nest did not.
4900     Value *FloorTripOverflow =
4901         Builder.CreateICmpNE(FloorTripRem, ConstantInt::get(IVType, 0));
4902 
4903     FloorTripOverflow = Builder.CreateZExt(FloorTripOverflow, IVType);
4904     FloorTripCount =
4905         Builder.CreateAdd(FloorTripCount, FloorTripOverflow,
4906                           "omp_floor" + Twine(i) + ".tripcount", true);
4907 
4908     // Remember some values for later use.
4909     FloorCount.push_back(FloorTripCount);
4910     FloorRems.push_back(FloorTripRem);
4911   }
4912 
4913   // Generate the new loop nest, from the outermost to the innermost.
4914   std::vector<CanonicalLoopInfo *> Result;
4915   Result.reserve(NumLoops * 2);
4916 
4917   // The basic block of the surrounding loop that enters the nest generated
4918   // loop.
4919   BasicBlock *Enter = OutermostLoop->getPreheader();
4920 
4921   // The basic block of the surrounding loop where the inner code should
4922   // continue.
4923   BasicBlock *Continue = OutermostLoop->getAfter();
4924 
4925   // Where the next loop basic block should be inserted.
4926   BasicBlock *OutroInsertBefore = InnermostLoop->getExit();
4927 
4928   auto EmbeddNewLoop =
4929       [this, DL, F, InnerEnter, &Enter, &Continue, &OutroInsertBefore](
4930           Value *TripCount, const Twine &Name) -> CanonicalLoopInfo * {
4931     CanonicalLoopInfo *EmbeddedLoop = createLoopSkeleton(
4932         DL, TripCount, F, InnerEnter, OutroInsertBefore, Name);
4933     redirectTo(Enter, EmbeddedLoop->getPreheader(), DL);
4934     redirectTo(EmbeddedLoop->getAfter(), Continue, DL);
4935 
4936     // Setup the position where the next embedded loop connects to this loop.
4937     Enter = EmbeddedLoop->getBody();
4938     Continue = EmbeddedLoop->getLatch();
4939     OutroInsertBefore = EmbeddedLoop->getLatch();
4940     return EmbeddedLoop;
4941   };
4942 
4943   auto EmbeddNewLoops = [&Result, &EmbeddNewLoop](ArrayRef<Value *> TripCounts,
4944                                                   const Twine &NameBase) {
4945     for (auto P : enumerate(TripCounts)) {
4946       CanonicalLoopInfo *EmbeddedLoop =
4947           EmbeddNewLoop(P.value(), NameBase + Twine(P.index()));
4948       Result.push_back(EmbeddedLoop);
4949     }
4950   };
4951 
4952   EmbeddNewLoops(FloorCount, "floor");
4953 
4954   // Within the innermost floor loop, emit the code that computes the tile
4955   // sizes.
4956   Builder.SetInsertPoint(Enter->getTerminator());
4957   SmallVector<Value *, 4> TileCounts;
4958   for (int i = 0; i < NumLoops; ++i) {
4959     CanonicalLoopInfo *FloorLoop = Result[i];
4960     Value *TileSize = TileSizes[i];
4961 
4962     Value *FloorIsEpilogue =
4963         Builder.CreateICmpEQ(FloorLoop->getIndVar(), FloorCount[i]);
4964     Value *TileTripCount =
4965         Builder.CreateSelect(FloorIsEpilogue, FloorRems[i], TileSize);
4966 
4967     TileCounts.push_back(TileTripCount);
4968   }
4969 
4970   // Create the tile loops.
4971   EmbeddNewLoops(TileCounts, "tile");
4972 
4973   // Insert the inbetween code into the body.
4974   BasicBlock *BodyEnter = Enter;
4975   BasicBlock *BodyEntered = nullptr;
4976   for (std::pair<BasicBlock *, BasicBlock *> P : InbetweenCode) {
4977     BasicBlock *EnterBB = P.first;
4978     BasicBlock *ExitBB = P.second;
4979 
4980     if (BodyEnter)
4981       redirectTo(BodyEnter, EnterBB, DL);
4982     else
4983       redirectAllPredecessorsTo(BodyEntered, EnterBB, DL);
4984 
4985     BodyEnter = nullptr;
4986     BodyEntered = ExitBB;
4987   }
4988 
4989   // Append the original loop nest body into the generated loop nest body.
4990   if (BodyEnter)
4991     redirectTo(BodyEnter, InnerEnter, DL);
4992   else
4993     redirectAllPredecessorsTo(BodyEntered, InnerEnter, DL);
4994   redirectAllPredecessorsTo(InnerLatch, Continue, DL);
4995 
4996   // Replace the original induction variable with an induction variable computed
4997   // from the tile and floor induction variables.
4998   Builder.restoreIP(Result.back()->getBodyIP());
4999   for (int i = 0; i < NumLoops; ++i) {
5000     CanonicalLoopInfo *FloorLoop = Result[i];
5001     CanonicalLoopInfo *TileLoop = Result[NumLoops + i];
5002     Value *OrigIndVar = OrigIndVars[i];
5003     Value *Size = TileSizes[i];
5004 
5005     Value *Scale =
5006         Builder.CreateMul(Size, FloorLoop->getIndVar(), {}, /*HasNUW=*/true);
5007     Value *Shift =
5008         Builder.CreateAdd(Scale, TileLoop->getIndVar(), {}, /*HasNUW=*/true);
5009     OrigIndVar->replaceAllUsesWith(Shift);
5010   }
5011 
5012   // Remove unused parts of the original loops.
5013   removeUnusedBlocksFromParent(OldControlBBs);
5014 
5015   for (CanonicalLoopInfo *L : Loops)
5016     L->invalidate();
5017 
5018 #ifndef NDEBUG
5019   for (CanonicalLoopInfo *GenL : Result)
5020     GenL->assertOK();
5021 #endif
5022   return Result;
5023 }
5024 
5025 /// Attach metadata \p Properties to the basic block described by \p BB. If the
5026 /// basic block already has metadata, the basic block properties are appended.
5027 static void addBasicBlockMetadata(BasicBlock *BB,
5028                                   ArrayRef<Metadata *> Properties) {
5029   // Nothing to do if no property to attach.
5030   if (Properties.empty())
5031     return;
5032 
5033   LLVMContext &Ctx = BB->getContext();
5034   SmallVector<Metadata *> NewProperties;
5035   NewProperties.push_back(nullptr);
5036 
5037   // If the basic block already has metadata, prepend it to the new metadata.
5038   MDNode *Existing = BB->getTerminator()->getMetadata(LLVMContext::MD_loop);
5039   if (Existing)
5040     append_range(NewProperties, drop_begin(Existing->operands(), 1));
5041 
5042   append_range(NewProperties, Properties);
5043   MDNode *BasicBlockID = MDNode::getDistinct(Ctx, NewProperties);
5044   BasicBlockID->replaceOperandWith(0, BasicBlockID);
5045 
5046   BB->getTerminator()->setMetadata(LLVMContext::MD_loop, BasicBlockID);
5047 }
5048 
5049 /// Attach loop metadata \p Properties to the loop described by \p Loop. If the
5050 /// loop already has metadata, the loop properties are appended.
5051 static void addLoopMetadata(CanonicalLoopInfo *Loop,
5052                             ArrayRef<Metadata *> Properties) {
5053   assert(Loop->isValid() && "Expecting a valid CanonicalLoopInfo");
5054 
5055   // Attach metadata to the loop's latch
5056   BasicBlock *Latch = Loop->getLatch();
5057   assert(Latch && "A valid CanonicalLoopInfo must have a unique latch");
5058   addBasicBlockMetadata(Latch, Properties);
5059 }
5060 
5061 /// Attach llvm.access.group metadata to the memref instructions of \p Block
5062 static void addSimdMetadata(BasicBlock *Block, MDNode *AccessGroup,
5063                             LoopInfo &LI) {
5064   for (Instruction &I : *Block) {
5065     if (I.mayReadOrWriteMemory()) {
5066       // TODO: This instruction may already have access group from
5067       // other pragmas e.g. #pragma clang loop vectorize.  Append
5068       // so that the existing metadata is not overwritten.
5069       I.setMetadata(LLVMContext::MD_access_group, AccessGroup);
5070     }
5071   }
5072 }
5073 
5074 void OpenMPIRBuilder::unrollLoopFull(DebugLoc, CanonicalLoopInfo *Loop) {
5075   LLVMContext &Ctx = Builder.getContext();
5076   addLoopMetadata(
5077       Loop, {MDNode::get(Ctx, MDString::get(Ctx, "llvm.loop.unroll.enable")),
5078              MDNode::get(Ctx, MDString::get(Ctx, "llvm.loop.unroll.full"))});
5079 }
5080 
5081 void OpenMPIRBuilder::unrollLoopHeuristic(DebugLoc, CanonicalLoopInfo *Loop) {
5082   LLVMContext &Ctx = Builder.getContext();
5083   addLoopMetadata(
5084       Loop, {
5085                 MDNode::get(Ctx, MDString::get(Ctx, "llvm.loop.unroll.enable")),
5086             });
5087 }
5088 
5089 void OpenMPIRBuilder::createIfVersion(CanonicalLoopInfo *CanonicalLoop,
5090                                       Value *IfCond, ValueToValueMapTy &VMap,
5091                                       const Twine &NamePrefix) {
5092   Function *F = CanonicalLoop->getFunction();
5093 
5094   // Define where if branch should be inserted
5095   Instruction *SplitBefore;
5096   if (Instruction::classof(IfCond)) {
5097     SplitBefore = dyn_cast<Instruction>(IfCond);
5098   } else {
5099     SplitBefore = CanonicalLoop->getPreheader()->getTerminator();
5100   }
5101 
5102   // TODO: We should not rely on pass manager. Currently we use pass manager
5103   // only for getting llvm::Loop which corresponds to given CanonicalLoopInfo
5104   // object. We should have a method  which returns all blocks between
5105   // CanonicalLoopInfo::getHeader() and CanonicalLoopInfo::getAfter()
5106   FunctionAnalysisManager FAM;
5107   FAM.registerPass([]() { return DominatorTreeAnalysis(); });
5108   FAM.registerPass([]() { return LoopAnalysis(); });
5109   FAM.registerPass([]() { return PassInstrumentationAnalysis(); });
5110 
5111   // Get the loop which needs to be cloned
5112   LoopAnalysis LIA;
5113   LoopInfo &&LI = LIA.run(*F, FAM);
5114   Loop *L = LI.getLoopFor(CanonicalLoop->getHeader());
5115 
5116   // Create additional blocks for the if statement
5117   BasicBlock *Head = SplitBefore->getParent();
5118   Instruction *HeadOldTerm = Head->getTerminator();
5119   llvm::LLVMContext &C = Head->getContext();
5120   llvm::BasicBlock *ThenBlock = llvm::BasicBlock::Create(
5121       C, NamePrefix + ".if.then", Head->getParent(), Head->getNextNode());
5122   llvm::BasicBlock *ElseBlock = llvm::BasicBlock::Create(
5123       C, NamePrefix + ".if.else", Head->getParent(), CanonicalLoop->getExit());
5124 
5125   // Create if condition branch.
5126   Builder.SetInsertPoint(HeadOldTerm);
5127   Instruction *BrInstr =
5128       Builder.CreateCondBr(IfCond, ThenBlock, /*ifFalse*/ ElseBlock);
5129   InsertPointTy IP{BrInstr->getParent(), ++BrInstr->getIterator()};
5130   // Then block contains branch to omp loop which needs to be vectorized
5131   spliceBB(IP, ThenBlock, false);
5132   ThenBlock->replaceSuccessorsPhiUsesWith(Head, ThenBlock);
5133 
5134   Builder.SetInsertPoint(ElseBlock);
5135 
5136   // Clone loop for the else branch
5137   SmallVector<BasicBlock *, 8> NewBlocks;
5138 
5139   VMap[CanonicalLoop->getPreheader()] = ElseBlock;
5140   for (BasicBlock *Block : L->getBlocks()) {
5141     BasicBlock *NewBB = CloneBasicBlock(Block, VMap, "", F);
5142     NewBB->moveBefore(CanonicalLoop->getExit());
5143     VMap[Block] = NewBB;
5144     NewBlocks.push_back(NewBB);
5145   }
5146   remapInstructionsInBlocks(NewBlocks, VMap);
5147   Builder.CreateBr(NewBlocks.front());
5148 }
5149 
5150 unsigned
5151 OpenMPIRBuilder::getOpenMPDefaultSimdAlign(const Triple &TargetTriple,
5152                                            const StringMap<bool> &Features) {
5153   if (TargetTriple.isX86()) {
5154     if (Features.lookup("avx512f"))
5155       return 512;
5156     else if (Features.lookup("avx"))
5157       return 256;
5158     return 128;
5159   }
5160   if (TargetTriple.isPPC())
5161     return 128;
5162   if (TargetTriple.isWasm())
5163     return 128;
5164   return 0;
5165 }
5166 
5167 void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop,
5168                                 MapVector<Value *, Value *> AlignedVars,
5169                                 Value *IfCond, OrderKind Order,
5170                                 ConstantInt *Simdlen, ConstantInt *Safelen) {
5171   LLVMContext &Ctx = Builder.getContext();
5172 
5173   Function *F = CanonicalLoop->getFunction();
5174 
5175   // TODO: We should not rely on pass manager. Currently we use pass manager
5176   // only for getting llvm::Loop which corresponds to given CanonicalLoopInfo
5177   // object. We should have a method  which returns all blocks between
5178   // CanonicalLoopInfo::getHeader() and CanonicalLoopInfo::getAfter()
5179   FunctionAnalysisManager FAM;
5180   FAM.registerPass([]() { return DominatorTreeAnalysis(); });
5181   FAM.registerPass([]() { return LoopAnalysis(); });
5182   FAM.registerPass([]() { return PassInstrumentationAnalysis(); });
5183 
5184   LoopAnalysis LIA;
5185   LoopInfo &&LI = LIA.run(*F, FAM);
5186 
5187   Loop *L = LI.getLoopFor(CanonicalLoop->getHeader());
5188   if (AlignedVars.size()) {
5189     InsertPointTy IP = Builder.saveIP();
5190     Builder.SetInsertPoint(CanonicalLoop->getPreheader()->getTerminator());
5191     for (auto &AlignedItem : AlignedVars) {
5192       Value *AlignedPtr = AlignedItem.first;
5193       Value *Alignment = AlignedItem.second;
5194       Builder.CreateAlignmentAssumption(F->getDataLayout(),
5195                                         AlignedPtr, Alignment);
5196     }
5197     Builder.restoreIP(IP);
5198   }
5199 
5200   if (IfCond) {
5201     ValueToValueMapTy VMap;
5202     createIfVersion(CanonicalLoop, IfCond, VMap, "simd");
5203     // Add metadata to the cloned loop which disables vectorization
5204     Value *MappedLatch = VMap.lookup(CanonicalLoop->getLatch());
5205     assert(MappedLatch &&
5206            "Cannot find value which corresponds to original loop latch");
5207     assert(isa<BasicBlock>(MappedLatch) &&
5208            "Cannot cast mapped latch block value to BasicBlock");
5209     BasicBlock *NewLatchBlock = dyn_cast<BasicBlock>(MappedLatch);
5210     ConstantAsMetadata *BoolConst =
5211         ConstantAsMetadata::get(ConstantInt::getFalse(Type::getInt1Ty(Ctx)));
5212     addBasicBlockMetadata(
5213         NewLatchBlock,
5214         {MDNode::get(Ctx, {MDString::get(Ctx, "llvm.loop.vectorize.enable"),
5215                            BoolConst})});
5216   }
5217 
5218   SmallSet<BasicBlock *, 8> Reachable;
5219 
5220   // Get the basic blocks from the loop in which memref instructions
5221   // can be found.
5222   // TODO: Generalize getting all blocks inside a CanonicalizeLoopInfo,
5223   // preferably without running any passes.
5224   for (BasicBlock *Block : L->getBlocks()) {
5225     if (Block == CanonicalLoop->getCond() ||
5226         Block == CanonicalLoop->getHeader())
5227       continue;
5228     Reachable.insert(Block);
5229   }
5230 
5231   SmallVector<Metadata *> LoopMDList;
5232 
5233   // In presence of finite 'safelen', it may be unsafe to mark all
5234   // the memory instructions parallel, because loop-carried
5235   // dependences of 'safelen' iterations are possible.
5236   // If clause order(concurrent) is specified then the memory instructions
5237   // are marked parallel even if 'safelen' is finite.
5238   if ((Safelen == nullptr) || (Order == OrderKind::OMP_ORDER_concurrent)) {
5239     // Add access group metadata to memory-access instructions.
5240     MDNode *AccessGroup = MDNode::getDistinct(Ctx, {});
5241     for (BasicBlock *BB : Reachable)
5242       addSimdMetadata(BB, AccessGroup, LI);
5243     // TODO:  If the loop has existing parallel access metadata, have
5244     // to combine two lists.
5245     LoopMDList.push_back(MDNode::get(
5246         Ctx, {MDString::get(Ctx, "llvm.loop.parallel_accesses"), AccessGroup}));
5247   }
5248 
5249   // Use the above access group metadata to create loop level
5250   // metadata, which should be distinct for each loop.
5251   ConstantAsMetadata *BoolConst =
5252       ConstantAsMetadata::get(ConstantInt::getTrue(Type::getInt1Ty(Ctx)));
5253   LoopMDList.push_back(MDNode::get(
5254       Ctx, {MDString::get(Ctx, "llvm.loop.vectorize.enable"), BoolConst}));
5255 
5256   if (Simdlen || Safelen) {
5257     // If both simdlen and safelen clauses are specified, the value of the
5258     // simdlen parameter must be less than or equal to the value of the safelen
5259     // parameter. Therefore, use safelen only in the absence of simdlen.
5260     ConstantInt *VectorizeWidth = Simdlen == nullptr ? Safelen : Simdlen;
5261     LoopMDList.push_back(
5262         MDNode::get(Ctx, {MDString::get(Ctx, "llvm.loop.vectorize.width"),
5263                           ConstantAsMetadata::get(VectorizeWidth)}));
5264   }
5265 
5266   addLoopMetadata(CanonicalLoop, LoopMDList);
5267 }
5268 
5269 /// Create the TargetMachine object to query the backend for optimization
5270 /// preferences.
5271 ///
5272 /// Ideally, this would be passed from the front-end to the OpenMPBuilder, but
5273 /// e.g. Clang does not pass it to its CodeGen layer and creates it only when
5274 /// needed for the LLVM pass pipline. We use some default options to avoid
5275 /// having to pass too many settings from the frontend that probably do not
5276 /// matter.
5277 ///
5278 /// Currently, TargetMachine is only used sometimes by the unrollLoopPartial
5279 /// method. If we are going to use TargetMachine for more purposes, especially
5280 /// those that are sensitive to TargetOptions, RelocModel and CodeModel, it
5281 /// might become be worth requiring front-ends to pass on their TargetMachine,
5282 /// or at least cache it between methods. Note that while fontends such as Clang
5283 /// have just a single main TargetMachine per translation unit, "target-cpu" and
5284 /// "target-features" that determine the TargetMachine are per-function and can
5285 /// be overrided using __attribute__((target("OPTIONS"))).
5286 static std::unique_ptr<TargetMachine>
5287 createTargetMachine(Function *F, CodeGenOptLevel OptLevel) {
5288   Module *M = F->getParent();
5289 
5290   StringRef CPU = F->getFnAttribute("target-cpu").getValueAsString();
5291   StringRef Features = F->getFnAttribute("target-features").getValueAsString();
5292   const std::string &Triple = M->getTargetTriple();
5293 
5294   std::string Error;
5295   const llvm::Target *TheTarget = TargetRegistry::lookupTarget(Triple, Error);
5296   if (!TheTarget)
5297     return {};
5298 
5299   llvm::TargetOptions Options;
5300   return std::unique_ptr<TargetMachine>(TheTarget->createTargetMachine(
5301       Triple, CPU, Features, Options, /*RelocModel=*/std::nullopt,
5302       /*CodeModel=*/std::nullopt, OptLevel));
5303 }
5304 
5305 /// Heuristically determine the best-performant unroll factor for \p CLI. This
5306 /// depends on the target processor. We are re-using the same heuristics as the
5307 /// LoopUnrollPass.
5308 static int32_t computeHeuristicUnrollFactor(CanonicalLoopInfo *CLI) {
5309   Function *F = CLI->getFunction();
5310 
5311   // Assume the user requests the most aggressive unrolling, even if the rest of
5312   // the code is optimized using a lower setting.
5313   CodeGenOptLevel OptLevel = CodeGenOptLevel::Aggressive;
5314   std::unique_ptr<TargetMachine> TM = createTargetMachine(F, OptLevel);
5315 
5316   FunctionAnalysisManager FAM;
5317   FAM.registerPass([]() { return TargetLibraryAnalysis(); });
5318   FAM.registerPass([]() { return AssumptionAnalysis(); });
5319   FAM.registerPass([]() { return DominatorTreeAnalysis(); });
5320   FAM.registerPass([]() { return LoopAnalysis(); });
5321   FAM.registerPass([]() { return ScalarEvolutionAnalysis(); });
5322   FAM.registerPass([]() { return PassInstrumentationAnalysis(); });
5323   TargetIRAnalysis TIRA;
5324   if (TM)
5325     TIRA = TargetIRAnalysis(
5326         [&](const Function &F) { return TM->getTargetTransformInfo(F); });
5327   FAM.registerPass([&]() { return TIRA; });
5328 
5329   TargetIRAnalysis::Result &&TTI = TIRA.run(*F, FAM);
5330   ScalarEvolutionAnalysis SEA;
5331   ScalarEvolution &&SE = SEA.run(*F, FAM);
5332   DominatorTreeAnalysis DTA;
5333   DominatorTree &&DT = DTA.run(*F, FAM);
5334   LoopAnalysis LIA;
5335   LoopInfo &&LI = LIA.run(*F, FAM);
5336   AssumptionAnalysis ACT;
5337   AssumptionCache &&AC = ACT.run(*F, FAM);
5338   OptimizationRemarkEmitter ORE{F};
5339 
5340   Loop *L = LI.getLoopFor(CLI->getHeader());
5341   assert(L && "Expecting CanonicalLoopInfo to be recognized as a loop");
5342 
5343   TargetTransformInfo::UnrollingPreferences UP =
5344       gatherUnrollingPreferences(L, SE, TTI,
5345                                  /*BlockFrequencyInfo=*/nullptr,
5346                                  /*ProfileSummaryInfo=*/nullptr, ORE, static_cast<int>(OptLevel),
5347                                  /*UserThreshold=*/std::nullopt,
5348                                  /*UserCount=*/std::nullopt,
5349                                  /*UserAllowPartial=*/true,
5350                                  /*UserAllowRuntime=*/true,
5351                                  /*UserUpperBound=*/std::nullopt,
5352                                  /*UserFullUnrollMaxCount=*/std::nullopt);
5353 
5354   UP.Force = true;
5355 
5356   // Account for additional optimizations taking place before the LoopUnrollPass
5357   // would unroll the loop.
5358   UP.Threshold *= UnrollThresholdFactor;
5359   UP.PartialThreshold *= UnrollThresholdFactor;
5360 
5361   // Use normal unroll factors even if the rest of the code is optimized for
5362   // size.
5363   UP.OptSizeThreshold = UP.Threshold;
5364   UP.PartialOptSizeThreshold = UP.PartialThreshold;
5365 
5366   LLVM_DEBUG(dbgs() << "Unroll heuristic thresholds:\n"
5367                     << "  Threshold=" << UP.Threshold << "\n"
5368                     << "  PartialThreshold=" << UP.PartialThreshold << "\n"
5369                     << "  OptSizeThreshold=" << UP.OptSizeThreshold << "\n"
5370                     << "  PartialOptSizeThreshold="
5371                     << UP.PartialOptSizeThreshold << "\n");
5372 
5373   // Disable peeling.
5374   TargetTransformInfo::PeelingPreferences PP =
5375       gatherPeelingPreferences(L, SE, TTI,
5376                                /*UserAllowPeeling=*/false,
5377                                /*UserAllowProfileBasedPeeling=*/false,
5378                                /*UnrollingSpecficValues=*/false);
5379 
5380   SmallPtrSet<const Value *, 32> EphValues;
5381   CodeMetrics::collectEphemeralValues(L, &AC, EphValues);
5382 
5383   // Assume that reads and writes to stack variables can be eliminated by
5384   // Mem2Reg, SROA or LICM. That is, don't count them towards the loop body's
5385   // size.
5386   for (BasicBlock *BB : L->blocks()) {
5387     for (Instruction &I : *BB) {
5388       Value *Ptr;
5389       if (auto *Load = dyn_cast<LoadInst>(&I)) {
5390         Ptr = Load->getPointerOperand();
5391       } else if (auto *Store = dyn_cast<StoreInst>(&I)) {
5392         Ptr = Store->getPointerOperand();
5393       } else
5394         continue;
5395 
5396       Ptr = Ptr->stripPointerCasts();
5397 
5398       if (auto *Alloca = dyn_cast<AllocaInst>(Ptr)) {
5399         if (Alloca->getParent() == &F->getEntryBlock())
5400           EphValues.insert(&I);
5401       }
5402     }
5403   }
5404 
5405   UnrollCostEstimator UCE(L, TTI, EphValues, UP.BEInsns);
5406 
5407   // Loop is not unrollable if the loop contains certain instructions.
5408   if (!UCE.canUnroll()) {
5409     LLVM_DEBUG(dbgs() << "Loop not considered unrollable\n");
5410     return 1;
5411   }
5412 
5413   LLVM_DEBUG(dbgs() << "Estimated loop size is " << UCE.getRolledLoopSize()
5414                     << "\n");
5415 
5416   // TODO: Determine trip count of \p CLI if constant, computeUnrollCount might
5417   // be able to use it.
5418   int TripCount = 0;
5419   int MaxTripCount = 0;
5420   bool MaxOrZero = false;
5421   unsigned TripMultiple = 0;
5422 
5423   bool UseUpperBound = false;
5424   computeUnrollCount(L, TTI, DT, &LI, &AC, SE, EphValues, &ORE, TripCount,
5425                      MaxTripCount, MaxOrZero, TripMultiple, UCE, UP, PP,
5426                      UseUpperBound);
5427   unsigned Factor = UP.Count;
5428   LLVM_DEBUG(dbgs() << "Suggesting unroll factor of " << Factor << "\n");
5429 
5430   // This function returns 1 to signal to not unroll a loop.
5431   if (Factor == 0)
5432     return 1;
5433   return Factor;
5434 }
5435 
5436 void OpenMPIRBuilder::unrollLoopPartial(DebugLoc DL, CanonicalLoopInfo *Loop,
5437                                         int32_t Factor,
5438                                         CanonicalLoopInfo **UnrolledCLI) {
5439   assert(Factor >= 0 && "Unroll factor must not be negative");
5440 
5441   Function *F = Loop->getFunction();
5442   LLVMContext &Ctx = F->getContext();
5443 
5444   // If the unrolled loop is not used for another loop-associated directive, it
5445   // is sufficient to add metadata for the LoopUnrollPass.
5446   if (!UnrolledCLI) {
5447     SmallVector<Metadata *, 2> LoopMetadata;
5448     LoopMetadata.push_back(
5449         MDNode::get(Ctx, MDString::get(Ctx, "llvm.loop.unroll.enable")));
5450 
5451     if (Factor >= 1) {
5452       ConstantAsMetadata *FactorConst = ConstantAsMetadata::get(
5453           ConstantInt::get(Type::getInt32Ty(Ctx), APInt(32, Factor)));
5454       LoopMetadata.push_back(MDNode::get(
5455           Ctx, {MDString::get(Ctx, "llvm.loop.unroll.count"), FactorConst}));
5456     }
5457 
5458     addLoopMetadata(Loop, LoopMetadata);
5459     return;
5460   }
5461 
5462   // Heuristically determine the unroll factor.
5463   if (Factor == 0)
5464     Factor = computeHeuristicUnrollFactor(Loop);
5465 
5466   // No change required with unroll factor 1.
5467   if (Factor == 1) {
5468     *UnrolledCLI = Loop;
5469     return;
5470   }
5471 
5472   assert(Factor >= 2 &&
5473          "unrolling only makes sense with a factor of 2 or larger");
5474 
5475   Type *IndVarTy = Loop->getIndVarType();
5476 
5477   // Apply partial unrolling by tiling the loop by the unroll-factor, then fully
5478   // unroll the inner loop.
5479   Value *FactorVal =
5480       ConstantInt::get(IndVarTy, APInt(IndVarTy->getIntegerBitWidth(), Factor,
5481                                        /*isSigned=*/false));
5482   std::vector<CanonicalLoopInfo *> LoopNest =
5483       tileLoops(DL, {Loop}, {FactorVal});
5484   assert(LoopNest.size() == 2 && "Expect 2 loops after tiling");
5485   *UnrolledCLI = LoopNest[0];
5486   CanonicalLoopInfo *InnerLoop = LoopNest[1];
5487 
5488   // LoopUnrollPass can only fully unroll loops with constant trip count.
5489   // Unroll by the unroll factor with a fallback epilog for the remainder
5490   // iterations if necessary.
5491   ConstantAsMetadata *FactorConst = ConstantAsMetadata::get(
5492       ConstantInt::get(Type::getInt32Ty(Ctx), APInt(32, Factor)));
5493   addLoopMetadata(
5494       InnerLoop,
5495       {MDNode::get(Ctx, MDString::get(Ctx, "llvm.loop.unroll.enable")),
5496        MDNode::get(
5497            Ctx, {MDString::get(Ctx, "llvm.loop.unroll.count"), FactorConst})});
5498 
5499 #ifndef NDEBUG
5500   (*UnrolledCLI)->assertOK();
5501 #endif
5502 }
5503 
5504 OpenMPIRBuilder::InsertPointTy
5505 OpenMPIRBuilder::createCopyPrivate(const LocationDescription &Loc,
5506                                    llvm::Value *BufSize, llvm::Value *CpyBuf,
5507                                    llvm::Value *CpyFn, llvm::Value *DidIt) {
5508   if (!updateToLocation(Loc))
5509     return Loc.IP;
5510 
5511   uint32_t SrcLocStrSize;
5512   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5513   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5514   Value *ThreadId = getOrCreateThreadID(Ident);
5515 
5516   llvm::Value *DidItLD = Builder.CreateLoad(Builder.getInt32Ty(), DidIt);
5517 
5518   Value *Args[] = {Ident, ThreadId, BufSize, CpyBuf, CpyFn, DidItLD};
5519 
5520   Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_copyprivate);
5521   Builder.CreateCall(Fn, Args);
5522 
5523   return Builder.saveIP();
5524 }
5525 
5526 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createSingle(
5527     const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
5528     FinalizeCallbackTy FiniCB, bool IsNowait, ArrayRef<llvm::Value *> CPVars,
5529     ArrayRef<llvm::Function *> CPFuncs) {
5530 
5531   if (!updateToLocation(Loc))
5532     return Loc.IP;
5533 
5534   // If needed allocate and initialize `DidIt` with 0.
5535   // DidIt: flag variable: 1=single thread; 0=not single thread.
5536   llvm::Value *DidIt = nullptr;
5537   if (!CPVars.empty()) {
5538     DidIt = Builder.CreateAlloca(llvm::Type::getInt32Ty(Builder.getContext()));
5539     Builder.CreateStore(Builder.getInt32(0), DidIt);
5540   }
5541 
5542   Directive OMPD = Directive::OMPD_single;
5543   uint32_t SrcLocStrSize;
5544   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5545   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5546   Value *ThreadId = getOrCreateThreadID(Ident);
5547   Value *Args[] = {Ident, ThreadId};
5548 
5549   Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_single);
5550   Instruction *EntryCall = Builder.CreateCall(EntryRTLFn, Args);
5551 
5552   Function *ExitRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_single);
5553   Instruction *ExitCall = Builder.CreateCall(ExitRTLFn, Args);
5554 
5555   auto FiniCBWrapper = [&](InsertPointTy IP) {
5556     FiniCB(IP);
5557 
5558     // The thread that executes the single region must set `DidIt` to 1.
5559     // This is used by __kmpc_copyprivate, to know if the caller is the
5560     // single thread or not.
5561     if (DidIt)
5562       Builder.CreateStore(Builder.getInt32(1), DidIt);
5563   };
5564 
5565   // generates the following:
5566   // if (__kmpc_single()) {
5567   //		.... single region ...
5568   // 		__kmpc_end_single
5569   // }
5570   // __kmpc_copyprivate
5571   // __kmpc_barrier
5572 
5573   EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCBWrapper,
5574                        /*Conditional*/ true,
5575                        /*hasFinalize*/ true);
5576 
5577   if (DidIt) {
5578     for (size_t I = 0, E = CPVars.size(); I < E; ++I)
5579       // NOTE BufSize is currently unused, so just pass 0.
5580       createCopyPrivate(LocationDescription(Builder.saveIP(), Loc.DL),
5581                         /*BufSize=*/ConstantInt::get(Int64, 0), CPVars[I],
5582                         CPFuncs[I], DidIt);
5583     // NOTE __kmpc_copyprivate already inserts a barrier
5584   } else if (!IsNowait)
5585     createBarrier(LocationDescription(Builder.saveIP(), Loc.DL),
5586                   omp::Directive::OMPD_unknown, /* ForceSimpleCall */ false,
5587                   /* CheckCancelFlag */ false);
5588   return Builder.saveIP();
5589 }
5590 
5591 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createCritical(
5592     const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
5593     FinalizeCallbackTy FiniCB, StringRef CriticalName, Value *HintInst) {
5594 
5595   if (!updateToLocation(Loc))
5596     return Loc.IP;
5597 
5598   Directive OMPD = Directive::OMPD_critical;
5599   uint32_t SrcLocStrSize;
5600   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5601   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5602   Value *ThreadId = getOrCreateThreadID(Ident);
5603   Value *LockVar = getOMPCriticalRegionLock(CriticalName);
5604   Value *Args[] = {Ident, ThreadId, LockVar};
5605 
5606   SmallVector<llvm::Value *, 4> EnterArgs(std::begin(Args), std::end(Args));
5607   Function *RTFn = nullptr;
5608   if (HintInst) {
5609     // Add Hint to entry Args and create call
5610     EnterArgs.push_back(HintInst);
5611     RTFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_critical_with_hint);
5612   } else {
5613     RTFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_critical);
5614   }
5615   Instruction *EntryCall = Builder.CreateCall(RTFn, EnterArgs);
5616 
5617   Function *ExitRTLFn =
5618       getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_critical);
5619   Instruction *ExitCall = Builder.CreateCall(ExitRTLFn, Args);
5620 
5621   return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
5622                               /*Conditional*/ false, /*hasFinalize*/ true);
5623 }
5624 
5625 OpenMPIRBuilder::InsertPointTy
5626 OpenMPIRBuilder::createOrderedDepend(const LocationDescription &Loc,
5627                                      InsertPointTy AllocaIP, unsigned NumLoops,
5628                                      ArrayRef<llvm::Value *> StoreValues,
5629                                      const Twine &Name, bool IsDependSource) {
5630   assert(
5631       llvm::all_of(StoreValues,
5632                    [](Value *SV) { return SV->getType()->isIntegerTy(64); }) &&
5633       "OpenMP runtime requires depend vec with i64 type");
5634 
5635   if (!updateToLocation(Loc))
5636     return Loc.IP;
5637 
5638   // Allocate space for vector and generate alloc instruction.
5639   auto *ArrI64Ty = ArrayType::get(Int64, NumLoops);
5640   Builder.restoreIP(AllocaIP);
5641   AllocaInst *ArgsBase = Builder.CreateAlloca(ArrI64Ty, nullptr, Name);
5642   ArgsBase->setAlignment(Align(8));
5643   Builder.restoreIP(Loc.IP);
5644 
5645   // Store the index value with offset in depend vector.
5646   for (unsigned I = 0; I < NumLoops; ++I) {
5647     Value *DependAddrGEPIter = Builder.CreateInBoundsGEP(
5648         ArrI64Ty, ArgsBase, {Builder.getInt64(0), Builder.getInt64(I)});
5649     StoreInst *STInst = Builder.CreateStore(StoreValues[I], DependAddrGEPIter);
5650     STInst->setAlignment(Align(8));
5651   }
5652 
5653   Value *DependBaseAddrGEP = Builder.CreateInBoundsGEP(
5654       ArrI64Ty, ArgsBase, {Builder.getInt64(0), Builder.getInt64(0)});
5655 
5656   uint32_t SrcLocStrSize;
5657   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5658   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5659   Value *ThreadId = getOrCreateThreadID(Ident);
5660   Value *Args[] = {Ident, ThreadId, DependBaseAddrGEP};
5661 
5662   Function *RTLFn = nullptr;
5663   if (IsDependSource)
5664     RTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_doacross_post);
5665   else
5666     RTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_doacross_wait);
5667   Builder.CreateCall(RTLFn, Args);
5668 
5669   return Builder.saveIP();
5670 }
5671 
5672 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createOrderedThreadsSimd(
5673     const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
5674     FinalizeCallbackTy FiniCB, bool IsThreads) {
5675   if (!updateToLocation(Loc))
5676     return Loc.IP;
5677 
5678   Directive OMPD = Directive::OMPD_ordered;
5679   Instruction *EntryCall = nullptr;
5680   Instruction *ExitCall = nullptr;
5681 
5682   if (IsThreads) {
5683     uint32_t SrcLocStrSize;
5684     Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5685     Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5686     Value *ThreadId = getOrCreateThreadID(Ident);
5687     Value *Args[] = {Ident, ThreadId};
5688 
5689     Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_ordered);
5690     EntryCall = Builder.CreateCall(EntryRTLFn, Args);
5691 
5692     Function *ExitRTLFn =
5693         getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_ordered);
5694     ExitCall = Builder.CreateCall(ExitRTLFn, Args);
5695   }
5696 
5697   return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
5698                               /*Conditional*/ false, /*hasFinalize*/ true);
5699 }
5700 
5701 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::EmitOMPInlinedRegion(
5702     Directive OMPD, Instruction *EntryCall, Instruction *ExitCall,
5703     BodyGenCallbackTy BodyGenCB, FinalizeCallbackTy FiniCB, bool Conditional,
5704     bool HasFinalize, bool IsCancellable) {
5705 
5706   if (HasFinalize)
5707     FinalizationStack.push_back({FiniCB, OMPD, IsCancellable});
5708 
5709   // Create inlined region's entry and body blocks, in preparation
5710   // for conditional creation
5711   BasicBlock *EntryBB = Builder.GetInsertBlock();
5712   Instruction *SplitPos = EntryBB->getTerminator();
5713   if (!isa_and_nonnull<BranchInst>(SplitPos))
5714     SplitPos = new UnreachableInst(Builder.getContext(), EntryBB);
5715   BasicBlock *ExitBB = EntryBB->splitBasicBlock(SplitPos, "omp_region.end");
5716   BasicBlock *FiniBB =
5717       EntryBB->splitBasicBlock(EntryBB->getTerminator(), "omp_region.finalize");
5718 
5719   Builder.SetInsertPoint(EntryBB->getTerminator());
5720   emitCommonDirectiveEntry(OMPD, EntryCall, ExitBB, Conditional);
5721 
5722   // generate body
5723   BodyGenCB(/* AllocaIP */ InsertPointTy(),
5724             /* CodeGenIP */ Builder.saveIP());
5725 
5726   // emit exit call and do any needed finalization.
5727   auto FinIP = InsertPointTy(FiniBB, FiniBB->getFirstInsertionPt());
5728   assert(FiniBB->getTerminator()->getNumSuccessors() == 1 &&
5729          FiniBB->getTerminator()->getSuccessor(0) == ExitBB &&
5730          "Unexpected control flow graph state!!");
5731   emitCommonDirectiveExit(OMPD, FinIP, ExitCall, HasFinalize);
5732   assert(FiniBB->getUniquePredecessor()->getUniqueSuccessor() == FiniBB &&
5733          "Unexpected Control Flow State!");
5734   MergeBlockIntoPredecessor(FiniBB);
5735 
5736   // If we are skipping the region of a non conditional, remove the exit
5737   // block, and clear the builder's insertion point.
5738   assert(SplitPos->getParent() == ExitBB &&
5739          "Unexpected Insertion point location!");
5740   auto merged = MergeBlockIntoPredecessor(ExitBB);
5741   BasicBlock *ExitPredBB = SplitPos->getParent();
5742   auto InsertBB = merged ? ExitPredBB : ExitBB;
5743   if (!isa_and_nonnull<BranchInst>(SplitPos))
5744     SplitPos->eraseFromParent();
5745   Builder.SetInsertPoint(InsertBB);
5746 
5747   return Builder.saveIP();
5748 }
5749 
5750 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitCommonDirectiveEntry(
5751     Directive OMPD, Value *EntryCall, BasicBlock *ExitBB, bool Conditional) {
5752   // if nothing to do, Return current insertion point.
5753   if (!Conditional || !EntryCall)
5754     return Builder.saveIP();
5755 
5756   BasicBlock *EntryBB = Builder.GetInsertBlock();
5757   Value *CallBool = Builder.CreateIsNotNull(EntryCall);
5758   auto *ThenBB = BasicBlock::Create(M.getContext(), "omp_region.body");
5759   auto *UI = new UnreachableInst(Builder.getContext(), ThenBB);
5760 
5761   // Emit thenBB and set the Builder's insertion point there for
5762   // body generation next. Place the block after the current block.
5763   Function *CurFn = EntryBB->getParent();
5764   CurFn->insert(std::next(EntryBB->getIterator()), ThenBB);
5765 
5766   // Move Entry branch to end of ThenBB, and replace with conditional
5767   // branch (If-stmt)
5768   Instruction *EntryBBTI = EntryBB->getTerminator();
5769   Builder.CreateCondBr(CallBool, ThenBB, ExitBB);
5770   EntryBBTI->removeFromParent();
5771   Builder.SetInsertPoint(UI);
5772   Builder.Insert(EntryBBTI);
5773   UI->eraseFromParent();
5774   Builder.SetInsertPoint(ThenBB->getTerminator());
5775 
5776   // return an insertion point to ExitBB.
5777   return IRBuilder<>::InsertPoint(ExitBB, ExitBB->getFirstInsertionPt());
5778 }
5779 
5780 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitCommonDirectiveExit(
5781     omp::Directive OMPD, InsertPointTy FinIP, Instruction *ExitCall,
5782     bool HasFinalize) {
5783 
5784   Builder.restoreIP(FinIP);
5785 
5786   // If there is finalization to do, emit it before the exit call
5787   if (HasFinalize) {
5788     assert(!FinalizationStack.empty() &&
5789            "Unexpected finalization stack state!");
5790 
5791     FinalizationInfo Fi = FinalizationStack.pop_back_val();
5792     assert(Fi.DK == OMPD && "Unexpected Directive for Finalization call!");
5793 
5794     Fi.FiniCB(FinIP);
5795 
5796     BasicBlock *FiniBB = FinIP.getBlock();
5797     Instruction *FiniBBTI = FiniBB->getTerminator();
5798 
5799     // set Builder IP for call creation
5800     Builder.SetInsertPoint(FiniBBTI);
5801   }
5802 
5803   if (!ExitCall)
5804     return Builder.saveIP();
5805 
5806   // place the Exitcall as last instruction before Finalization block terminator
5807   ExitCall->removeFromParent();
5808   Builder.Insert(ExitCall);
5809 
5810   return IRBuilder<>::InsertPoint(ExitCall->getParent(),
5811                                   ExitCall->getIterator());
5812 }
5813 
5814 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createCopyinClauseBlocks(
5815     InsertPointTy IP, Value *MasterAddr, Value *PrivateAddr,
5816     llvm::IntegerType *IntPtrTy, bool BranchtoEnd) {
5817   if (!IP.isSet())
5818     return IP;
5819 
5820   IRBuilder<>::InsertPointGuard IPG(Builder);
5821 
5822   // creates the following CFG structure
5823   //	   OMP_Entry : (MasterAddr != PrivateAddr)?
5824   //       F     T
5825   //       |      \
5826   //       |     copin.not.master
5827   //       |      /
5828   //       v     /
5829   //   copyin.not.master.end
5830   //		     |
5831   //         v
5832   //   OMP.Entry.Next
5833 
5834   BasicBlock *OMP_Entry = IP.getBlock();
5835   Function *CurFn = OMP_Entry->getParent();
5836   BasicBlock *CopyBegin =
5837       BasicBlock::Create(M.getContext(), "copyin.not.master", CurFn);
5838   BasicBlock *CopyEnd = nullptr;
5839 
5840   // If entry block is terminated, split to preserve the branch to following
5841   // basic block (i.e. OMP.Entry.Next), otherwise, leave everything as is.
5842   if (isa_and_nonnull<BranchInst>(OMP_Entry->getTerminator())) {
5843     CopyEnd = OMP_Entry->splitBasicBlock(OMP_Entry->getTerminator(),
5844                                          "copyin.not.master.end");
5845     OMP_Entry->getTerminator()->eraseFromParent();
5846   } else {
5847     CopyEnd =
5848         BasicBlock::Create(M.getContext(), "copyin.not.master.end", CurFn);
5849   }
5850 
5851   Builder.SetInsertPoint(OMP_Entry);
5852   Value *MasterPtr = Builder.CreatePtrToInt(MasterAddr, IntPtrTy);
5853   Value *PrivatePtr = Builder.CreatePtrToInt(PrivateAddr, IntPtrTy);
5854   Value *cmp = Builder.CreateICmpNE(MasterPtr, PrivatePtr);
5855   Builder.CreateCondBr(cmp, CopyBegin, CopyEnd);
5856 
5857   Builder.SetInsertPoint(CopyBegin);
5858   if (BranchtoEnd)
5859     Builder.SetInsertPoint(Builder.CreateBr(CopyEnd));
5860 
5861   return Builder.saveIP();
5862 }
5863 
5864 CallInst *OpenMPIRBuilder::createOMPAlloc(const LocationDescription &Loc,
5865                                           Value *Size, Value *Allocator,
5866                                           std::string Name) {
5867   IRBuilder<>::InsertPointGuard IPG(Builder);
5868   updateToLocation(Loc);
5869 
5870   uint32_t SrcLocStrSize;
5871   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5872   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5873   Value *ThreadId = getOrCreateThreadID(Ident);
5874   Value *Args[] = {ThreadId, Size, Allocator};
5875 
5876   Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_alloc);
5877 
5878   return Builder.CreateCall(Fn, Args, Name);
5879 }
5880 
5881 CallInst *OpenMPIRBuilder::createOMPFree(const LocationDescription &Loc,
5882                                          Value *Addr, Value *Allocator,
5883                                          std::string Name) {
5884   IRBuilder<>::InsertPointGuard IPG(Builder);
5885   updateToLocation(Loc);
5886 
5887   uint32_t SrcLocStrSize;
5888   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5889   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5890   Value *ThreadId = getOrCreateThreadID(Ident);
5891   Value *Args[] = {ThreadId, Addr, Allocator};
5892   Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_free);
5893   return Builder.CreateCall(Fn, Args, Name);
5894 }
5895 
5896 CallInst *OpenMPIRBuilder::createOMPInteropInit(
5897     const LocationDescription &Loc, Value *InteropVar,
5898     omp::OMPInteropType InteropType, Value *Device, Value *NumDependences,
5899     Value *DependenceAddress, bool HaveNowaitClause) {
5900   IRBuilder<>::InsertPointGuard IPG(Builder);
5901   updateToLocation(Loc);
5902 
5903   uint32_t SrcLocStrSize;
5904   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5905   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5906   Value *ThreadId = getOrCreateThreadID(Ident);
5907   if (Device == nullptr)
5908     Device = ConstantInt::get(Int32, -1);
5909   Constant *InteropTypeVal = ConstantInt::get(Int32, (int)InteropType);
5910   if (NumDependences == nullptr) {
5911     NumDependences = ConstantInt::get(Int32, 0);
5912     PointerType *PointerTypeVar = PointerType::getUnqual(M.getContext());
5913     DependenceAddress = ConstantPointerNull::get(PointerTypeVar);
5914   }
5915   Value *HaveNowaitClauseVal = ConstantInt::get(Int32, HaveNowaitClause);
5916   Value *Args[] = {
5917       Ident,  ThreadId,       InteropVar,        InteropTypeVal,
5918       Device, NumDependences, DependenceAddress, HaveNowaitClauseVal};
5919 
5920   Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___tgt_interop_init);
5921 
5922   return Builder.CreateCall(Fn, Args);
5923 }
5924 
5925 CallInst *OpenMPIRBuilder::createOMPInteropDestroy(
5926     const LocationDescription &Loc, Value *InteropVar, Value *Device,
5927     Value *NumDependences, Value *DependenceAddress, bool HaveNowaitClause) {
5928   IRBuilder<>::InsertPointGuard IPG(Builder);
5929   updateToLocation(Loc);
5930 
5931   uint32_t SrcLocStrSize;
5932   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5933   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5934   Value *ThreadId = getOrCreateThreadID(Ident);
5935   if (Device == nullptr)
5936     Device = ConstantInt::get(Int32, -1);
5937   if (NumDependences == nullptr) {
5938     NumDependences = ConstantInt::get(Int32, 0);
5939     PointerType *PointerTypeVar = PointerType::getUnqual(M.getContext());
5940     DependenceAddress = ConstantPointerNull::get(PointerTypeVar);
5941   }
5942   Value *HaveNowaitClauseVal = ConstantInt::get(Int32, HaveNowaitClause);
5943   Value *Args[] = {
5944       Ident,          ThreadId,          InteropVar,         Device,
5945       NumDependences, DependenceAddress, HaveNowaitClauseVal};
5946 
5947   Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___tgt_interop_destroy);
5948 
5949   return Builder.CreateCall(Fn, Args);
5950 }
5951 
5952 CallInst *OpenMPIRBuilder::createOMPInteropUse(const LocationDescription &Loc,
5953                                                Value *InteropVar, Value *Device,
5954                                                Value *NumDependences,
5955                                                Value *DependenceAddress,
5956                                                bool HaveNowaitClause) {
5957   IRBuilder<>::InsertPointGuard IPG(Builder);
5958   updateToLocation(Loc);
5959   uint32_t SrcLocStrSize;
5960   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5961   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5962   Value *ThreadId = getOrCreateThreadID(Ident);
5963   if (Device == nullptr)
5964     Device = ConstantInt::get(Int32, -1);
5965   if (NumDependences == nullptr) {
5966     NumDependences = ConstantInt::get(Int32, 0);
5967     PointerType *PointerTypeVar = PointerType::getUnqual(M.getContext());
5968     DependenceAddress = ConstantPointerNull::get(PointerTypeVar);
5969   }
5970   Value *HaveNowaitClauseVal = ConstantInt::get(Int32, HaveNowaitClause);
5971   Value *Args[] = {
5972       Ident,          ThreadId,          InteropVar,         Device,
5973       NumDependences, DependenceAddress, HaveNowaitClauseVal};
5974 
5975   Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___tgt_interop_use);
5976 
5977   return Builder.CreateCall(Fn, Args);
5978 }
5979 
5980 CallInst *OpenMPIRBuilder::createCachedThreadPrivate(
5981     const LocationDescription &Loc, llvm::Value *Pointer,
5982     llvm::ConstantInt *Size, const llvm::Twine &Name) {
5983   IRBuilder<>::InsertPointGuard IPG(Builder);
5984   updateToLocation(Loc);
5985 
5986   uint32_t SrcLocStrSize;
5987   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5988   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5989   Value *ThreadId = getOrCreateThreadID(Ident);
5990   Constant *ThreadPrivateCache =
5991       getOrCreateInternalVariable(Int8PtrPtr, Name.str());
5992   llvm::Value *Args[] = {Ident, ThreadId, Pointer, Size, ThreadPrivateCache};
5993 
5994   Function *Fn =
5995       getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_threadprivate_cached);
5996 
5997   return Builder.CreateCall(Fn, Args);
5998 }
5999 
6000 OpenMPIRBuilder::InsertPointTy
6001 OpenMPIRBuilder::createTargetInit(const LocationDescription &Loc, bool IsSPMD,
6002                                   int32_t MinThreadsVal, int32_t MaxThreadsVal,
6003                                   int32_t MinTeamsVal, int32_t MaxTeamsVal) {
6004   if (!updateToLocation(Loc))
6005     return Loc.IP;
6006 
6007   uint32_t SrcLocStrSize;
6008   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
6009   Constant *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
6010   Constant *IsSPMDVal = ConstantInt::getSigned(
6011       Int8, IsSPMD ? OMP_TGT_EXEC_MODE_SPMD : OMP_TGT_EXEC_MODE_GENERIC);
6012   Constant *UseGenericStateMachineVal = ConstantInt::getSigned(Int8, !IsSPMD);
6013   Constant *MayUseNestedParallelismVal = ConstantInt::getSigned(Int8, true);
6014   Constant *DebugIndentionLevelVal = ConstantInt::getSigned(Int16, 0);
6015 
6016   Function *Kernel = Builder.GetInsertBlock()->getParent();
6017 
6018   // Manifest the launch configuration in the metadata matching the kernel
6019   // environment.
6020   if (MinTeamsVal > 1 || MaxTeamsVal > 0)
6021     writeTeamsForKernel(T, *Kernel, MinTeamsVal, MaxTeamsVal);
6022 
6023   // For max values, < 0 means unset, == 0 means set but unknown.
6024   if (MaxThreadsVal < 0)
6025     MaxThreadsVal = std::max(
6026         int32_t(getGridValue(T, Kernel).GV_Default_WG_Size), MinThreadsVal);
6027 
6028   if (MaxThreadsVal > 0)
6029     writeThreadBoundsForKernel(T, *Kernel, MinThreadsVal, MaxThreadsVal);
6030 
6031   Constant *MinThreads = ConstantInt::getSigned(Int32, MinThreadsVal);
6032   Constant *MaxThreads = ConstantInt::getSigned(Int32, MaxThreadsVal);
6033   Constant *MinTeams = ConstantInt::getSigned(Int32, MinTeamsVal);
6034   Constant *MaxTeams = ConstantInt::getSigned(Int32, MaxTeamsVal);
6035   Constant *ReductionDataSize = ConstantInt::getSigned(Int32, 0);
6036   Constant *ReductionBufferLength = ConstantInt::getSigned(Int32, 0);
6037 
6038   // We need to strip the debug prefix to get the correct kernel name.
6039   StringRef KernelName = Kernel->getName();
6040   const std::string DebugPrefix = "_debug__";
6041   if (KernelName.ends_with(DebugPrefix))
6042     KernelName = KernelName.drop_back(DebugPrefix.length());
6043 
6044   Function *Fn = getOrCreateRuntimeFunctionPtr(
6045       omp::RuntimeFunction::OMPRTL___kmpc_target_init);
6046   const DataLayout &DL = Fn->getDataLayout();
6047 
6048   Twine DynamicEnvironmentName = KernelName + "_dynamic_environment";
6049   Constant *DynamicEnvironmentInitializer =
6050       ConstantStruct::get(DynamicEnvironment, {DebugIndentionLevelVal});
6051   GlobalVariable *DynamicEnvironmentGV = new GlobalVariable(
6052       M, DynamicEnvironment, /*IsConstant=*/false, GlobalValue::WeakODRLinkage,
6053       DynamicEnvironmentInitializer, DynamicEnvironmentName,
6054       /*InsertBefore=*/nullptr, GlobalValue::NotThreadLocal,
6055       DL.getDefaultGlobalsAddressSpace());
6056   DynamicEnvironmentGV->setVisibility(GlobalValue::ProtectedVisibility);
6057 
6058   Constant *DynamicEnvironment =
6059       DynamicEnvironmentGV->getType() == DynamicEnvironmentPtr
6060           ? DynamicEnvironmentGV
6061           : ConstantExpr::getAddrSpaceCast(DynamicEnvironmentGV,
6062                                            DynamicEnvironmentPtr);
6063 
6064   Constant *ConfigurationEnvironmentInitializer = ConstantStruct::get(
6065       ConfigurationEnvironment, {
6066                                     UseGenericStateMachineVal,
6067                                     MayUseNestedParallelismVal,
6068                                     IsSPMDVal,
6069                                     MinThreads,
6070                                     MaxThreads,
6071                                     MinTeams,
6072                                     MaxTeams,
6073                                     ReductionDataSize,
6074                                     ReductionBufferLength,
6075                                 });
6076   Constant *KernelEnvironmentInitializer = ConstantStruct::get(
6077       KernelEnvironment, {
6078                              ConfigurationEnvironmentInitializer,
6079                              Ident,
6080                              DynamicEnvironment,
6081                          });
6082   std::string KernelEnvironmentName =
6083       (KernelName + "_kernel_environment").str();
6084   GlobalVariable *KernelEnvironmentGV = new GlobalVariable(
6085       M, KernelEnvironment, /*IsConstant=*/true, GlobalValue::WeakODRLinkage,
6086       KernelEnvironmentInitializer, KernelEnvironmentName,
6087       /*InsertBefore=*/nullptr, GlobalValue::NotThreadLocal,
6088       DL.getDefaultGlobalsAddressSpace());
6089   KernelEnvironmentGV->setVisibility(GlobalValue::ProtectedVisibility);
6090 
6091   Constant *KernelEnvironment =
6092       KernelEnvironmentGV->getType() == KernelEnvironmentPtr
6093           ? KernelEnvironmentGV
6094           : ConstantExpr::getAddrSpaceCast(KernelEnvironmentGV,
6095                                            KernelEnvironmentPtr);
6096   Value *KernelLaunchEnvironment = Kernel->getArg(0);
6097   CallInst *ThreadKind =
6098       Builder.CreateCall(Fn, {KernelEnvironment, KernelLaunchEnvironment});
6099 
6100   Value *ExecUserCode = Builder.CreateICmpEQ(
6101       ThreadKind, ConstantInt::get(ThreadKind->getType(), -1),
6102       "exec_user_code");
6103 
6104   // ThreadKind = __kmpc_target_init(...)
6105   // if (ThreadKind == -1)
6106   //   user_code
6107   // else
6108   //   return;
6109 
6110   auto *UI = Builder.CreateUnreachable();
6111   BasicBlock *CheckBB = UI->getParent();
6112   BasicBlock *UserCodeEntryBB = CheckBB->splitBasicBlock(UI, "user_code.entry");
6113 
6114   BasicBlock *WorkerExitBB = BasicBlock::Create(
6115       CheckBB->getContext(), "worker.exit", CheckBB->getParent());
6116   Builder.SetInsertPoint(WorkerExitBB);
6117   Builder.CreateRetVoid();
6118 
6119   auto *CheckBBTI = CheckBB->getTerminator();
6120   Builder.SetInsertPoint(CheckBBTI);
6121   Builder.CreateCondBr(ExecUserCode, UI->getParent(), WorkerExitBB);
6122 
6123   CheckBBTI->eraseFromParent();
6124   UI->eraseFromParent();
6125 
6126   // Continue in the "user_code" block, see diagram above and in
6127   // openmp/libomptarget/deviceRTLs/common/include/target.h .
6128   return InsertPointTy(UserCodeEntryBB, UserCodeEntryBB->getFirstInsertionPt());
6129 }
6130 
6131 void OpenMPIRBuilder::createTargetDeinit(const LocationDescription &Loc,
6132                                          int32_t TeamsReductionDataSize,
6133                                          int32_t TeamsReductionBufferLength) {
6134   if (!updateToLocation(Loc))
6135     return;
6136 
6137   Function *Fn = getOrCreateRuntimeFunctionPtr(
6138       omp::RuntimeFunction::OMPRTL___kmpc_target_deinit);
6139 
6140   Builder.CreateCall(Fn, {});
6141 
6142   if (!TeamsReductionBufferLength || !TeamsReductionDataSize)
6143     return;
6144 
6145   Function *Kernel = Builder.GetInsertBlock()->getParent();
6146   // We need to strip the debug prefix to get the correct kernel name.
6147   StringRef KernelName = Kernel->getName();
6148   const std::string DebugPrefix = "_debug__";
6149   if (KernelName.ends_with(DebugPrefix))
6150     KernelName = KernelName.drop_back(DebugPrefix.length());
6151   auto *KernelEnvironmentGV =
6152       M.getNamedGlobal((KernelName + "_kernel_environment").str());
6153   assert(KernelEnvironmentGV && "Expected kernel environment global\n");
6154   auto *KernelEnvironmentInitializer = KernelEnvironmentGV->getInitializer();
6155   auto *NewInitializer = ConstantFoldInsertValueInstruction(
6156       KernelEnvironmentInitializer,
6157       ConstantInt::get(Int32, TeamsReductionDataSize), {0, 7});
6158   NewInitializer = ConstantFoldInsertValueInstruction(
6159       NewInitializer, ConstantInt::get(Int32, TeamsReductionBufferLength),
6160       {0, 8});
6161   KernelEnvironmentGV->setInitializer(NewInitializer);
6162 }
6163 
6164 static MDNode *getNVPTXMDNode(Function &Kernel, StringRef Name) {
6165   Module &M = *Kernel.getParent();
6166   NamedMDNode *MD = M.getOrInsertNamedMetadata("nvvm.annotations");
6167   for (auto *Op : MD->operands()) {
6168     if (Op->getNumOperands() != 3)
6169       continue;
6170     auto *KernelOp = dyn_cast<ConstantAsMetadata>(Op->getOperand(0));
6171     if (!KernelOp || KernelOp->getValue() != &Kernel)
6172       continue;
6173     auto *Prop = dyn_cast<MDString>(Op->getOperand(1));
6174     if (!Prop || Prop->getString() != Name)
6175       continue;
6176     return Op;
6177   }
6178   return nullptr;
6179 }
6180 
6181 static void updateNVPTXMetadata(Function &Kernel, StringRef Name, int32_t Value,
6182                                 bool Min) {
6183   // Update the "maxntidx" metadata for NVIDIA, or add it.
6184   MDNode *ExistingOp = getNVPTXMDNode(Kernel, Name);
6185   if (ExistingOp) {
6186     auto *OldVal = cast<ConstantAsMetadata>(ExistingOp->getOperand(2));
6187     int32_t OldLimit = cast<ConstantInt>(OldVal->getValue())->getZExtValue();
6188     ExistingOp->replaceOperandWith(
6189         2, ConstantAsMetadata::get(ConstantInt::get(
6190                OldVal->getValue()->getType(),
6191                Min ? std::min(OldLimit, Value) : std::max(OldLimit, Value))));
6192   } else {
6193     LLVMContext &Ctx = Kernel.getContext();
6194     Metadata *MDVals[] = {ConstantAsMetadata::get(&Kernel),
6195                           MDString::get(Ctx, Name),
6196                           ConstantAsMetadata::get(
6197                               ConstantInt::get(Type::getInt32Ty(Ctx), Value))};
6198     // Append metadata to nvvm.annotations
6199     Module &M = *Kernel.getParent();
6200     NamedMDNode *MD = M.getOrInsertNamedMetadata("nvvm.annotations");
6201     MD->addOperand(MDNode::get(Ctx, MDVals));
6202   }
6203 }
6204 
6205 std::pair<int32_t, int32_t>
6206 OpenMPIRBuilder::readThreadBoundsForKernel(const Triple &T, Function &Kernel) {
6207   int32_t ThreadLimit =
6208       Kernel.getFnAttributeAsParsedInteger("omp_target_thread_limit");
6209 
6210   if (T.isAMDGPU()) {
6211     const auto &Attr = Kernel.getFnAttribute("amdgpu-flat-work-group-size");
6212     if (!Attr.isValid() || !Attr.isStringAttribute())
6213       return {0, ThreadLimit};
6214     auto [LBStr, UBStr] = Attr.getValueAsString().split(',');
6215     int32_t LB, UB;
6216     if (!llvm::to_integer(UBStr, UB, 10))
6217       return {0, ThreadLimit};
6218     UB = ThreadLimit ? std::min(ThreadLimit, UB) : UB;
6219     if (!llvm::to_integer(LBStr, LB, 10))
6220       return {0, UB};
6221     return {LB, UB};
6222   }
6223 
6224   if (MDNode *ExistingOp = getNVPTXMDNode(Kernel, "maxntidx")) {
6225     auto *OldVal = cast<ConstantAsMetadata>(ExistingOp->getOperand(2));
6226     int32_t UB = cast<ConstantInt>(OldVal->getValue())->getZExtValue();
6227     return {0, ThreadLimit ? std::min(ThreadLimit, UB) : UB};
6228   }
6229   return {0, ThreadLimit};
6230 }
6231 
6232 void OpenMPIRBuilder::writeThreadBoundsForKernel(const Triple &T,
6233                                                  Function &Kernel, int32_t LB,
6234                                                  int32_t UB) {
6235   Kernel.addFnAttr("omp_target_thread_limit", std::to_string(UB));
6236 
6237   if (T.isAMDGPU()) {
6238     Kernel.addFnAttr("amdgpu-flat-work-group-size",
6239                      llvm::utostr(LB) + "," + llvm::utostr(UB));
6240     return;
6241   }
6242 
6243   updateNVPTXMetadata(Kernel, "maxntidx", UB, true);
6244 }
6245 
6246 std::pair<int32_t, int32_t>
6247 OpenMPIRBuilder::readTeamBoundsForKernel(const Triple &, Function &Kernel) {
6248   // TODO: Read from backend annotations if available.
6249   return {0, Kernel.getFnAttributeAsParsedInteger("omp_target_num_teams")};
6250 }
6251 
6252 void OpenMPIRBuilder::writeTeamsForKernel(const Triple &T, Function &Kernel,
6253                                           int32_t LB, int32_t UB) {
6254   if (T.isNVPTX())
6255     if (UB > 0)
6256       updateNVPTXMetadata(Kernel, "maxclusterrank", UB, true);
6257   if (T.isAMDGPU())
6258     Kernel.addFnAttr("amdgpu-max-num-workgroups", llvm::utostr(LB) + ",1,1");
6259 
6260   Kernel.addFnAttr("omp_target_num_teams", std::to_string(LB));
6261 }
6262 
6263 void OpenMPIRBuilder::setOutlinedTargetRegionFunctionAttributes(
6264     Function *OutlinedFn) {
6265   if (Config.isTargetDevice()) {
6266     OutlinedFn->setLinkage(GlobalValue::WeakODRLinkage);
6267     // TODO: Determine if DSO local can be set to true.
6268     OutlinedFn->setDSOLocal(false);
6269     OutlinedFn->setVisibility(GlobalValue::ProtectedVisibility);
6270     if (T.isAMDGCN())
6271       OutlinedFn->setCallingConv(CallingConv::AMDGPU_KERNEL);
6272   }
6273 }
6274 
6275 Constant *OpenMPIRBuilder::createOutlinedFunctionID(Function *OutlinedFn,
6276                                                     StringRef EntryFnIDName) {
6277   if (Config.isTargetDevice()) {
6278     assert(OutlinedFn && "The outlined function must exist if embedded");
6279     return OutlinedFn;
6280   }
6281 
6282   return new GlobalVariable(
6283       M, Builder.getInt8Ty(), /*isConstant=*/true, GlobalValue::WeakAnyLinkage,
6284       Constant::getNullValue(Builder.getInt8Ty()), EntryFnIDName);
6285 }
6286 
6287 Constant *OpenMPIRBuilder::createTargetRegionEntryAddr(Function *OutlinedFn,
6288                                                        StringRef EntryFnName) {
6289   if (OutlinedFn)
6290     return OutlinedFn;
6291 
6292   assert(!M.getGlobalVariable(EntryFnName, true) &&
6293          "Named kernel already exists?");
6294   return new GlobalVariable(
6295       M, Builder.getInt8Ty(), /*isConstant=*/true, GlobalValue::InternalLinkage,
6296       Constant::getNullValue(Builder.getInt8Ty()), EntryFnName);
6297 }
6298 
6299 void OpenMPIRBuilder::emitTargetRegionFunction(
6300     TargetRegionEntryInfo &EntryInfo,
6301     FunctionGenCallback &GenerateFunctionCallback, bool IsOffloadEntry,
6302     Function *&OutlinedFn, Constant *&OutlinedFnID) {
6303 
6304   SmallString<64> EntryFnName;
6305   OffloadInfoManager.getTargetRegionEntryFnName(EntryFnName, EntryInfo);
6306 
6307   OutlinedFn = Config.isTargetDevice() || !Config.openMPOffloadMandatory()
6308                    ? GenerateFunctionCallback(EntryFnName)
6309                    : nullptr;
6310 
6311   // If this target outline function is not an offload entry, we don't need to
6312   // register it. This may be in the case of a false if clause, or if there are
6313   // no OpenMP targets.
6314   if (!IsOffloadEntry)
6315     return;
6316 
6317   std::string EntryFnIDName =
6318       Config.isTargetDevice()
6319           ? std::string(EntryFnName)
6320           : createPlatformSpecificName({EntryFnName, "region_id"});
6321 
6322   OutlinedFnID = registerTargetRegionFunction(EntryInfo, OutlinedFn,
6323                                               EntryFnName, EntryFnIDName);
6324 }
6325 
6326 Constant *OpenMPIRBuilder::registerTargetRegionFunction(
6327     TargetRegionEntryInfo &EntryInfo, Function *OutlinedFn,
6328     StringRef EntryFnName, StringRef EntryFnIDName) {
6329   if (OutlinedFn)
6330     setOutlinedTargetRegionFunctionAttributes(OutlinedFn);
6331   auto OutlinedFnID = createOutlinedFunctionID(OutlinedFn, EntryFnIDName);
6332   auto EntryAddr = createTargetRegionEntryAddr(OutlinedFn, EntryFnName);
6333   OffloadInfoManager.registerTargetRegionEntryInfo(
6334       EntryInfo, EntryAddr, OutlinedFnID,
6335       OffloadEntriesInfoManager::OMPTargetRegionEntryTargetRegion);
6336   return OutlinedFnID;
6337 }
6338 
6339 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetData(
6340     const LocationDescription &Loc, InsertPointTy AllocaIP,
6341     InsertPointTy CodeGenIP, Value *DeviceID, Value *IfCond,
6342     TargetDataInfo &Info, GenMapInfoCallbackTy GenMapInfoCB,
6343     omp::RuntimeFunction *MapperFunc,
6344     function_ref<InsertPointTy(InsertPointTy CodeGenIP, BodyGenTy BodyGenType)>
6345         BodyGenCB,
6346     function_ref<void(unsigned int, Value *)> DeviceAddrCB,
6347     function_ref<Value *(unsigned int)> CustomMapperCB, Value *SrcLocInfo) {
6348   if (!updateToLocation(Loc))
6349     return InsertPointTy();
6350 
6351   // Disable TargetData CodeGen on Device pass.
6352   if (Config.IsTargetDevice.value_or(false)) {
6353     if (BodyGenCB)
6354       Builder.restoreIP(BodyGenCB(Builder.saveIP(), BodyGenTy::NoPriv));
6355     return Builder.saveIP();
6356   }
6357 
6358   Builder.restoreIP(CodeGenIP);
6359   bool IsStandAlone = !BodyGenCB;
6360   MapInfosTy *MapInfo;
6361   // Generate the code for the opening of the data environment. Capture all the
6362   // arguments of the runtime call by reference because they are used in the
6363   // closing of the region.
6364   auto BeginThenGen = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
6365     MapInfo = &GenMapInfoCB(Builder.saveIP());
6366     emitOffloadingArrays(AllocaIP, Builder.saveIP(), *MapInfo, Info,
6367                          /*IsNonContiguous=*/true, DeviceAddrCB,
6368                          CustomMapperCB);
6369 
6370     TargetDataRTArgs RTArgs;
6371     emitOffloadingArraysArgument(Builder, RTArgs, Info,
6372                                  !MapInfo->Names.empty());
6373 
6374     // Emit the number of elements in the offloading arrays.
6375     Value *PointerNum = Builder.getInt32(Info.NumberOfPtrs);
6376 
6377     // Source location for the ident struct
6378     if (!SrcLocInfo) {
6379       uint32_t SrcLocStrSize;
6380       Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
6381       SrcLocInfo = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
6382     }
6383 
6384     Value *OffloadingArgs[] = {SrcLocInfo,           DeviceID,
6385                                PointerNum,           RTArgs.BasePointersArray,
6386                                RTArgs.PointersArray, RTArgs.SizesArray,
6387                                RTArgs.MapTypesArray, RTArgs.MapNamesArray,
6388                                RTArgs.MappersArray};
6389 
6390     if (IsStandAlone) {
6391       assert(MapperFunc && "MapperFunc missing for standalone target data");
6392       Builder.CreateCall(getOrCreateRuntimeFunctionPtr(*MapperFunc),
6393                          OffloadingArgs);
6394     } else {
6395       Function *BeginMapperFunc = getOrCreateRuntimeFunctionPtr(
6396           omp::OMPRTL___tgt_target_data_begin_mapper);
6397 
6398       Builder.CreateCall(BeginMapperFunc, OffloadingArgs);
6399 
6400       for (auto DeviceMap : Info.DevicePtrInfoMap) {
6401         if (isa<AllocaInst>(DeviceMap.second.second)) {
6402           auto *LI =
6403               Builder.CreateLoad(Builder.getPtrTy(), DeviceMap.second.first);
6404           Builder.CreateStore(LI, DeviceMap.second.second);
6405         }
6406       }
6407 
6408       // If device pointer privatization is required, emit the body of the
6409       // region here. It will have to be duplicated: with and without
6410       // privatization.
6411       Builder.restoreIP(BodyGenCB(Builder.saveIP(), BodyGenTy::Priv));
6412     }
6413   };
6414 
6415   // If we need device pointer privatization, we need to emit the body of the
6416   // region with no privatization in the 'else' branch of the conditional.
6417   // Otherwise, we don't have to do anything.
6418   auto BeginElseGen = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
6419     Builder.restoreIP(BodyGenCB(Builder.saveIP(), BodyGenTy::DupNoPriv));
6420   };
6421 
6422   // Generate code for the closing of the data region.
6423   auto EndThenGen = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
6424     TargetDataRTArgs RTArgs;
6425     emitOffloadingArraysArgument(Builder, RTArgs, Info, !MapInfo->Names.empty(),
6426                                  /*ForEndCall=*/true);
6427 
6428     // Emit the number of elements in the offloading arrays.
6429     Value *PointerNum = Builder.getInt32(Info.NumberOfPtrs);
6430 
6431     // Source location for the ident struct
6432     if (!SrcLocInfo) {
6433       uint32_t SrcLocStrSize;
6434       Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
6435       SrcLocInfo = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
6436     }
6437 
6438     Value *OffloadingArgs[] = {SrcLocInfo,           DeviceID,
6439                                PointerNum,           RTArgs.BasePointersArray,
6440                                RTArgs.PointersArray, RTArgs.SizesArray,
6441                                RTArgs.MapTypesArray, RTArgs.MapNamesArray,
6442                                RTArgs.MappersArray};
6443     Function *EndMapperFunc =
6444         getOrCreateRuntimeFunctionPtr(omp::OMPRTL___tgt_target_data_end_mapper);
6445 
6446     Builder.CreateCall(EndMapperFunc, OffloadingArgs);
6447   };
6448 
6449   // We don't have to do anything to close the region if the if clause evaluates
6450   // to false.
6451   auto EndElseGen = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {};
6452 
6453   if (BodyGenCB) {
6454     if (IfCond) {
6455       emitIfClause(IfCond, BeginThenGen, BeginElseGen, AllocaIP);
6456     } else {
6457       BeginThenGen(AllocaIP, Builder.saveIP());
6458     }
6459 
6460     // If we don't require privatization of device pointers, we emit the body in
6461     // between the runtime calls. This avoids duplicating the body code.
6462     Builder.restoreIP(BodyGenCB(Builder.saveIP(), BodyGenTy::NoPriv));
6463 
6464     if (IfCond) {
6465       emitIfClause(IfCond, EndThenGen, EndElseGen, AllocaIP);
6466     } else {
6467       EndThenGen(AllocaIP, Builder.saveIP());
6468     }
6469   } else {
6470     if (IfCond) {
6471       emitIfClause(IfCond, BeginThenGen, EndElseGen, AllocaIP);
6472     } else {
6473       BeginThenGen(AllocaIP, Builder.saveIP());
6474     }
6475   }
6476 
6477   return Builder.saveIP();
6478 }
6479 
6480 FunctionCallee
6481 OpenMPIRBuilder::createForStaticInitFunction(unsigned IVSize, bool IVSigned,
6482                                              bool IsGPUDistribute) {
6483   assert((IVSize == 32 || IVSize == 64) &&
6484          "IV size is not compatible with the omp runtime");
6485   RuntimeFunction Name;
6486   if (IsGPUDistribute)
6487     Name = IVSize == 32
6488                ? (IVSigned ? omp::OMPRTL___kmpc_distribute_static_init_4
6489                            : omp::OMPRTL___kmpc_distribute_static_init_4u)
6490                : (IVSigned ? omp::OMPRTL___kmpc_distribute_static_init_8
6491                            : omp::OMPRTL___kmpc_distribute_static_init_8u);
6492   else
6493     Name = IVSize == 32 ? (IVSigned ? omp::OMPRTL___kmpc_for_static_init_4
6494                                     : omp::OMPRTL___kmpc_for_static_init_4u)
6495                         : (IVSigned ? omp::OMPRTL___kmpc_for_static_init_8
6496                                     : omp::OMPRTL___kmpc_for_static_init_8u);
6497 
6498   return getOrCreateRuntimeFunction(M, Name);
6499 }
6500 
6501 FunctionCallee OpenMPIRBuilder::createDispatchInitFunction(unsigned IVSize,
6502                                                            bool IVSigned) {
6503   assert((IVSize == 32 || IVSize == 64) &&
6504          "IV size is not compatible with the omp runtime");
6505   RuntimeFunction Name = IVSize == 32
6506                              ? (IVSigned ? omp::OMPRTL___kmpc_dispatch_init_4
6507                                          : omp::OMPRTL___kmpc_dispatch_init_4u)
6508                              : (IVSigned ? omp::OMPRTL___kmpc_dispatch_init_8
6509                                          : omp::OMPRTL___kmpc_dispatch_init_8u);
6510 
6511   return getOrCreateRuntimeFunction(M, Name);
6512 }
6513 
6514 FunctionCallee OpenMPIRBuilder::createDispatchNextFunction(unsigned IVSize,
6515                                                            bool IVSigned) {
6516   assert((IVSize == 32 || IVSize == 64) &&
6517          "IV size is not compatible with the omp runtime");
6518   RuntimeFunction Name = IVSize == 32
6519                              ? (IVSigned ? omp::OMPRTL___kmpc_dispatch_next_4
6520                                          : omp::OMPRTL___kmpc_dispatch_next_4u)
6521                              : (IVSigned ? omp::OMPRTL___kmpc_dispatch_next_8
6522                                          : omp::OMPRTL___kmpc_dispatch_next_8u);
6523 
6524   return getOrCreateRuntimeFunction(M, Name);
6525 }
6526 
6527 FunctionCallee OpenMPIRBuilder::createDispatchFiniFunction(unsigned IVSize,
6528                                                            bool IVSigned) {
6529   assert((IVSize == 32 || IVSize == 64) &&
6530          "IV size is not compatible with the omp runtime");
6531   RuntimeFunction Name = IVSize == 32
6532                              ? (IVSigned ? omp::OMPRTL___kmpc_dispatch_fini_4
6533                                          : omp::OMPRTL___kmpc_dispatch_fini_4u)
6534                              : (IVSigned ? omp::OMPRTL___kmpc_dispatch_fini_8
6535                                          : omp::OMPRTL___kmpc_dispatch_fini_8u);
6536 
6537   return getOrCreateRuntimeFunction(M, Name);
6538 }
6539 
6540 FunctionCallee OpenMPIRBuilder::createDispatchDeinitFunction() {
6541   return getOrCreateRuntimeFunction(M, omp::OMPRTL___kmpc_dispatch_deinit);
6542 }
6543 
6544 static Function *createOutlinedFunction(
6545     OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, StringRef FuncName,
6546     SmallVectorImpl<Value *> &Inputs,
6547     OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
6548     OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy &ArgAccessorFuncCB) {
6549   SmallVector<Type *> ParameterTypes;
6550   if (OMPBuilder.Config.isTargetDevice()) {
6551     // Add the "implicit" runtime argument we use to provide launch specific
6552     // information for target devices.
6553     auto *Int8PtrTy = PointerType::getUnqual(Builder.getContext());
6554     ParameterTypes.push_back(Int8PtrTy);
6555 
6556     // All parameters to target devices are passed as pointers
6557     // or i64. This assumes 64-bit address spaces/pointers.
6558     for (auto &Arg : Inputs)
6559       ParameterTypes.push_back(Arg->getType()->isPointerTy()
6560                                    ? Arg->getType()
6561                                    : Type::getInt64Ty(Builder.getContext()));
6562   } else {
6563     for (auto &Arg : Inputs)
6564       ParameterTypes.push_back(Arg->getType());
6565   }
6566 
6567   auto FuncType = FunctionType::get(Builder.getVoidTy(), ParameterTypes,
6568                                     /*isVarArg*/ false);
6569   auto Func = Function::Create(FuncType, GlobalValue::InternalLinkage, FuncName,
6570                                Builder.GetInsertBlock()->getModule());
6571 
6572   // Save insert point.
6573   auto OldInsertPoint = Builder.saveIP();
6574 
6575   // Generate the region into the function.
6576   BasicBlock *EntryBB = BasicBlock::Create(Builder.getContext(), "entry", Func);
6577   Builder.SetInsertPoint(EntryBB);
6578 
6579   // Insert target init call in the device compilation pass.
6580   if (OMPBuilder.Config.isTargetDevice())
6581     Builder.restoreIP(OMPBuilder.createTargetInit(Builder, /*IsSPMD*/ false));
6582 
6583   BasicBlock *UserCodeEntryBB = Builder.GetInsertBlock();
6584 
6585   // As we embed the user code in the middle of our target region after we
6586   // generate entry code, we must move what allocas we can into the entry
6587   // block to avoid possible breaking optimisations for device
6588   if (OMPBuilder.Config.isTargetDevice())
6589     OMPBuilder.ConstantAllocaRaiseCandidates.emplace_back(Func);
6590 
6591   // Insert target deinit call in the device compilation pass.
6592   Builder.restoreIP(CBFunc(Builder.saveIP(), Builder.saveIP()));
6593   if (OMPBuilder.Config.isTargetDevice())
6594     OMPBuilder.createTargetDeinit(Builder);
6595 
6596   // Insert return instruction.
6597   Builder.CreateRetVoid();
6598 
6599   // New Alloca IP at entry point of created device function.
6600   Builder.SetInsertPoint(EntryBB->getFirstNonPHI());
6601   auto AllocaIP = Builder.saveIP();
6602 
6603   Builder.SetInsertPoint(UserCodeEntryBB->getFirstNonPHIOrDbg());
6604 
6605   // Skip the artificial dyn_ptr on the device.
6606   const auto &ArgRange =
6607       OMPBuilder.Config.isTargetDevice()
6608           ? make_range(Func->arg_begin() + 1, Func->arg_end())
6609           : Func->args();
6610 
6611   auto ReplaceValue = [](Value *Input, Value *InputCopy, Function *Func) {
6612     // Things like GEP's can come in the form of Constants. Constants and
6613     // ConstantExpr's do not have access to the knowledge of what they're
6614     // contained in, so we must dig a little to find an instruction so we
6615     // can tell if they're used inside of the function we're outlining. We
6616     // also replace the original constant expression with a new instruction
6617     // equivalent; an instruction as it allows easy modification in the
6618     // following loop, as we can now know the constant (instruction) is
6619     // owned by our target function and replaceUsesOfWith can now be invoked
6620     // on it (cannot do this with constants it seems). A brand new one also
6621     // allows us to be cautious as it is perhaps possible the old expression
6622     // was used inside of the function but exists and is used externally
6623     // (unlikely by the nature of a Constant, but still).
6624     // NOTE: We cannot remove dead constants that have been rewritten to
6625     // instructions at this stage, we run the risk of breaking later lowering
6626     // by doing so as we could still be in the process of lowering the module
6627     // from MLIR to LLVM-IR and the MLIR lowering may still require the original
6628     // constants we have created rewritten versions of.
6629     if (auto *Const = dyn_cast<Constant>(Input))
6630       convertUsersOfConstantsToInstructions(Const, Func, false);
6631 
6632     // Collect all the instructions
6633     for (User *User : make_early_inc_range(Input->users()))
6634       if (auto *Instr = dyn_cast<Instruction>(User))
6635         if (Instr->getFunction() == Func)
6636           Instr->replaceUsesOfWith(Input, InputCopy);
6637   };
6638 
6639   SmallVector<std::pair<Value *, Value *>> DeferredReplacement;
6640 
6641   // Rewrite uses of input valus to parameters.
6642   for (auto InArg : zip(Inputs, ArgRange)) {
6643     Value *Input = std::get<0>(InArg);
6644     Argument &Arg = std::get<1>(InArg);
6645     Value *InputCopy = nullptr;
6646 
6647     Builder.restoreIP(
6648         ArgAccessorFuncCB(Arg, Input, InputCopy, AllocaIP, Builder.saveIP()));
6649 
6650     // In certain cases a Global may be set up for replacement, however, this
6651     // Global may be used in multiple arguments to the kernel, just segmented
6652     // apart, for example, if we have a global array, that is sectioned into
6653     // multiple mappings (technically not legal in OpenMP, but there is a case
6654     // in Fortran for Common Blocks where this is neccesary), we will end up
6655     // with GEP's into this array inside the kernel, that refer to the Global
6656     // but are technically seperate arguments to the kernel for all intents and
6657     // purposes. If we have mapped a segment that requires a GEP into the 0-th
6658     // index, it will fold into an referal to the Global, if we then encounter
6659     // this folded GEP during replacement all of the references to the
6660     // Global in the kernel will be replaced with the argument we have generated
6661     // that corresponds to it, including any other GEP's that refer to the
6662     // Global that may be other arguments. This will invalidate all of the other
6663     // preceding mapped arguments that refer to the same global that may be
6664     // seperate segments. To prevent this, we defer global processing until all
6665     // other processing has been performed.
6666     if (llvm::isa<llvm::GlobalValue>(std::get<0>(InArg)) ||
6667         llvm::isa<llvm::GlobalObject>(std::get<0>(InArg)) ||
6668         llvm::isa<llvm::GlobalVariable>(std::get<0>(InArg))) {
6669       DeferredReplacement.push_back(std::make_pair(Input, InputCopy));
6670       continue;
6671     }
6672 
6673     ReplaceValue(Input, InputCopy, Func);
6674   }
6675 
6676   // Replace all of our deferred Input values, currently just Globals.
6677   for (auto Deferred : DeferredReplacement)
6678     ReplaceValue(std::get<0>(Deferred), std::get<1>(Deferred), Func);
6679 
6680   // Restore insert point.
6681   Builder.restoreIP(OldInsertPoint);
6682 
6683   return Func;
6684 }
6685 
6686 /// Create an entry point for a target task with the following.
6687 /// It'll have the following signature
6688 /// void @.omp_target_task_proxy_func(i32 %thread.id, ptr %task)
6689 /// This function is called from emitTargetTask once the
6690 /// code to launch the target kernel has been outlined already.
6691 static Function *emitTargetTaskProxyFunction(OpenMPIRBuilder &OMPBuilder,
6692                                              IRBuilderBase &Builder,
6693                                              CallInst *StaleCI) {
6694   Module &M = OMPBuilder.M;
6695   // KernelLaunchFunction is the target launch function, i.e.
6696   // the function that sets up kernel arguments and calls
6697   // __tgt_target_kernel to launch the kernel on the device.
6698   //
6699   Function *KernelLaunchFunction = StaleCI->getCalledFunction();
6700 
6701   // StaleCI is the CallInst which is the call to the outlined
6702   // target kernel launch function. If there are values that the
6703   // outlined function uses then these are aggregated into a structure
6704   // which is passed as the second argument. If not, then there's
6705   // only one argument, the threadID. So, StaleCI can be
6706   //
6707   // %structArg = alloca { ptr, ptr }, align 8
6708   // %gep_ = getelementptr { ptr, ptr }, ptr %structArg, i32 0, i32 0
6709   // store ptr %20, ptr %gep_, align 8
6710   // %gep_8 = getelementptr { ptr, ptr }, ptr %structArg, i32 0, i32 1
6711   // store ptr %21, ptr %gep_8, align 8
6712   // call void @_QQmain..omp_par.1(i32 %global.tid.val6, ptr %structArg)
6713   //
6714   // OR
6715   //
6716   // call void @_QQmain..omp_par.1(i32 %global.tid.val6)
6717   OpenMPIRBuilder::InsertPointTy IP(StaleCI->getParent(),
6718                                     StaleCI->getIterator());
6719   LLVMContext &Ctx = StaleCI->getParent()->getContext();
6720   Type *ThreadIDTy = Type::getInt32Ty(Ctx);
6721   Type *TaskPtrTy = OMPBuilder.TaskPtr;
6722   Type *TaskTy = OMPBuilder.Task;
6723   auto ProxyFnTy =
6724       FunctionType::get(Builder.getVoidTy(), {ThreadIDTy, TaskPtrTy},
6725                         /* isVarArg */ false);
6726   auto ProxyFn = Function::Create(ProxyFnTy, GlobalValue::InternalLinkage,
6727                                   ".omp_target_task_proxy_func",
6728                                   Builder.GetInsertBlock()->getModule());
6729   ProxyFn->getArg(0)->setName("thread.id");
6730   ProxyFn->getArg(1)->setName("task");
6731 
6732   BasicBlock *EntryBB =
6733       BasicBlock::Create(Builder.getContext(), "entry", ProxyFn);
6734   Builder.SetInsertPoint(EntryBB);
6735 
6736   bool HasShareds = StaleCI->arg_size() > 1;
6737   // TODO: This is a temporary assert to prove to ourselves that
6738   // the outlined target launch function is always going to have
6739   // atmost two arguments if there is any data shared between
6740   // host and device.
6741   assert((!HasShareds || (StaleCI->arg_size() == 2)) &&
6742          "StaleCI with shareds should have exactly two arguments.");
6743   if (HasShareds) {
6744     auto *ArgStructAlloca = dyn_cast<AllocaInst>(StaleCI->getArgOperand(1));
6745     assert(ArgStructAlloca &&
6746            "Unable to find the alloca instruction corresponding to arguments "
6747            "for extracted function");
6748     auto *ArgStructType =
6749         dyn_cast<StructType>(ArgStructAlloca->getAllocatedType());
6750 
6751     AllocaInst *NewArgStructAlloca =
6752         Builder.CreateAlloca(ArgStructType, nullptr, "structArg");
6753     Value *TaskT = ProxyFn->getArg(1);
6754     Value *ThreadId = ProxyFn->getArg(0);
6755     Value *SharedsSize =
6756         Builder.getInt64(M.getDataLayout().getTypeStoreSize(ArgStructType));
6757 
6758     Value *Shareds = Builder.CreateStructGEP(TaskTy, TaskT, 0);
6759     LoadInst *LoadShared =
6760         Builder.CreateLoad(PointerType::getUnqual(Ctx), Shareds);
6761 
6762     Builder.CreateMemCpy(
6763         NewArgStructAlloca, NewArgStructAlloca->getAlign(), LoadShared,
6764         LoadShared->getPointerAlignment(M.getDataLayout()), SharedsSize);
6765 
6766     Builder.CreateCall(KernelLaunchFunction, {ThreadId, NewArgStructAlloca});
6767   }
6768   Builder.CreateRetVoid();
6769   return ProxyFn;
6770 }
6771 static void emitTargetOutlinedFunction(
6772     OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
6773     TargetRegionEntryInfo &EntryInfo, Function *&OutlinedFn,
6774     Constant *&OutlinedFnID, SmallVectorImpl<Value *> &Inputs,
6775     OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
6776     OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy &ArgAccessorFuncCB) {
6777 
6778   OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction =
6779       [&OMPBuilder, &Builder, &Inputs, &CBFunc,
6780        &ArgAccessorFuncCB](StringRef EntryFnName) {
6781         return createOutlinedFunction(OMPBuilder, Builder, EntryFnName, Inputs,
6782                                       CBFunc, ArgAccessorFuncCB);
6783       };
6784 
6785   OMPBuilder.emitTargetRegionFunction(EntryInfo, GenerateOutlinedFunction, true,
6786                                       OutlinedFn, OutlinedFnID);
6787 }
6788 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitTargetTask(
6789     Function *OutlinedFn, Value *OutlinedFnID,
6790     EmitFallbackCallbackTy EmitTargetCallFallbackCB, TargetKernelArgs &Args,
6791     Value *DeviceID, Value *RTLoc, OpenMPIRBuilder::InsertPointTy AllocaIP,
6792     SmallVector<llvm::OpenMPIRBuilder::DependData> &Dependencies,
6793     bool HasNoWait) {
6794 
6795   // When we arrive at this function, the target region itself has been
6796   // outlined into the function OutlinedFn.
6797   // So at ths point, for
6798   // --------------------------------------------------
6799   //   void user_code_that_offloads(...) {
6800   //     omp target depend(..) map(from:a) map(to:b, c)
6801   //        a = b + c
6802   //   }
6803   //
6804   // --------------------------------------------------
6805   //
6806   // we have
6807   //
6808   // --------------------------------------------------
6809   //
6810   //   void user_code_that_offloads(...) {
6811   //     %.offload_baseptrs = alloca [3 x ptr], align 8
6812   //     %.offload_ptrs = alloca [3 x ptr], align 8
6813   //     %.offload_mappers = alloca [3 x ptr], align 8
6814   //     ;; target region has been outlined and now we need to
6815   //     ;; offload to it via a target task.
6816   //   }
6817   //   void outlined_device_function(ptr a, ptr b, ptr c) {
6818   //     *a = *b + *c
6819   //   }
6820   //
6821   // We have to now do the following
6822   // (i)   Make an offloading call to outlined_device_function using the OpenMP
6823   //       RTL. See 'kernel_launch_function' in the pseudo code below. This is
6824   //       emitted by emitKernelLaunch
6825   // (ii)  Create a task entry point function that calls kernel_launch_function
6826   //       and is the entry point for the target task. See
6827   //       '@.omp_target_task_proxy_func in the pseudocode below.
6828   // (iii) Create a task with the task entry point created in (ii)
6829   //
6830   // That is we create the following
6831   //
6832   //   void user_code_that_offloads(...) {
6833   //     %.offload_baseptrs = alloca [3 x ptr], align 8
6834   //     %.offload_ptrs = alloca [3 x ptr], align 8
6835   //     %.offload_mappers = alloca [3 x ptr], align 8
6836   //
6837   //     %structArg = alloca { ptr, ptr, ptr }, align 8
6838   //     %strucArg[0] = %.offload_baseptrs
6839   //     %strucArg[1] = %.offload_ptrs
6840   //     %strucArg[2] = %.offload_mappers
6841   //     proxy_target_task = @__kmpc_omp_task_alloc(...,
6842   //                                               @.omp_target_task_proxy_func)
6843   //     memcpy(proxy_target_task->shareds, %structArg, sizeof(structArg))
6844   //     dependencies_array = ...
6845   //     ;; if nowait not present
6846   //     call @__kmpc_omp_wait_deps(..., dependencies_array)
6847   //     call @__kmpc_omp_task_begin_if0(...)
6848   //     call @ @.omp_target_task_proxy_func(i32 thread_id, ptr
6849   //     %proxy_target_task) call @__kmpc_omp_task_complete_if0(...)
6850   //   }
6851   //
6852   //   define internal void @.omp_target_task_proxy_func(i32 %thread.id,
6853   //                                                     ptr %task) {
6854   //       %structArg = alloca {ptr, ptr, ptr}
6855   //       %shared_data = load (getelementptr %task, 0, 0)
6856   //       mempcy(%structArg, %shared_data, sizeof(structArg))
6857   //       kernel_launch_function(%thread.id, %structArg)
6858   //   }
6859   //
6860   //   We need the proxy function because the signature of the task entry point
6861   //   expected by kmpc_omp_task is always the same and will be different from
6862   //   that of the kernel_launch function.
6863   //
6864   //   kernel_launch_function is generated by emitKernelLaunch and has the
6865   //   always_inline attribute.
6866   //   void kernel_launch_function(thread_id,
6867   //                               structArg) alwaysinline {
6868   //       %kernel_args = alloca %struct.__tgt_kernel_arguments, align 8
6869   //       offload_baseptrs = load(getelementptr structArg, 0, 0)
6870   //       offload_ptrs = load(getelementptr structArg, 0, 1)
6871   //       offload_mappers = load(getelementptr structArg, 0, 2)
6872   //       ; setup kernel_args using offload_baseptrs, offload_ptrs and
6873   //       ; offload_mappers
6874   //       call i32 @__tgt_target_kernel(...,
6875   //                                     outlined_device_function,
6876   //                                     ptr %kernel_args)
6877   //   }
6878   //   void outlined_device_function(ptr a, ptr b, ptr c) {
6879   //      *a = *b + *c
6880   //   }
6881   //
6882   BasicBlock *TargetTaskBodyBB =
6883       splitBB(Builder, /*CreateBranch=*/true, "target.task.body");
6884   BasicBlock *TargetTaskAllocaBB =
6885       splitBB(Builder, /*CreateBranch=*/true, "target.task.alloca");
6886 
6887   InsertPointTy TargetTaskAllocaIP(TargetTaskAllocaBB,
6888                                    TargetTaskAllocaBB->begin());
6889   InsertPointTy TargetTaskBodyIP(TargetTaskBodyBB, TargetTaskBodyBB->begin());
6890 
6891   OutlineInfo OI;
6892   OI.EntryBB = TargetTaskAllocaBB;
6893   OI.OuterAllocaBB = AllocaIP.getBlock();
6894 
6895   // Add the thread ID argument.
6896   SmallVector<Instruction *, 4> ToBeDeleted;
6897   OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
6898       Builder, AllocaIP, ToBeDeleted, TargetTaskAllocaIP, "global.tid", false));
6899 
6900   Builder.restoreIP(TargetTaskBodyIP);
6901 
6902   // emitKernelLaunch makes the necessary runtime call to offload the kernel.
6903   // We then outline all that code into a separate function
6904   // ('kernel_launch_function' in the pseudo code above). This function is then
6905   // called by the target task proxy function (see
6906   // '@.omp_target_task_proxy_func' in the pseudo code above)
6907   // "@.omp_target_task_proxy_func' is generated by emitTargetTaskProxyFunction
6908   Builder.restoreIP(emitKernelLaunch(Builder, OutlinedFn, OutlinedFnID,
6909                                      EmitTargetCallFallbackCB, Args, DeviceID,
6910                                      RTLoc, TargetTaskAllocaIP));
6911 
6912   OI.ExitBB = Builder.saveIP().getBlock();
6913   OI.PostOutlineCB = [this, ToBeDeleted, Dependencies,
6914                       HasNoWait](Function &OutlinedFn) mutable {
6915     assert(OutlinedFn.getNumUses() == 1 &&
6916            "there must be a single user for the outlined function");
6917 
6918     CallInst *StaleCI = cast<CallInst>(OutlinedFn.user_back());
6919     bool HasShareds = StaleCI->arg_size() > 1;
6920 
6921     Function *ProxyFn = emitTargetTaskProxyFunction(*this, Builder, StaleCI);
6922 
6923     LLVM_DEBUG(dbgs() << "Proxy task entry function created: " << *ProxyFn
6924                       << "\n");
6925 
6926     Builder.SetInsertPoint(StaleCI);
6927 
6928     // Gather the arguments for emitting the runtime call.
6929     uint32_t SrcLocStrSize;
6930     Constant *SrcLocStr =
6931         getOrCreateSrcLocStr(LocationDescription(Builder), SrcLocStrSize);
6932     Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
6933 
6934     // @__kmpc_omp_task_alloc
6935     Function *TaskAllocFn =
6936         getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_alloc);
6937 
6938     // Arguments - `loc_ref` (Ident) and `gtid` (ThreadID)
6939     // call.
6940     Value *ThreadID = getOrCreateThreadID(Ident);
6941 
6942     // Argument - `sizeof_kmp_task_t` (TaskSize)
6943     // Tasksize refers to the size in bytes of kmp_task_t data structure
6944     // including private vars accessed in task.
6945     // TODO: add kmp_task_t_with_privates (privates)
6946     Value *TaskSize =
6947         Builder.getInt64(M.getDataLayout().getTypeStoreSize(Task));
6948 
6949     // Argument - `sizeof_shareds` (SharedsSize)
6950     // SharedsSize refers to the shareds array size in the kmp_task_t data
6951     // structure.
6952     Value *SharedsSize = Builder.getInt64(0);
6953     if (HasShareds) {
6954       auto *ArgStructAlloca = dyn_cast<AllocaInst>(StaleCI->getArgOperand(1));
6955       assert(ArgStructAlloca &&
6956              "Unable to find the alloca instruction corresponding to arguments "
6957              "for extracted function");
6958       auto *ArgStructType =
6959           dyn_cast<StructType>(ArgStructAlloca->getAllocatedType());
6960       assert(ArgStructType && "Unable to find struct type corresponding to "
6961                               "arguments for extracted function");
6962       SharedsSize =
6963           Builder.getInt64(M.getDataLayout().getTypeStoreSize(ArgStructType));
6964     }
6965 
6966     // Argument - `flags`
6967     // Task is tied iff (Flags & 1) == 1.
6968     // Task is untied iff (Flags & 1) == 0.
6969     // Task is final iff (Flags & 2) == 2.
6970     // Task is not final iff (Flags & 2) == 0.
6971     // A target task is not final and is untied.
6972     Value *Flags = Builder.getInt32(0);
6973 
6974     // Emit the @__kmpc_omp_task_alloc runtime call
6975     // The runtime call returns a pointer to an area where the task captured
6976     // variables must be copied before the task is run (TaskData)
6977     CallInst *TaskData = Builder.CreateCall(
6978         TaskAllocFn, {/*loc_ref=*/Ident, /*gtid=*/ThreadID, /*flags=*/Flags,
6979                       /*sizeof_task=*/TaskSize, /*sizeof_shared=*/SharedsSize,
6980                       /*task_func=*/ProxyFn});
6981 
6982     if (HasShareds) {
6983       Value *Shareds = StaleCI->getArgOperand(1);
6984       Align Alignment = TaskData->getPointerAlignment(M.getDataLayout());
6985       Value *TaskShareds = Builder.CreateLoad(VoidPtr, TaskData);
6986       Builder.CreateMemCpy(TaskShareds, Alignment, Shareds, Alignment,
6987                            SharedsSize);
6988     }
6989 
6990     Value *DepArray = emitTaskDependencies(*this, Dependencies);
6991 
6992     // ---------------------------------------------------------------
6993     // V5.2 13.8 target construct
6994     // If the nowait clause is present, execution of the target task
6995     // may be deferred. If the nowait clause is not present, the target task is
6996     // an included task.
6997     // ---------------------------------------------------------------
6998     // The above means that the lack of a nowait on the target construct
6999     // translates to '#pragma omp task if(0)'
7000     if (!HasNoWait) {
7001       if (DepArray) {
7002         Function *TaskWaitFn =
7003             getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_wait_deps);
7004         Builder.CreateCall(
7005             TaskWaitFn,
7006             {/*loc_ref=*/Ident, /*gtid=*/ThreadID,
7007              /*ndeps=*/Builder.getInt32(Dependencies.size()),
7008              /*dep_list=*/DepArray,
7009              /*ndeps_noalias=*/ConstantInt::get(Builder.getInt32Ty(), 0),
7010              /*noalias_dep_list=*/
7011              ConstantPointerNull::get(PointerType::getUnqual(M.getContext()))});
7012       }
7013       // Included task.
7014       Function *TaskBeginFn =
7015           getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_begin_if0);
7016       Function *TaskCompleteFn =
7017           getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_complete_if0);
7018       Builder.CreateCall(TaskBeginFn, {Ident, ThreadID, TaskData});
7019       CallInst *CI = nullptr;
7020       if (HasShareds)
7021         CI = Builder.CreateCall(ProxyFn, {ThreadID, TaskData});
7022       else
7023         CI = Builder.CreateCall(ProxyFn, {ThreadID});
7024       CI->setDebugLoc(StaleCI->getDebugLoc());
7025       Builder.CreateCall(TaskCompleteFn, {Ident, ThreadID, TaskData});
7026     } else if (DepArray) {
7027       // HasNoWait - meaning the task may be deferred. Call
7028       // __kmpc_omp_task_with_deps if there are dependencies,
7029       // else call __kmpc_omp_task
7030       Function *TaskFn =
7031           getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_with_deps);
7032       Builder.CreateCall(
7033           TaskFn,
7034           {Ident, ThreadID, TaskData, Builder.getInt32(Dependencies.size()),
7035            DepArray, ConstantInt::get(Builder.getInt32Ty(), 0),
7036            ConstantPointerNull::get(PointerType::getUnqual(M.getContext()))});
7037     } else {
7038       // Emit the @__kmpc_omp_task runtime call to spawn the task
7039       Function *TaskFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task);
7040       Builder.CreateCall(TaskFn, {Ident, ThreadID, TaskData});
7041     }
7042 
7043     StaleCI->eraseFromParent();
7044     llvm::for_each(llvm::reverse(ToBeDeleted),
7045                    [](Instruction *I) { I->eraseFromParent(); });
7046   };
7047   addOutlineInfo(std::move(OI));
7048 
7049   LLVM_DEBUG(dbgs() << "Insert block after emitKernelLaunch = \n"
7050                     << *(Builder.GetInsertBlock()) << "\n");
7051   LLVM_DEBUG(dbgs() << "Module after emitKernelLaunch = \n"
7052                     << *(Builder.GetInsertBlock()->getParent()->getParent())
7053                     << "\n");
7054   return Builder.saveIP();
7055 }
7056 static void emitTargetCall(
7057     OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
7058     OpenMPIRBuilder::InsertPointTy AllocaIP, Function *OutlinedFn,
7059     Constant *OutlinedFnID, int32_t NumTeams, int32_t NumThreads,
7060     SmallVectorImpl<Value *> &Args,
7061     OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
7062     SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies = {}) {
7063 
7064   OpenMPIRBuilder::TargetDataInfo Info(
7065       /*RequiresDevicePointerInfo=*/false,
7066       /*SeparateBeginEndCalls=*/true);
7067 
7068   OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP());
7069   OMPBuilder.emitOffloadingArrays(AllocaIP, Builder.saveIP(), MapInfo, Info,
7070                                   /*IsNonContiguous=*/true);
7071 
7072   OpenMPIRBuilder::TargetDataRTArgs RTArgs;
7073   OMPBuilder.emitOffloadingArraysArgument(Builder, RTArgs, Info,
7074                                           !MapInfo.Names.empty());
7075 
7076   //  emitKernelLaunch
7077   auto &&EmitTargetCallFallbackCB =
7078       [&](OpenMPIRBuilder::InsertPointTy IP) -> OpenMPIRBuilder::InsertPointTy {
7079     Builder.restoreIP(IP);
7080     Builder.CreateCall(OutlinedFn, Args);
7081     return Builder.saveIP();
7082   };
7083 
7084   unsigned NumTargetItems = MapInfo.BasePointers.size();
7085   // TODO: Use correct device ID
7086   Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF);
7087   Value *NumTeamsVal = Builder.getInt32(NumTeams);
7088   Value *NumThreadsVal = Builder.getInt32(NumThreads);
7089   uint32_t SrcLocStrSize;
7090   Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
7091   Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
7092                                              llvm::omp::IdentFlag(0), 0);
7093   // TODO: Use correct NumIterations
7094   Value *NumIterations = Builder.getInt64(0);
7095   // TODO: Use correct DynCGGroupMem
7096   Value *DynCGGroupMem = Builder.getInt32(0);
7097 
7098   bool HasNoWait = false;
7099   bool HasDependencies = Dependencies.size() > 0;
7100   bool RequiresOuterTargetTask = HasNoWait || HasDependencies;
7101 
7102   OpenMPIRBuilder::TargetKernelArgs KArgs(NumTargetItems, RTArgs, NumIterations,
7103                                           NumTeamsVal, NumThreadsVal,
7104                                           DynCGGroupMem, HasNoWait);
7105 
7106   // The presence of certain clauses on the target directive require the
7107   // explicit generation of the target task.
7108   if (RequiresOuterTargetTask) {
7109     Builder.restoreIP(OMPBuilder.emitTargetTask(
7110         OutlinedFn, OutlinedFnID, EmitTargetCallFallbackCB, KArgs, DeviceID,
7111         RTLoc, AllocaIP, Dependencies, HasNoWait));
7112   } else {
7113     Builder.restoreIP(OMPBuilder.emitKernelLaunch(
7114         Builder, OutlinedFn, OutlinedFnID, EmitTargetCallFallbackCB, KArgs,
7115         DeviceID, RTLoc, AllocaIP));
7116   }
7117 }
7118 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTarget(
7119     const LocationDescription &Loc, InsertPointTy AllocaIP,
7120     InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo, int32_t NumTeams,
7121     int32_t NumThreads, SmallVectorImpl<Value *> &Args,
7122     GenMapInfoCallbackTy GenMapInfoCB,
7123     OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
7124     OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
7125     SmallVector<DependData> Dependencies) {
7126 
7127   if (!updateToLocation(Loc))
7128     return InsertPointTy();
7129 
7130   Builder.restoreIP(CodeGenIP);
7131 
7132   Function *OutlinedFn;
7133   Constant *OutlinedFnID;
7134   // The target region is outlined into its own function. The LLVM IR for
7135   // the target region itself is generated using the callbacks CBFunc
7136   // and ArgAccessorFuncCB
7137   emitTargetOutlinedFunction(*this, Builder, EntryInfo, OutlinedFn,
7138                              OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB);
7139 
7140   // If we are not on the target device, then we need to generate code
7141   // to make a remote call (offload) to the previously outlined function
7142   // that represents the target region. Do that now.
7143   if (!Config.isTargetDevice())
7144     emitTargetCall(*this, Builder, AllocaIP, OutlinedFn, OutlinedFnID, NumTeams,
7145                    NumThreads, Args, GenMapInfoCB, Dependencies);
7146   return Builder.saveIP();
7147 }
7148 
7149 std::string OpenMPIRBuilder::getNameWithSeparators(ArrayRef<StringRef> Parts,
7150                                                    StringRef FirstSeparator,
7151                                                    StringRef Separator) {
7152   SmallString<128> Buffer;
7153   llvm::raw_svector_ostream OS(Buffer);
7154   StringRef Sep = FirstSeparator;
7155   for (StringRef Part : Parts) {
7156     OS << Sep << Part;
7157     Sep = Separator;
7158   }
7159   return OS.str().str();
7160 }
7161 
7162 std::string
7163 OpenMPIRBuilder::createPlatformSpecificName(ArrayRef<StringRef> Parts) const {
7164   return OpenMPIRBuilder::getNameWithSeparators(Parts, Config.firstSeparator(),
7165                                                 Config.separator());
7166 }
7167 
7168 GlobalVariable *
7169 OpenMPIRBuilder::getOrCreateInternalVariable(Type *Ty, const StringRef &Name,
7170                                              unsigned AddressSpace) {
7171   auto &Elem = *InternalVars.try_emplace(Name, nullptr).first;
7172   if (Elem.second) {
7173     assert(Elem.second->getValueType() == Ty &&
7174            "OMP internal variable has different type than requested");
7175   } else {
7176     // TODO: investigate the appropriate linkage type used for the global
7177     // variable for possibly changing that to internal or private, or maybe
7178     // create different versions of the function for different OMP internal
7179     // variables.
7180     auto Linkage = this->M.getTargetTriple().rfind("wasm32") == 0
7181                        ? GlobalValue::ExternalLinkage
7182                        : GlobalValue::CommonLinkage;
7183     auto *GV = new GlobalVariable(M, Ty, /*IsConstant=*/false, Linkage,
7184                                   Constant::getNullValue(Ty), Elem.first(),
7185                                   /*InsertBefore=*/nullptr,
7186                                   GlobalValue::NotThreadLocal, AddressSpace);
7187     const DataLayout &DL = M.getDataLayout();
7188     const llvm::Align TypeAlign = DL.getABITypeAlign(Ty);
7189     const llvm::Align PtrAlign = DL.getPointerABIAlignment(AddressSpace);
7190     GV->setAlignment(std::max(TypeAlign, PtrAlign));
7191     Elem.second = GV;
7192   }
7193 
7194   return Elem.second;
7195 }
7196 
7197 Value *OpenMPIRBuilder::getOMPCriticalRegionLock(StringRef CriticalName) {
7198   std::string Prefix = Twine("gomp_critical_user_", CriticalName).str();
7199   std::string Name = getNameWithSeparators({Prefix, "var"}, ".", ".");
7200   return getOrCreateInternalVariable(KmpCriticalNameTy, Name);
7201 }
7202 
7203 Value *OpenMPIRBuilder::getSizeInBytes(Value *BasePtr) {
7204   LLVMContext &Ctx = Builder.getContext();
7205   Value *Null =
7206       Constant::getNullValue(PointerType::getUnqual(BasePtr->getContext()));
7207   Value *SizeGep =
7208       Builder.CreateGEP(BasePtr->getType(), Null, Builder.getInt32(1));
7209   Value *SizePtrToInt = Builder.CreatePtrToInt(SizeGep, Type::getInt64Ty(Ctx));
7210   return SizePtrToInt;
7211 }
7212 
7213 GlobalVariable *
7214 OpenMPIRBuilder::createOffloadMaptypes(SmallVectorImpl<uint64_t> &Mappings,
7215                                        std::string VarName) {
7216   llvm::Constant *MaptypesArrayInit =
7217       llvm::ConstantDataArray::get(M.getContext(), Mappings);
7218   auto *MaptypesArrayGlobal = new llvm::GlobalVariable(
7219       M, MaptypesArrayInit->getType(),
7220       /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, MaptypesArrayInit,
7221       VarName);
7222   MaptypesArrayGlobal->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global);
7223   return MaptypesArrayGlobal;
7224 }
7225 
7226 void OpenMPIRBuilder::createMapperAllocas(const LocationDescription &Loc,
7227                                           InsertPointTy AllocaIP,
7228                                           unsigned NumOperands,
7229                                           struct MapperAllocas &MapperAllocas) {
7230   if (!updateToLocation(Loc))
7231     return;
7232 
7233   auto *ArrI8PtrTy = ArrayType::get(Int8Ptr, NumOperands);
7234   auto *ArrI64Ty = ArrayType::get(Int64, NumOperands);
7235   Builder.restoreIP(AllocaIP);
7236   AllocaInst *ArgsBase = Builder.CreateAlloca(
7237       ArrI8PtrTy, /* ArraySize = */ nullptr, ".offload_baseptrs");
7238   AllocaInst *Args = Builder.CreateAlloca(ArrI8PtrTy, /* ArraySize = */ nullptr,
7239                                           ".offload_ptrs");
7240   AllocaInst *ArgSizes = Builder.CreateAlloca(
7241       ArrI64Ty, /* ArraySize = */ nullptr, ".offload_sizes");
7242   Builder.restoreIP(Loc.IP);
7243   MapperAllocas.ArgsBase = ArgsBase;
7244   MapperAllocas.Args = Args;
7245   MapperAllocas.ArgSizes = ArgSizes;
7246 }
7247 
7248 void OpenMPIRBuilder::emitMapperCall(const LocationDescription &Loc,
7249                                      Function *MapperFunc, Value *SrcLocInfo,
7250                                      Value *MaptypesArg, Value *MapnamesArg,
7251                                      struct MapperAllocas &MapperAllocas,
7252                                      int64_t DeviceID, unsigned NumOperands) {
7253   if (!updateToLocation(Loc))
7254     return;
7255 
7256   auto *ArrI8PtrTy = ArrayType::get(Int8Ptr, NumOperands);
7257   auto *ArrI64Ty = ArrayType::get(Int64, NumOperands);
7258   Value *ArgsBaseGEP =
7259       Builder.CreateInBoundsGEP(ArrI8PtrTy, MapperAllocas.ArgsBase,
7260                                 {Builder.getInt32(0), Builder.getInt32(0)});
7261   Value *ArgsGEP =
7262       Builder.CreateInBoundsGEP(ArrI8PtrTy, MapperAllocas.Args,
7263                                 {Builder.getInt32(0), Builder.getInt32(0)});
7264   Value *ArgSizesGEP =
7265       Builder.CreateInBoundsGEP(ArrI64Ty, MapperAllocas.ArgSizes,
7266                                 {Builder.getInt32(0), Builder.getInt32(0)});
7267   Value *NullPtr =
7268       Constant::getNullValue(PointerType::getUnqual(Int8Ptr->getContext()));
7269   Builder.CreateCall(MapperFunc,
7270                      {SrcLocInfo, Builder.getInt64(DeviceID),
7271                       Builder.getInt32(NumOperands), ArgsBaseGEP, ArgsGEP,
7272                       ArgSizesGEP, MaptypesArg, MapnamesArg, NullPtr});
7273 }
7274 
7275 void OpenMPIRBuilder::emitOffloadingArraysArgument(IRBuilderBase &Builder,
7276                                                    TargetDataRTArgs &RTArgs,
7277                                                    TargetDataInfo &Info,
7278                                                    bool EmitDebug,
7279                                                    bool ForEndCall) {
7280   assert((!ForEndCall || Info.separateBeginEndCalls()) &&
7281          "expected region end call to runtime only when end call is separate");
7282   auto UnqualPtrTy = PointerType::getUnqual(M.getContext());
7283   auto VoidPtrTy = UnqualPtrTy;
7284   auto VoidPtrPtrTy = UnqualPtrTy;
7285   auto Int64Ty = Type::getInt64Ty(M.getContext());
7286   auto Int64PtrTy = UnqualPtrTy;
7287 
7288   if (!Info.NumberOfPtrs) {
7289     RTArgs.BasePointersArray = ConstantPointerNull::get(VoidPtrPtrTy);
7290     RTArgs.PointersArray = ConstantPointerNull::get(VoidPtrPtrTy);
7291     RTArgs.SizesArray = ConstantPointerNull::get(Int64PtrTy);
7292     RTArgs.MapTypesArray = ConstantPointerNull::get(Int64PtrTy);
7293     RTArgs.MapNamesArray = ConstantPointerNull::get(VoidPtrPtrTy);
7294     RTArgs.MappersArray = ConstantPointerNull::get(VoidPtrPtrTy);
7295     return;
7296   }
7297 
7298   RTArgs.BasePointersArray = Builder.CreateConstInBoundsGEP2_32(
7299       ArrayType::get(VoidPtrTy, Info.NumberOfPtrs),
7300       Info.RTArgs.BasePointersArray,
7301       /*Idx0=*/0, /*Idx1=*/0);
7302   RTArgs.PointersArray = Builder.CreateConstInBoundsGEP2_32(
7303       ArrayType::get(VoidPtrTy, Info.NumberOfPtrs), Info.RTArgs.PointersArray,
7304       /*Idx0=*/0,
7305       /*Idx1=*/0);
7306   RTArgs.SizesArray = Builder.CreateConstInBoundsGEP2_32(
7307       ArrayType::get(Int64Ty, Info.NumberOfPtrs), Info.RTArgs.SizesArray,
7308       /*Idx0=*/0, /*Idx1=*/0);
7309   RTArgs.MapTypesArray = Builder.CreateConstInBoundsGEP2_32(
7310       ArrayType::get(Int64Ty, Info.NumberOfPtrs),
7311       ForEndCall && Info.RTArgs.MapTypesArrayEnd ? Info.RTArgs.MapTypesArrayEnd
7312                                                  : Info.RTArgs.MapTypesArray,
7313       /*Idx0=*/0,
7314       /*Idx1=*/0);
7315 
7316   // Only emit the mapper information arrays if debug information is
7317   // requested.
7318   if (!EmitDebug)
7319     RTArgs.MapNamesArray = ConstantPointerNull::get(VoidPtrPtrTy);
7320   else
7321     RTArgs.MapNamesArray = Builder.CreateConstInBoundsGEP2_32(
7322         ArrayType::get(VoidPtrTy, Info.NumberOfPtrs), Info.RTArgs.MapNamesArray,
7323         /*Idx0=*/0,
7324         /*Idx1=*/0);
7325   // If there is no user-defined mapper, set the mapper array to nullptr to
7326   // avoid an unnecessary data privatization
7327   if (!Info.HasMapper)
7328     RTArgs.MappersArray = ConstantPointerNull::get(VoidPtrPtrTy);
7329   else
7330     RTArgs.MappersArray =
7331         Builder.CreatePointerCast(Info.RTArgs.MappersArray, VoidPtrPtrTy);
7332 }
7333 
7334 void OpenMPIRBuilder::emitNonContiguousDescriptor(InsertPointTy AllocaIP,
7335                                                   InsertPointTy CodeGenIP,
7336                                                   MapInfosTy &CombinedInfo,
7337                                                   TargetDataInfo &Info) {
7338   MapInfosTy::StructNonContiguousInfo &NonContigInfo =
7339       CombinedInfo.NonContigInfo;
7340 
7341   // Build an array of struct descriptor_dim and then assign it to
7342   // offload_args.
7343   //
7344   // struct descriptor_dim {
7345   //  uint64_t offset;
7346   //  uint64_t count;
7347   //  uint64_t stride
7348   // };
7349   Type *Int64Ty = Builder.getInt64Ty();
7350   StructType *DimTy = StructType::create(
7351       M.getContext(), ArrayRef<Type *>({Int64Ty, Int64Ty, Int64Ty}),
7352       "struct.descriptor_dim");
7353 
7354   enum { OffsetFD = 0, CountFD, StrideFD };
7355   // We need two index variable here since the size of "Dims" is the same as
7356   // the size of Components, however, the size of offset, count, and stride is
7357   // equal to the size of base declaration that is non-contiguous.
7358   for (unsigned I = 0, L = 0, E = NonContigInfo.Dims.size(); I < E; ++I) {
7359     // Skip emitting ir if dimension size is 1 since it cannot be
7360     // non-contiguous.
7361     if (NonContigInfo.Dims[I] == 1)
7362       continue;
7363     Builder.restoreIP(AllocaIP);
7364     ArrayType *ArrayTy = ArrayType::get(DimTy, NonContigInfo.Dims[I]);
7365     AllocaInst *DimsAddr =
7366         Builder.CreateAlloca(ArrayTy, /* ArraySize = */ nullptr, "dims");
7367     Builder.restoreIP(CodeGenIP);
7368     for (unsigned II = 0, EE = NonContigInfo.Dims[I]; II < EE; ++II) {
7369       unsigned RevIdx = EE - II - 1;
7370       Value *DimsLVal = Builder.CreateInBoundsGEP(
7371           DimsAddr->getAllocatedType(), DimsAddr,
7372           {Builder.getInt64(0), Builder.getInt64(II)});
7373       // Offset
7374       Value *OffsetLVal = Builder.CreateStructGEP(DimTy, DimsLVal, OffsetFD);
7375       Builder.CreateAlignedStore(
7376           NonContigInfo.Offsets[L][RevIdx], OffsetLVal,
7377           M.getDataLayout().getPrefTypeAlign(OffsetLVal->getType()));
7378       // Count
7379       Value *CountLVal = Builder.CreateStructGEP(DimTy, DimsLVal, CountFD);
7380       Builder.CreateAlignedStore(
7381           NonContigInfo.Counts[L][RevIdx], CountLVal,
7382           M.getDataLayout().getPrefTypeAlign(CountLVal->getType()));
7383       // Stride
7384       Value *StrideLVal = Builder.CreateStructGEP(DimTy, DimsLVal, StrideFD);
7385       Builder.CreateAlignedStore(
7386           NonContigInfo.Strides[L][RevIdx], StrideLVal,
7387           M.getDataLayout().getPrefTypeAlign(CountLVal->getType()));
7388     }
7389     // args[I] = &dims
7390     Builder.restoreIP(CodeGenIP);
7391     Value *DAddr = Builder.CreatePointerBitCastOrAddrSpaceCast(
7392         DimsAddr, Builder.getPtrTy());
7393     Value *P = Builder.CreateConstInBoundsGEP2_32(
7394         ArrayType::get(Builder.getPtrTy(), Info.NumberOfPtrs),
7395         Info.RTArgs.PointersArray, 0, I);
7396     Builder.CreateAlignedStore(
7397         DAddr, P, M.getDataLayout().getPrefTypeAlign(Builder.getPtrTy()));
7398     ++L;
7399   }
7400 }
7401 
7402 void OpenMPIRBuilder::emitOffloadingArrays(
7403     InsertPointTy AllocaIP, InsertPointTy CodeGenIP, MapInfosTy &CombinedInfo,
7404     TargetDataInfo &Info, bool IsNonContiguous,
7405     function_ref<void(unsigned int, Value *)> DeviceAddrCB,
7406     function_ref<Value *(unsigned int)> CustomMapperCB) {
7407 
7408   // Reset the array information.
7409   Info.clearArrayInfo();
7410   Info.NumberOfPtrs = CombinedInfo.BasePointers.size();
7411 
7412   if (Info.NumberOfPtrs == 0)
7413     return;
7414 
7415   Builder.restoreIP(AllocaIP);
7416   // Detect if we have any capture size requiring runtime evaluation of the
7417   // size so that a constant array could be eventually used.
7418   ArrayType *PointerArrayType =
7419       ArrayType::get(Builder.getPtrTy(), Info.NumberOfPtrs);
7420 
7421   Info.RTArgs.BasePointersArray = Builder.CreateAlloca(
7422       PointerArrayType, /* ArraySize = */ nullptr, ".offload_baseptrs");
7423 
7424   Info.RTArgs.PointersArray = Builder.CreateAlloca(
7425       PointerArrayType, /* ArraySize = */ nullptr, ".offload_ptrs");
7426   AllocaInst *MappersArray = Builder.CreateAlloca(
7427       PointerArrayType, /* ArraySize = */ nullptr, ".offload_mappers");
7428   Info.RTArgs.MappersArray = MappersArray;
7429 
7430   // If we don't have any VLA types or other types that require runtime
7431   // evaluation, we can use a constant array for the map sizes, otherwise we
7432   // need to fill up the arrays as we do for the pointers.
7433   Type *Int64Ty = Builder.getInt64Ty();
7434   SmallVector<Constant *> ConstSizes(CombinedInfo.Sizes.size(),
7435                                      ConstantInt::get(Int64Ty, 0));
7436   SmallBitVector RuntimeSizes(CombinedInfo.Sizes.size());
7437   for (unsigned I = 0, E = CombinedInfo.Sizes.size(); I < E; ++I) {
7438     if (auto *CI = dyn_cast<Constant>(CombinedInfo.Sizes[I])) {
7439       if (!isa<ConstantExpr>(CI) && !isa<GlobalValue>(CI)) {
7440         if (IsNonContiguous &&
7441             static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
7442                 CombinedInfo.Types[I] &
7443                 OpenMPOffloadMappingFlags::OMP_MAP_NON_CONTIG))
7444           ConstSizes[I] =
7445               ConstantInt::get(Int64Ty, CombinedInfo.NonContigInfo.Dims[I]);
7446         else
7447           ConstSizes[I] = CI;
7448         continue;
7449       }
7450     }
7451     RuntimeSizes.set(I);
7452   }
7453 
7454   if (RuntimeSizes.all()) {
7455     ArrayType *SizeArrayType = ArrayType::get(Int64Ty, Info.NumberOfPtrs);
7456     Info.RTArgs.SizesArray = Builder.CreateAlloca(
7457         SizeArrayType, /* ArraySize = */ nullptr, ".offload_sizes");
7458     Builder.restoreIP(CodeGenIP);
7459   } else {
7460     auto *SizesArrayInit = ConstantArray::get(
7461         ArrayType::get(Int64Ty, ConstSizes.size()), ConstSizes);
7462     std::string Name = createPlatformSpecificName({"offload_sizes"});
7463     auto *SizesArrayGbl =
7464         new GlobalVariable(M, SizesArrayInit->getType(), /*isConstant=*/true,
7465                            GlobalValue::PrivateLinkage, SizesArrayInit, Name);
7466     SizesArrayGbl->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
7467 
7468     if (!RuntimeSizes.any()) {
7469       Info.RTArgs.SizesArray = SizesArrayGbl;
7470     } else {
7471       unsigned IndexSize = M.getDataLayout().getIndexSizeInBits(0);
7472       Align OffloadSizeAlign = M.getDataLayout().getABIIntegerTypeAlignment(64);
7473       ArrayType *SizeArrayType = ArrayType::get(Int64Ty, Info.NumberOfPtrs);
7474       AllocaInst *Buffer = Builder.CreateAlloca(
7475           SizeArrayType, /* ArraySize = */ nullptr, ".offload_sizes");
7476       Buffer->setAlignment(OffloadSizeAlign);
7477       Builder.restoreIP(CodeGenIP);
7478       Builder.CreateMemCpy(
7479           Buffer, M.getDataLayout().getPrefTypeAlign(Buffer->getType()),
7480           SizesArrayGbl, OffloadSizeAlign,
7481           Builder.getIntN(
7482               IndexSize,
7483               Buffer->getAllocationSize(M.getDataLayout())->getFixedValue()));
7484 
7485       Info.RTArgs.SizesArray = Buffer;
7486     }
7487     Builder.restoreIP(CodeGenIP);
7488   }
7489 
7490   // The map types are always constant so we don't need to generate code to
7491   // fill arrays. Instead, we create an array constant.
7492   SmallVector<uint64_t, 4> Mapping;
7493   for (auto mapFlag : CombinedInfo.Types)
7494     Mapping.push_back(
7495         static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
7496             mapFlag));
7497   std::string MaptypesName = createPlatformSpecificName({"offload_maptypes"});
7498   auto *MapTypesArrayGbl = createOffloadMaptypes(Mapping, MaptypesName);
7499   Info.RTArgs.MapTypesArray = MapTypesArrayGbl;
7500 
7501   // The information types are only built if provided.
7502   if (!CombinedInfo.Names.empty()) {
7503     std::string MapnamesName = createPlatformSpecificName({"offload_mapnames"});
7504     auto *MapNamesArrayGbl =
7505         createOffloadMapnames(CombinedInfo.Names, MapnamesName);
7506     Info.RTArgs.MapNamesArray = MapNamesArrayGbl;
7507   } else {
7508     Info.RTArgs.MapNamesArray =
7509         Constant::getNullValue(PointerType::getUnqual(Builder.getContext()));
7510   }
7511 
7512   // If there's a present map type modifier, it must not be applied to the end
7513   // of a region, so generate a separate map type array in that case.
7514   if (Info.separateBeginEndCalls()) {
7515     bool EndMapTypesDiffer = false;
7516     for (uint64_t &Type : Mapping) {
7517       if (Type & static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
7518                      OpenMPOffloadMappingFlags::OMP_MAP_PRESENT)) {
7519         Type &= ~static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
7520             OpenMPOffloadMappingFlags::OMP_MAP_PRESENT);
7521         EndMapTypesDiffer = true;
7522       }
7523     }
7524     if (EndMapTypesDiffer) {
7525       MapTypesArrayGbl = createOffloadMaptypes(Mapping, MaptypesName);
7526       Info.RTArgs.MapTypesArrayEnd = MapTypesArrayGbl;
7527     }
7528   }
7529 
7530   PointerType *PtrTy = Builder.getPtrTy();
7531   for (unsigned I = 0; I < Info.NumberOfPtrs; ++I) {
7532     Value *BPVal = CombinedInfo.BasePointers[I];
7533     Value *BP = Builder.CreateConstInBoundsGEP2_32(
7534         ArrayType::get(PtrTy, Info.NumberOfPtrs), Info.RTArgs.BasePointersArray,
7535         0, I);
7536     Builder.CreateAlignedStore(BPVal, BP,
7537                                M.getDataLayout().getPrefTypeAlign(PtrTy));
7538 
7539     if (Info.requiresDevicePointerInfo()) {
7540       if (CombinedInfo.DevicePointers[I] == DeviceInfoTy::Pointer) {
7541         CodeGenIP = Builder.saveIP();
7542         Builder.restoreIP(AllocaIP);
7543         Info.DevicePtrInfoMap[BPVal] = {BP, Builder.CreateAlloca(PtrTy)};
7544         Builder.restoreIP(CodeGenIP);
7545         if (DeviceAddrCB)
7546           DeviceAddrCB(I, Info.DevicePtrInfoMap[BPVal].second);
7547       } else if (CombinedInfo.DevicePointers[I] == DeviceInfoTy::Address) {
7548         Info.DevicePtrInfoMap[BPVal] = {BP, BP};
7549         if (DeviceAddrCB)
7550           DeviceAddrCB(I, BP);
7551       }
7552     }
7553 
7554     Value *PVal = CombinedInfo.Pointers[I];
7555     Value *P = Builder.CreateConstInBoundsGEP2_32(
7556         ArrayType::get(PtrTy, Info.NumberOfPtrs), Info.RTArgs.PointersArray, 0,
7557         I);
7558     // TODO: Check alignment correct.
7559     Builder.CreateAlignedStore(PVal, P,
7560                                M.getDataLayout().getPrefTypeAlign(PtrTy));
7561 
7562     if (RuntimeSizes.test(I)) {
7563       Value *S = Builder.CreateConstInBoundsGEP2_32(
7564           ArrayType::get(Int64Ty, Info.NumberOfPtrs), Info.RTArgs.SizesArray,
7565           /*Idx0=*/0,
7566           /*Idx1=*/I);
7567       Builder.CreateAlignedStore(Builder.CreateIntCast(CombinedInfo.Sizes[I],
7568                                                        Int64Ty,
7569                                                        /*isSigned=*/true),
7570                                  S, M.getDataLayout().getPrefTypeAlign(PtrTy));
7571     }
7572     // Fill up the mapper array.
7573     unsigned IndexSize = M.getDataLayout().getIndexSizeInBits(0);
7574     Value *MFunc = ConstantPointerNull::get(PtrTy);
7575     if (CustomMapperCB)
7576       if (Value *CustomMFunc = CustomMapperCB(I))
7577         MFunc = Builder.CreatePointerCast(CustomMFunc, PtrTy);
7578     Value *MAddr = Builder.CreateInBoundsGEP(
7579         MappersArray->getAllocatedType(), MappersArray,
7580         {Builder.getIntN(IndexSize, 0), Builder.getIntN(IndexSize, I)});
7581     Builder.CreateAlignedStore(
7582         MFunc, MAddr, M.getDataLayout().getPrefTypeAlign(MAddr->getType()));
7583   }
7584 
7585   if (!IsNonContiguous || CombinedInfo.NonContigInfo.Offsets.empty() ||
7586       Info.NumberOfPtrs == 0)
7587     return;
7588   emitNonContiguousDescriptor(AllocaIP, CodeGenIP, CombinedInfo, Info);
7589 }
7590 
7591 void OpenMPIRBuilder::emitBranch(BasicBlock *Target) {
7592   BasicBlock *CurBB = Builder.GetInsertBlock();
7593 
7594   if (!CurBB || CurBB->getTerminator()) {
7595     // If there is no insert point or the previous block is already
7596     // terminated, don't touch it.
7597   } else {
7598     // Otherwise, create a fall-through branch.
7599     Builder.CreateBr(Target);
7600   }
7601 
7602   Builder.ClearInsertionPoint();
7603 }
7604 
7605 void OpenMPIRBuilder::emitBlock(BasicBlock *BB, Function *CurFn,
7606                                 bool IsFinished) {
7607   BasicBlock *CurBB = Builder.GetInsertBlock();
7608 
7609   // Fall out of the current block (if necessary).
7610   emitBranch(BB);
7611 
7612   if (IsFinished && BB->use_empty()) {
7613     BB->eraseFromParent();
7614     return;
7615   }
7616 
7617   // Place the block after the current block, if possible, or else at
7618   // the end of the function.
7619   if (CurBB && CurBB->getParent())
7620     CurFn->insert(std::next(CurBB->getIterator()), BB);
7621   else
7622     CurFn->insert(CurFn->end(), BB);
7623   Builder.SetInsertPoint(BB);
7624 }
7625 
7626 void OpenMPIRBuilder::emitIfClause(Value *Cond, BodyGenCallbackTy ThenGen,
7627                                    BodyGenCallbackTy ElseGen,
7628                                    InsertPointTy AllocaIP) {
7629   // If the condition constant folds and can be elided, try to avoid emitting
7630   // the condition and the dead arm of the if/else.
7631   if (auto *CI = dyn_cast<ConstantInt>(Cond)) {
7632     auto CondConstant = CI->getSExtValue();
7633     if (CondConstant)
7634       ThenGen(AllocaIP, Builder.saveIP());
7635     else
7636       ElseGen(AllocaIP, Builder.saveIP());
7637     return;
7638   }
7639 
7640   Function *CurFn = Builder.GetInsertBlock()->getParent();
7641 
7642   // Otherwise, the condition did not fold, or we couldn't elide it.  Just
7643   // emit the conditional branch.
7644   BasicBlock *ThenBlock = BasicBlock::Create(M.getContext(), "omp_if.then");
7645   BasicBlock *ElseBlock = BasicBlock::Create(M.getContext(), "omp_if.else");
7646   BasicBlock *ContBlock = BasicBlock::Create(M.getContext(), "omp_if.end");
7647   Builder.CreateCondBr(Cond, ThenBlock, ElseBlock);
7648   // Emit the 'then' code.
7649   emitBlock(ThenBlock, CurFn);
7650   ThenGen(AllocaIP, Builder.saveIP());
7651   emitBranch(ContBlock);
7652   // Emit the 'else' code if present.
7653   // There is no need to emit line number for unconditional branch.
7654   emitBlock(ElseBlock, CurFn);
7655   ElseGen(AllocaIP, Builder.saveIP());
7656   // There is no need to emit line number for unconditional branch.
7657   emitBranch(ContBlock);
7658   // Emit the continuation block for code after the if.
7659   emitBlock(ContBlock, CurFn, /*IsFinished=*/true);
7660 }
7661 
7662 bool OpenMPIRBuilder::checkAndEmitFlushAfterAtomic(
7663     const LocationDescription &Loc, llvm::AtomicOrdering AO, AtomicKind AK) {
7664   assert(!(AO == AtomicOrdering::NotAtomic ||
7665            AO == llvm::AtomicOrdering::Unordered) &&
7666          "Unexpected Atomic Ordering.");
7667 
7668   bool Flush = false;
7669   llvm::AtomicOrdering FlushAO = AtomicOrdering::Monotonic;
7670 
7671   switch (AK) {
7672   case Read:
7673     if (AO == AtomicOrdering::Acquire || AO == AtomicOrdering::AcquireRelease ||
7674         AO == AtomicOrdering::SequentiallyConsistent) {
7675       FlushAO = AtomicOrdering::Acquire;
7676       Flush = true;
7677     }
7678     break;
7679   case Write:
7680   case Compare:
7681   case Update:
7682     if (AO == AtomicOrdering::Release || AO == AtomicOrdering::AcquireRelease ||
7683         AO == AtomicOrdering::SequentiallyConsistent) {
7684       FlushAO = AtomicOrdering::Release;
7685       Flush = true;
7686     }
7687     break;
7688   case Capture:
7689     switch (AO) {
7690     case AtomicOrdering::Acquire:
7691       FlushAO = AtomicOrdering::Acquire;
7692       Flush = true;
7693       break;
7694     case AtomicOrdering::Release:
7695       FlushAO = AtomicOrdering::Release;
7696       Flush = true;
7697       break;
7698     case AtomicOrdering::AcquireRelease:
7699     case AtomicOrdering::SequentiallyConsistent:
7700       FlushAO = AtomicOrdering::AcquireRelease;
7701       Flush = true;
7702       break;
7703     default:
7704       // do nothing - leave silently.
7705       break;
7706     }
7707   }
7708 
7709   if (Flush) {
7710     // Currently Flush RT call still doesn't take memory_ordering, so for when
7711     // that happens, this tries to do the resolution of which atomic ordering
7712     // to use with but issue the flush call
7713     // TODO: pass `FlushAO` after memory ordering support is added
7714     (void)FlushAO;
7715     emitFlush(Loc);
7716   }
7717 
7718   // for AO == AtomicOrdering::Monotonic and  all other case combinations
7719   // do nothing
7720   return Flush;
7721 }
7722 
7723 OpenMPIRBuilder::InsertPointTy
7724 OpenMPIRBuilder::createAtomicRead(const LocationDescription &Loc,
7725                                   AtomicOpValue &X, AtomicOpValue &V,
7726                                   AtomicOrdering AO) {
7727   if (!updateToLocation(Loc))
7728     return Loc.IP;
7729 
7730   assert(X.Var->getType()->isPointerTy() &&
7731          "OMP Atomic expects a pointer to target memory");
7732   Type *XElemTy = X.ElemTy;
7733   assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
7734           XElemTy->isPointerTy()) &&
7735          "OMP atomic read expected a scalar type");
7736 
7737   Value *XRead = nullptr;
7738 
7739   if (XElemTy->isIntegerTy()) {
7740     LoadInst *XLD =
7741         Builder.CreateLoad(XElemTy, X.Var, X.IsVolatile, "omp.atomic.read");
7742     XLD->setAtomic(AO);
7743     XRead = cast<Value>(XLD);
7744   } else {
7745     // We need to perform atomic op as integer
7746     IntegerType *IntCastTy =
7747         IntegerType::get(M.getContext(), XElemTy->getScalarSizeInBits());
7748     LoadInst *XLoad =
7749         Builder.CreateLoad(IntCastTy, X.Var, X.IsVolatile, "omp.atomic.load");
7750     XLoad->setAtomic(AO);
7751     if (XElemTy->isFloatingPointTy()) {
7752       XRead = Builder.CreateBitCast(XLoad, XElemTy, "atomic.flt.cast");
7753     } else {
7754       XRead = Builder.CreateIntToPtr(XLoad, XElemTy, "atomic.ptr.cast");
7755     }
7756   }
7757   checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Read);
7758   Builder.CreateStore(XRead, V.Var, V.IsVolatile);
7759   return Builder.saveIP();
7760 }
7761 
7762 OpenMPIRBuilder::InsertPointTy
7763 OpenMPIRBuilder::createAtomicWrite(const LocationDescription &Loc,
7764                                    AtomicOpValue &X, Value *Expr,
7765                                    AtomicOrdering AO) {
7766   if (!updateToLocation(Loc))
7767     return Loc.IP;
7768 
7769   assert(X.Var->getType()->isPointerTy() &&
7770          "OMP Atomic expects a pointer to target memory");
7771   Type *XElemTy = X.ElemTy;
7772   assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
7773           XElemTy->isPointerTy()) &&
7774          "OMP atomic write expected a scalar type");
7775 
7776   if (XElemTy->isIntegerTy()) {
7777     StoreInst *XSt = Builder.CreateStore(Expr, X.Var, X.IsVolatile);
7778     XSt->setAtomic(AO);
7779   } else {
7780     // We need to bitcast and perform atomic op as integers
7781     IntegerType *IntCastTy =
7782         IntegerType::get(M.getContext(), XElemTy->getScalarSizeInBits());
7783     Value *ExprCast =
7784         Builder.CreateBitCast(Expr, IntCastTy, "atomic.src.int.cast");
7785     StoreInst *XSt = Builder.CreateStore(ExprCast, X.Var, X.IsVolatile);
7786     XSt->setAtomic(AO);
7787   }
7788 
7789   checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Write);
7790   return Builder.saveIP();
7791 }
7792 
7793 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicUpdate(
7794     const LocationDescription &Loc, InsertPointTy AllocaIP, AtomicOpValue &X,
7795     Value *Expr, AtomicOrdering AO, AtomicRMWInst::BinOp RMWOp,
7796     AtomicUpdateCallbackTy &UpdateOp, bool IsXBinopExpr) {
7797   assert(!isConflictIP(Loc.IP, AllocaIP) && "IPs must not be ambiguous");
7798   if (!updateToLocation(Loc))
7799     return Loc.IP;
7800 
7801   LLVM_DEBUG({
7802     Type *XTy = X.Var->getType();
7803     assert(XTy->isPointerTy() &&
7804            "OMP Atomic expects a pointer to target memory");
7805     Type *XElemTy = X.ElemTy;
7806     assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
7807             XElemTy->isPointerTy()) &&
7808            "OMP atomic update expected a scalar type");
7809     assert((RMWOp != AtomicRMWInst::Max) && (RMWOp != AtomicRMWInst::Min) &&
7810            (RMWOp != AtomicRMWInst::UMax) && (RMWOp != AtomicRMWInst::UMin) &&
7811            "OpenMP atomic does not support LT or GT operations");
7812   });
7813 
7814   emitAtomicUpdate(AllocaIP, X.Var, X.ElemTy, Expr, AO, RMWOp, UpdateOp,
7815                    X.IsVolatile, IsXBinopExpr);
7816   checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Update);
7817   return Builder.saveIP();
7818 }
7819 
7820 // FIXME: Duplicating AtomicExpand
7821 Value *OpenMPIRBuilder::emitRMWOpAsInstruction(Value *Src1, Value *Src2,
7822                                                AtomicRMWInst::BinOp RMWOp) {
7823   switch (RMWOp) {
7824   case AtomicRMWInst::Add:
7825     return Builder.CreateAdd(Src1, Src2);
7826   case AtomicRMWInst::Sub:
7827     return Builder.CreateSub(Src1, Src2);
7828   case AtomicRMWInst::And:
7829     return Builder.CreateAnd(Src1, Src2);
7830   case AtomicRMWInst::Nand:
7831     return Builder.CreateNeg(Builder.CreateAnd(Src1, Src2));
7832   case AtomicRMWInst::Or:
7833     return Builder.CreateOr(Src1, Src2);
7834   case AtomicRMWInst::Xor:
7835     return Builder.CreateXor(Src1, Src2);
7836   case AtomicRMWInst::Xchg:
7837   case AtomicRMWInst::FAdd:
7838   case AtomicRMWInst::FSub:
7839   case AtomicRMWInst::BAD_BINOP:
7840   case AtomicRMWInst::Max:
7841   case AtomicRMWInst::Min:
7842   case AtomicRMWInst::UMax:
7843   case AtomicRMWInst::UMin:
7844   case AtomicRMWInst::FMax:
7845   case AtomicRMWInst::FMin:
7846   case AtomicRMWInst::UIncWrap:
7847   case AtomicRMWInst::UDecWrap:
7848     llvm_unreachable("Unsupported atomic update operation");
7849   }
7850   llvm_unreachable("Unsupported atomic update operation");
7851 }
7852 
7853 std::pair<Value *, Value *> OpenMPIRBuilder::emitAtomicUpdate(
7854     InsertPointTy AllocaIP, Value *X, Type *XElemTy, Value *Expr,
7855     AtomicOrdering AO, AtomicRMWInst::BinOp RMWOp,
7856     AtomicUpdateCallbackTy &UpdateOp, bool VolatileX, bool IsXBinopExpr) {
7857   // TODO: handle the case where XElemTy is not byte-sized or not a power of 2
7858   // or a complex datatype.
7859   bool emitRMWOp = false;
7860   switch (RMWOp) {
7861   case AtomicRMWInst::Add:
7862   case AtomicRMWInst::And:
7863   case AtomicRMWInst::Nand:
7864   case AtomicRMWInst::Or:
7865   case AtomicRMWInst::Xor:
7866   case AtomicRMWInst::Xchg:
7867     emitRMWOp = XElemTy;
7868     break;
7869   case AtomicRMWInst::Sub:
7870     emitRMWOp = (IsXBinopExpr && XElemTy);
7871     break;
7872   default:
7873     emitRMWOp = false;
7874   }
7875   emitRMWOp &= XElemTy->isIntegerTy();
7876 
7877   std::pair<Value *, Value *> Res;
7878   if (emitRMWOp) {
7879     Res.first = Builder.CreateAtomicRMW(RMWOp, X, Expr, llvm::MaybeAlign(), AO);
7880     // not needed except in case of postfix captures. Generate anyway for
7881     // consistency with the else part. Will be removed with any DCE pass.
7882     // AtomicRMWInst::Xchg does not have a coressponding instruction.
7883     if (RMWOp == AtomicRMWInst::Xchg)
7884       Res.second = Res.first;
7885     else
7886       Res.second = emitRMWOpAsInstruction(Res.first, Expr, RMWOp);
7887   } else {
7888     IntegerType *IntCastTy =
7889         IntegerType::get(M.getContext(), XElemTy->getScalarSizeInBits());
7890     LoadInst *OldVal =
7891         Builder.CreateLoad(IntCastTy, X, X->getName() + ".atomic.load");
7892     OldVal->setAtomic(AO);
7893     // CurBB
7894     // |     /---\
7895 		// ContBB    |
7896     // |     \---/
7897     // ExitBB
7898     BasicBlock *CurBB = Builder.GetInsertBlock();
7899     Instruction *CurBBTI = CurBB->getTerminator();
7900     CurBBTI = CurBBTI ? CurBBTI : Builder.CreateUnreachable();
7901     BasicBlock *ExitBB =
7902         CurBB->splitBasicBlock(CurBBTI, X->getName() + ".atomic.exit");
7903     BasicBlock *ContBB = CurBB->splitBasicBlock(CurBB->getTerminator(),
7904                                                 X->getName() + ".atomic.cont");
7905     ContBB->getTerminator()->eraseFromParent();
7906     Builder.restoreIP(AllocaIP);
7907     AllocaInst *NewAtomicAddr = Builder.CreateAlloca(XElemTy);
7908     NewAtomicAddr->setName(X->getName() + "x.new.val");
7909     Builder.SetInsertPoint(ContBB);
7910     llvm::PHINode *PHI = Builder.CreatePHI(OldVal->getType(), 2);
7911     PHI->addIncoming(OldVal, CurBB);
7912     bool IsIntTy = XElemTy->isIntegerTy();
7913     Value *OldExprVal = PHI;
7914     if (!IsIntTy) {
7915       if (XElemTy->isFloatingPointTy()) {
7916         OldExprVal = Builder.CreateBitCast(PHI, XElemTy,
7917                                            X->getName() + ".atomic.fltCast");
7918       } else {
7919         OldExprVal = Builder.CreateIntToPtr(PHI, XElemTy,
7920                                             X->getName() + ".atomic.ptrCast");
7921       }
7922     }
7923 
7924     Value *Upd = UpdateOp(OldExprVal, Builder);
7925     Builder.CreateStore(Upd, NewAtomicAddr);
7926     LoadInst *DesiredVal = Builder.CreateLoad(IntCastTy, NewAtomicAddr);
7927     AtomicOrdering Failure =
7928         llvm::AtomicCmpXchgInst::getStrongestFailureOrdering(AO);
7929     AtomicCmpXchgInst *Result = Builder.CreateAtomicCmpXchg(
7930         X, PHI, DesiredVal, llvm::MaybeAlign(), AO, Failure);
7931     Result->setVolatile(VolatileX);
7932     Value *PreviousVal = Builder.CreateExtractValue(Result, /*Idxs=*/0);
7933     Value *SuccessFailureVal = Builder.CreateExtractValue(Result, /*Idxs=*/1);
7934     PHI->addIncoming(PreviousVal, Builder.GetInsertBlock());
7935     Builder.CreateCondBr(SuccessFailureVal, ExitBB, ContBB);
7936 
7937     Res.first = OldExprVal;
7938     Res.second = Upd;
7939 
7940     // set Insertion point in exit block
7941     if (UnreachableInst *ExitTI =
7942             dyn_cast<UnreachableInst>(ExitBB->getTerminator())) {
7943       CurBBTI->eraseFromParent();
7944       Builder.SetInsertPoint(ExitBB);
7945     } else {
7946       Builder.SetInsertPoint(ExitTI);
7947     }
7948   }
7949 
7950   return Res;
7951 }
7952 
7953 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCapture(
7954     const LocationDescription &Loc, InsertPointTy AllocaIP, AtomicOpValue &X,
7955     AtomicOpValue &V, Value *Expr, AtomicOrdering AO,
7956     AtomicRMWInst::BinOp RMWOp, AtomicUpdateCallbackTy &UpdateOp,
7957     bool UpdateExpr, bool IsPostfixUpdate, bool IsXBinopExpr) {
7958   if (!updateToLocation(Loc))
7959     return Loc.IP;
7960 
7961   LLVM_DEBUG({
7962     Type *XTy = X.Var->getType();
7963     assert(XTy->isPointerTy() &&
7964            "OMP Atomic expects a pointer to target memory");
7965     Type *XElemTy = X.ElemTy;
7966     assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
7967             XElemTy->isPointerTy()) &&
7968            "OMP atomic capture expected a scalar type");
7969     assert((RMWOp != AtomicRMWInst::Max) && (RMWOp != AtomicRMWInst::Min) &&
7970            "OpenMP atomic does not support LT or GT operations");
7971   });
7972 
7973   // If UpdateExpr is 'x' updated with some `expr` not based on 'x',
7974   // 'x' is simply atomically rewritten with 'expr'.
7975   AtomicRMWInst::BinOp AtomicOp = (UpdateExpr ? RMWOp : AtomicRMWInst::Xchg);
7976   std::pair<Value *, Value *> Result =
7977       emitAtomicUpdate(AllocaIP, X.Var, X.ElemTy, Expr, AO, AtomicOp, UpdateOp,
7978                        X.IsVolatile, IsXBinopExpr);
7979 
7980   Value *CapturedVal = (IsPostfixUpdate ? Result.first : Result.second);
7981   Builder.CreateStore(CapturedVal, V.Var, V.IsVolatile);
7982 
7983   checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Capture);
7984   return Builder.saveIP();
7985 }
7986 
7987 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare(
7988     const LocationDescription &Loc, AtomicOpValue &X, AtomicOpValue &V,
7989     AtomicOpValue &R, Value *E, Value *D, AtomicOrdering AO,
7990     omp::OMPAtomicCompareOp Op, bool IsXBinopExpr, bool IsPostfixUpdate,
7991     bool IsFailOnly) {
7992 
7993   AtomicOrdering Failure = AtomicCmpXchgInst::getStrongestFailureOrdering(AO);
7994   return createAtomicCompare(Loc, X, V, R, E, D, AO, Op, IsXBinopExpr,
7995                              IsPostfixUpdate, IsFailOnly, Failure);
7996 }
7997 
7998 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare(
7999     const LocationDescription &Loc, AtomicOpValue &X, AtomicOpValue &V,
8000     AtomicOpValue &R, Value *E, Value *D, AtomicOrdering AO,
8001     omp::OMPAtomicCompareOp Op, bool IsXBinopExpr, bool IsPostfixUpdate,
8002     bool IsFailOnly, AtomicOrdering Failure) {
8003 
8004   if (!updateToLocation(Loc))
8005     return Loc.IP;
8006 
8007   assert(X.Var->getType()->isPointerTy() &&
8008          "OMP atomic expects a pointer to target memory");
8009   // compare capture
8010   if (V.Var) {
8011     assert(V.Var->getType()->isPointerTy() && "v.var must be of pointer type");
8012     assert(V.ElemTy == X.ElemTy && "x and v must be of same type");
8013   }
8014 
8015   bool IsInteger = E->getType()->isIntegerTy();
8016 
8017   if (Op == OMPAtomicCompareOp::EQ) {
8018     AtomicCmpXchgInst *Result = nullptr;
8019     if (!IsInteger) {
8020       IntegerType *IntCastTy =
8021           IntegerType::get(M.getContext(), X.ElemTy->getScalarSizeInBits());
8022       Value *EBCast = Builder.CreateBitCast(E, IntCastTy);
8023       Value *DBCast = Builder.CreateBitCast(D, IntCastTy);
8024       Result = Builder.CreateAtomicCmpXchg(X.Var, EBCast, DBCast, MaybeAlign(),
8025                                            AO, Failure);
8026     } else {
8027       Result =
8028           Builder.CreateAtomicCmpXchg(X.Var, E, D, MaybeAlign(), AO, Failure);
8029     }
8030 
8031     if (V.Var) {
8032       Value *OldValue = Builder.CreateExtractValue(Result, /*Idxs=*/0);
8033       if (!IsInteger)
8034         OldValue = Builder.CreateBitCast(OldValue, X.ElemTy);
8035       assert(OldValue->getType() == V.ElemTy &&
8036              "OldValue and V must be of same type");
8037       if (IsPostfixUpdate) {
8038         Builder.CreateStore(OldValue, V.Var, V.IsVolatile);
8039       } else {
8040         Value *SuccessOrFail = Builder.CreateExtractValue(Result, /*Idxs=*/1);
8041         if (IsFailOnly) {
8042           // CurBB----
8043           //   |     |
8044           //   v     |
8045           // ContBB  |
8046           //   |     |
8047           //   v     |
8048           // ExitBB <-
8049           //
8050           // where ContBB only contains the store of old value to 'v'.
8051           BasicBlock *CurBB = Builder.GetInsertBlock();
8052           Instruction *CurBBTI = CurBB->getTerminator();
8053           CurBBTI = CurBBTI ? CurBBTI : Builder.CreateUnreachable();
8054           BasicBlock *ExitBB = CurBB->splitBasicBlock(
8055               CurBBTI, X.Var->getName() + ".atomic.exit");
8056           BasicBlock *ContBB = CurBB->splitBasicBlock(
8057               CurBB->getTerminator(), X.Var->getName() + ".atomic.cont");
8058           ContBB->getTerminator()->eraseFromParent();
8059           CurBB->getTerminator()->eraseFromParent();
8060 
8061           Builder.CreateCondBr(SuccessOrFail, ExitBB, ContBB);
8062 
8063           Builder.SetInsertPoint(ContBB);
8064           Builder.CreateStore(OldValue, V.Var);
8065           Builder.CreateBr(ExitBB);
8066 
8067           if (UnreachableInst *ExitTI =
8068                   dyn_cast<UnreachableInst>(ExitBB->getTerminator())) {
8069             CurBBTI->eraseFromParent();
8070             Builder.SetInsertPoint(ExitBB);
8071           } else {
8072             Builder.SetInsertPoint(ExitTI);
8073           }
8074         } else {
8075           Value *CapturedValue =
8076               Builder.CreateSelect(SuccessOrFail, E, OldValue);
8077           Builder.CreateStore(CapturedValue, V.Var, V.IsVolatile);
8078         }
8079       }
8080     }
8081     // The comparison result has to be stored.
8082     if (R.Var) {
8083       assert(R.Var->getType()->isPointerTy() &&
8084              "r.var must be of pointer type");
8085       assert(R.ElemTy->isIntegerTy() && "r must be of integral type");
8086 
8087       Value *SuccessFailureVal = Builder.CreateExtractValue(Result, /*Idxs=*/1);
8088       Value *ResultCast = R.IsSigned
8089                               ? Builder.CreateSExt(SuccessFailureVal, R.ElemTy)
8090                               : Builder.CreateZExt(SuccessFailureVal, R.ElemTy);
8091       Builder.CreateStore(ResultCast, R.Var, R.IsVolatile);
8092     }
8093   } else {
8094     assert((Op == OMPAtomicCompareOp::MAX || Op == OMPAtomicCompareOp::MIN) &&
8095            "Op should be either max or min at this point");
8096     assert(!IsFailOnly && "IsFailOnly is only valid when the comparison is ==");
8097 
8098     // Reverse the ordop as the OpenMP forms are different from LLVM forms.
8099     // Let's take max as example.
8100     // OpenMP form:
8101     // x = x > expr ? expr : x;
8102     // LLVM form:
8103     // *ptr = *ptr > val ? *ptr : val;
8104     // We need to transform to LLVM form.
8105     // x = x <= expr ? x : expr;
8106     AtomicRMWInst::BinOp NewOp;
8107     if (IsXBinopExpr) {
8108       if (IsInteger) {
8109         if (X.IsSigned)
8110           NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::Min
8111                                                 : AtomicRMWInst::Max;
8112         else
8113           NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::UMin
8114                                                 : AtomicRMWInst::UMax;
8115       } else {
8116         NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::FMin
8117                                               : AtomicRMWInst::FMax;
8118       }
8119     } else {
8120       if (IsInteger) {
8121         if (X.IsSigned)
8122           NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::Max
8123                                                 : AtomicRMWInst::Min;
8124         else
8125           NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::UMax
8126                                                 : AtomicRMWInst::UMin;
8127       } else {
8128         NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::FMax
8129                                               : AtomicRMWInst::FMin;
8130       }
8131     }
8132 
8133     AtomicRMWInst *OldValue =
8134         Builder.CreateAtomicRMW(NewOp, X.Var, E, MaybeAlign(), AO);
8135     if (V.Var) {
8136       Value *CapturedValue = nullptr;
8137       if (IsPostfixUpdate) {
8138         CapturedValue = OldValue;
8139       } else {
8140         CmpInst::Predicate Pred;
8141         switch (NewOp) {
8142         case AtomicRMWInst::Max:
8143           Pred = CmpInst::ICMP_SGT;
8144           break;
8145         case AtomicRMWInst::UMax:
8146           Pred = CmpInst::ICMP_UGT;
8147           break;
8148         case AtomicRMWInst::FMax:
8149           Pred = CmpInst::FCMP_OGT;
8150           break;
8151         case AtomicRMWInst::Min:
8152           Pred = CmpInst::ICMP_SLT;
8153           break;
8154         case AtomicRMWInst::UMin:
8155           Pred = CmpInst::ICMP_ULT;
8156           break;
8157         case AtomicRMWInst::FMin:
8158           Pred = CmpInst::FCMP_OLT;
8159           break;
8160         default:
8161           llvm_unreachable("unexpected comparison op");
8162         }
8163         Value *NonAtomicCmp = Builder.CreateCmp(Pred, OldValue, E);
8164         CapturedValue = Builder.CreateSelect(NonAtomicCmp, E, OldValue);
8165       }
8166       Builder.CreateStore(CapturedValue, V.Var, V.IsVolatile);
8167     }
8168   }
8169 
8170   checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Compare);
8171 
8172   return Builder.saveIP();
8173 }
8174 
8175 OpenMPIRBuilder::InsertPointTy
8176 OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
8177                              BodyGenCallbackTy BodyGenCB, Value *NumTeamsLower,
8178                              Value *NumTeamsUpper, Value *ThreadLimit,
8179                              Value *IfExpr) {
8180   if (!updateToLocation(Loc))
8181     return InsertPointTy();
8182 
8183   uint32_t SrcLocStrSize;
8184   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
8185   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
8186   Function *CurrentFunction = Builder.GetInsertBlock()->getParent();
8187 
8188   // Outer allocation basicblock is the entry block of the current function.
8189   BasicBlock &OuterAllocaBB = CurrentFunction->getEntryBlock();
8190   if (&OuterAllocaBB == Builder.GetInsertBlock()) {
8191     BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "teams.entry");
8192     Builder.SetInsertPoint(BodyBB, BodyBB->begin());
8193   }
8194 
8195   // The current basic block is split into four basic blocks. After outlining,
8196   // they will be mapped as follows:
8197   // ```
8198   // def current_fn() {
8199   //   current_basic_block:
8200   //     br label %teams.exit
8201   //   teams.exit:
8202   //     ; instructions after teams
8203   // }
8204   //
8205   // def outlined_fn() {
8206   //   teams.alloca:
8207   //     br label %teams.body
8208   //   teams.body:
8209   //     ; instructions within teams body
8210   // }
8211   // ```
8212   BasicBlock *ExitBB = splitBB(Builder, /*CreateBranch=*/true, "teams.exit");
8213   BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "teams.body");
8214   BasicBlock *AllocaBB =
8215       splitBB(Builder, /*CreateBranch=*/true, "teams.alloca");
8216 
8217   bool SubClausesPresent =
8218       (NumTeamsLower || NumTeamsUpper || ThreadLimit || IfExpr);
8219   // Push num_teams
8220   if (!Config.isTargetDevice() && SubClausesPresent) {
8221     assert((NumTeamsLower == nullptr || NumTeamsUpper != nullptr) &&
8222            "if lowerbound is non-null, then upperbound must also be non-null "
8223            "for bounds on num_teams");
8224 
8225     if (NumTeamsUpper == nullptr)
8226       NumTeamsUpper = Builder.getInt32(0);
8227 
8228     if (NumTeamsLower == nullptr)
8229       NumTeamsLower = NumTeamsUpper;
8230 
8231     if (IfExpr) {
8232       assert(IfExpr->getType()->isIntegerTy() &&
8233              "argument to if clause must be an integer value");
8234 
8235       // upper = ifexpr ? upper : 1
8236       if (IfExpr->getType() != Int1)
8237         IfExpr = Builder.CreateICmpNE(IfExpr,
8238                                       ConstantInt::get(IfExpr->getType(), 0));
8239       NumTeamsUpper = Builder.CreateSelect(
8240           IfExpr, NumTeamsUpper, Builder.getInt32(1), "numTeamsUpper");
8241 
8242       // lower = ifexpr ? lower : 1
8243       NumTeamsLower = Builder.CreateSelect(
8244           IfExpr, NumTeamsLower, Builder.getInt32(1), "numTeamsLower");
8245     }
8246 
8247     if (ThreadLimit == nullptr)
8248       ThreadLimit = Builder.getInt32(0);
8249 
8250     Value *ThreadNum = getOrCreateThreadID(Ident);
8251     Builder.CreateCall(
8252         getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_num_teams_51),
8253         {Ident, ThreadNum, NumTeamsLower, NumTeamsUpper, ThreadLimit});
8254   }
8255   // Generate the body of teams.
8256   InsertPointTy AllocaIP(AllocaBB, AllocaBB->begin());
8257   InsertPointTy CodeGenIP(BodyBB, BodyBB->begin());
8258   BodyGenCB(AllocaIP, CodeGenIP);
8259 
8260   OutlineInfo OI;
8261   OI.EntryBB = AllocaBB;
8262   OI.ExitBB = ExitBB;
8263   OI.OuterAllocaBB = &OuterAllocaBB;
8264 
8265   // Insert fake values for global tid and bound tid.
8266   SmallVector<Instruction *, 8> ToBeDeleted;
8267   InsertPointTy OuterAllocaIP(&OuterAllocaBB, OuterAllocaBB.begin());
8268   OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
8269       Builder, OuterAllocaIP, ToBeDeleted, AllocaIP, "gid", true));
8270   OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
8271       Builder, OuterAllocaIP, ToBeDeleted, AllocaIP, "tid", true));
8272 
8273   auto HostPostOutlineCB = [this, Ident,
8274                             ToBeDeleted](Function &OutlinedFn) mutable {
8275     // The stale call instruction will be replaced with a new call instruction
8276     // for runtime call with the outlined function.
8277 
8278     assert(OutlinedFn.getNumUses() == 1 &&
8279            "there must be a single user for the outlined function");
8280     CallInst *StaleCI = cast<CallInst>(OutlinedFn.user_back());
8281     ToBeDeleted.push_back(StaleCI);
8282 
8283     assert((OutlinedFn.arg_size() == 2 || OutlinedFn.arg_size() == 3) &&
8284            "Outlined function must have two or three arguments only");
8285 
8286     bool HasShared = OutlinedFn.arg_size() == 3;
8287 
8288     OutlinedFn.getArg(0)->setName("global.tid.ptr");
8289     OutlinedFn.getArg(1)->setName("bound.tid.ptr");
8290     if (HasShared)
8291       OutlinedFn.getArg(2)->setName("data");
8292 
8293     // Call to the runtime function for teams in the current function.
8294     assert(StaleCI && "Error while outlining - no CallInst user found for the "
8295                       "outlined function.");
8296     Builder.SetInsertPoint(StaleCI);
8297     SmallVector<Value *> Args = {
8298         Ident, Builder.getInt32(StaleCI->arg_size() - 2), &OutlinedFn};
8299     if (HasShared)
8300       Args.push_back(StaleCI->getArgOperand(2));
8301     Builder.CreateCall(getOrCreateRuntimeFunctionPtr(
8302                            omp::RuntimeFunction::OMPRTL___kmpc_fork_teams),
8303                        Args);
8304 
8305     llvm::for_each(llvm::reverse(ToBeDeleted),
8306                    [](Instruction *I) { I->eraseFromParent(); });
8307 
8308   };
8309 
8310   if (!Config.isTargetDevice())
8311     OI.PostOutlineCB = HostPostOutlineCB;
8312 
8313   addOutlineInfo(std::move(OI));
8314 
8315   Builder.SetInsertPoint(ExitBB, ExitBB->begin());
8316 
8317   return Builder.saveIP();
8318 }
8319 
8320 GlobalVariable *
8321 OpenMPIRBuilder::createOffloadMapnames(SmallVectorImpl<llvm::Constant *> &Names,
8322                                        std::string VarName) {
8323   llvm::Constant *MapNamesArrayInit = llvm::ConstantArray::get(
8324       llvm::ArrayType::get(llvm::PointerType::getUnqual(M.getContext()),
8325                            Names.size()),
8326       Names);
8327   auto *MapNamesArrayGlobal = new llvm::GlobalVariable(
8328       M, MapNamesArrayInit->getType(),
8329       /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, MapNamesArrayInit,
8330       VarName);
8331   return MapNamesArrayGlobal;
8332 }
8333 
8334 // Create all simple and struct types exposed by the runtime and remember
8335 // the llvm::PointerTypes of them for easy access later.
8336 void OpenMPIRBuilder::initializeTypes(Module &M) {
8337   LLVMContext &Ctx = M.getContext();
8338   StructType *T;
8339 #define OMP_TYPE(VarName, InitValue) VarName = InitValue;
8340 #define OMP_ARRAY_TYPE(VarName, ElemTy, ArraySize)                             \
8341   VarName##Ty = ArrayType::get(ElemTy, ArraySize);                             \
8342   VarName##PtrTy = PointerType::getUnqual(VarName##Ty);
8343 #define OMP_FUNCTION_TYPE(VarName, IsVarArg, ReturnType, ...)                  \
8344   VarName = FunctionType::get(ReturnType, {__VA_ARGS__}, IsVarArg);            \
8345   VarName##Ptr = PointerType::getUnqual(VarName);
8346 #define OMP_STRUCT_TYPE(VarName, StructName, Packed, ...)                      \
8347   T = StructType::getTypeByName(Ctx, StructName);                              \
8348   if (!T)                                                                      \
8349     T = StructType::create(Ctx, {__VA_ARGS__}, StructName, Packed);            \
8350   VarName = T;                                                                 \
8351   VarName##Ptr = PointerType::getUnqual(T);
8352 #include "llvm/Frontend/OpenMP/OMPKinds.def"
8353 }
8354 
8355 void OpenMPIRBuilder::OutlineInfo::collectBlocks(
8356     SmallPtrSetImpl<BasicBlock *> &BlockSet,
8357     SmallVectorImpl<BasicBlock *> &BlockVector) {
8358   SmallVector<BasicBlock *, 32> Worklist;
8359   BlockSet.insert(EntryBB);
8360   BlockSet.insert(ExitBB);
8361 
8362   Worklist.push_back(EntryBB);
8363   while (!Worklist.empty()) {
8364     BasicBlock *BB = Worklist.pop_back_val();
8365     BlockVector.push_back(BB);
8366     for (BasicBlock *SuccBB : successors(BB))
8367       if (BlockSet.insert(SuccBB).second)
8368         Worklist.push_back(SuccBB);
8369   }
8370 }
8371 
8372 void OpenMPIRBuilder::createOffloadEntry(Constant *ID, Constant *Addr,
8373                                          uint64_t Size, int32_t Flags,
8374                                          GlobalValue::LinkageTypes,
8375                                          StringRef Name) {
8376   if (!Config.isGPU()) {
8377     llvm::offloading::emitOffloadingEntry(
8378         M, ID, Name.empty() ? Addr->getName() : Name, Size, Flags, /*Data=*/0,
8379         "omp_offloading_entries");
8380     return;
8381   }
8382   // TODO: Add support for global variables on the device after declare target
8383   // support.
8384   Function *Fn = dyn_cast<Function>(Addr);
8385   if (!Fn)
8386     return;
8387 
8388   Module &M = *(Fn->getParent());
8389   LLVMContext &Ctx = M.getContext();
8390 
8391   // Get "nvvm.annotations" metadata node.
8392   NamedMDNode *MD = M.getOrInsertNamedMetadata("nvvm.annotations");
8393 
8394   Metadata *MDVals[] = {
8395       ConstantAsMetadata::get(Fn), MDString::get(Ctx, "kernel"),
8396       ConstantAsMetadata::get(ConstantInt::get(Type::getInt32Ty(Ctx), 1))};
8397   // Append metadata to nvvm.annotations.
8398   MD->addOperand(MDNode::get(Ctx, MDVals));
8399 
8400   // Add a function attribute for the kernel.
8401   Fn->addFnAttr(Attribute::get(Ctx, "kernel"));
8402   if (T.isAMDGCN())
8403     Fn->addFnAttr("uniform-work-group-size", "true");
8404   Fn->addFnAttr(Attribute::MustProgress);
8405 }
8406 
8407 // We only generate metadata for function that contain target regions.
8408 void OpenMPIRBuilder::createOffloadEntriesAndInfoMetadata(
8409     EmitMetadataErrorReportFunctionTy &ErrorFn) {
8410 
8411   // If there are no entries, we don't need to do anything.
8412   if (OffloadInfoManager.empty())
8413     return;
8414 
8415   LLVMContext &C = M.getContext();
8416   SmallVector<std::pair<const OffloadEntriesInfoManager::OffloadEntryInfo *,
8417                         TargetRegionEntryInfo>,
8418               16>
8419       OrderedEntries(OffloadInfoManager.size());
8420 
8421   // Auxiliary methods to create metadata values and strings.
8422   auto &&GetMDInt = [this](unsigned V) {
8423     return ConstantAsMetadata::get(ConstantInt::get(Builder.getInt32Ty(), V));
8424   };
8425 
8426   auto &&GetMDString = [&C](StringRef V) { return MDString::get(C, V); };
8427 
8428   // Create the offloading info metadata node.
8429   NamedMDNode *MD = M.getOrInsertNamedMetadata("omp_offload.info");
8430   auto &&TargetRegionMetadataEmitter =
8431       [&C, MD, &OrderedEntries, &GetMDInt, &GetMDString](
8432           const TargetRegionEntryInfo &EntryInfo,
8433           const OffloadEntriesInfoManager::OffloadEntryInfoTargetRegion &E) {
8434         // Generate metadata for target regions. Each entry of this metadata
8435         // contains:
8436         // - Entry 0 -> Kind of this type of metadata (0).
8437         // - Entry 1 -> Device ID of the file where the entry was identified.
8438         // - Entry 2 -> File ID of the file where the entry was identified.
8439         // - Entry 3 -> Mangled name of the function where the entry was
8440         // identified.
8441         // - Entry 4 -> Line in the file where the entry was identified.
8442         // - Entry 5 -> Count of regions at this DeviceID/FilesID/Line.
8443         // - Entry 6 -> Order the entry was created.
8444         // The first element of the metadata node is the kind.
8445         Metadata *Ops[] = {
8446             GetMDInt(E.getKind()),      GetMDInt(EntryInfo.DeviceID),
8447             GetMDInt(EntryInfo.FileID), GetMDString(EntryInfo.ParentName),
8448             GetMDInt(EntryInfo.Line),   GetMDInt(EntryInfo.Count),
8449             GetMDInt(E.getOrder())};
8450 
8451         // Save this entry in the right position of the ordered entries array.
8452         OrderedEntries[E.getOrder()] = std::make_pair(&E, EntryInfo);
8453 
8454         // Add metadata to the named metadata node.
8455         MD->addOperand(MDNode::get(C, Ops));
8456       };
8457 
8458   OffloadInfoManager.actOnTargetRegionEntriesInfo(TargetRegionMetadataEmitter);
8459 
8460   // Create function that emits metadata for each device global variable entry;
8461   auto &&DeviceGlobalVarMetadataEmitter =
8462       [&C, &OrderedEntries, &GetMDInt, &GetMDString, MD](
8463           StringRef MangledName,
8464           const OffloadEntriesInfoManager::OffloadEntryInfoDeviceGlobalVar &E) {
8465         // Generate metadata for global variables. Each entry of this metadata
8466         // contains:
8467         // - Entry 0 -> Kind of this type of metadata (1).
8468         // - Entry 1 -> Mangled name of the variable.
8469         // - Entry 2 -> Declare target kind.
8470         // - Entry 3 -> Order the entry was created.
8471         // The first element of the metadata node is the kind.
8472         Metadata *Ops[] = {GetMDInt(E.getKind()), GetMDString(MangledName),
8473                            GetMDInt(E.getFlags()), GetMDInt(E.getOrder())};
8474 
8475         // Save this entry in the right position of the ordered entries array.
8476         TargetRegionEntryInfo varInfo(MangledName, 0, 0, 0);
8477         OrderedEntries[E.getOrder()] = std::make_pair(&E, varInfo);
8478 
8479         // Add metadata to the named metadata node.
8480         MD->addOperand(MDNode::get(C, Ops));
8481       };
8482 
8483   OffloadInfoManager.actOnDeviceGlobalVarEntriesInfo(
8484       DeviceGlobalVarMetadataEmitter);
8485 
8486   for (const auto &E : OrderedEntries) {
8487     assert(E.first && "All ordered entries must exist!");
8488     if (const auto *CE =
8489             dyn_cast<OffloadEntriesInfoManager::OffloadEntryInfoTargetRegion>(
8490                 E.first)) {
8491       if (!CE->getID() || !CE->getAddress()) {
8492         // Do not blame the entry if the parent funtion is not emitted.
8493         TargetRegionEntryInfo EntryInfo = E.second;
8494         StringRef FnName = EntryInfo.ParentName;
8495         if (!M.getNamedValue(FnName))
8496           continue;
8497         ErrorFn(EMIT_MD_TARGET_REGION_ERROR, EntryInfo);
8498         continue;
8499       }
8500       createOffloadEntry(CE->getID(), CE->getAddress(),
8501                          /*Size=*/0, CE->getFlags(),
8502                          GlobalValue::WeakAnyLinkage);
8503     } else if (const auto *CE = dyn_cast<
8504                    OffloadEntriesInfoManager::OffloadEntryInfoDeviceGlobalVar>(
8505                    E.first)) {
8506       OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind Flags =
8507           static_cast<OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind>(
8508               CE->getFlags());
8509       switch (Flags) {
8510       case OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter:
8511       case OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo:
8512         if (Config.isTargetDevice() && Config.hasRequiresUnifiedSharedMemory())
8513           continue;
8514         if (!CE->getAddress()) {
8515           ErrorFn(EMIT_MD_DECLARE_TARGET_ERROR, E.second);
8516           continue;
8517         }
8518         // The vaiable has no definition - no need to add the entry.
8519         if (CE->getVarSize() == 0)
8520           continue;
8521         break;
8522       case OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink:
8523         assert(((Config.isTargetDevice() && !CE->getAddress()) ||
8524                 (!Config.isTargetDevice() && CE->getAddress())) &&
8525                "Declaret target link address is set.");
8526         if (Config.isTargetDevice())
8527           continue;
8528         if (!CE->getAddress()) {
8529           ErrorFn(EMIT_MD_GLOBAL_VAR_LINK_ERROR, TargetRegionEntryInfo());
8530           continue;
8531         }
8532         break;
8533       default:
8534         break;
8535       }
8536 
8537       // Hidden or internal symbols on the device are not externally visible.
8538       // We should not attempt to register them by creating an offloading
8539       // entry. Indirect variables are handled separately on the device.
8540       if (auto *GV = dyn_cast<GlobalValue>(CE->getAddress()))
8541         if ((GV->hasLocalLinkage() || GV->hasHiddenVisibility()) &&
8542             Flags != OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirect)
8543           continue;
8544 
8545       // Indirect globals need to use a special name that doesn't match the name
8546       // of the associated host global.
8547       if (Flags == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirect)
8548         createOffloadEntry(CE->getAddress(), CE->getAddress(), CE->getVarSize(),
8549                            Flags, CE->getLinkage(), CE->getVarName());
8550       else
8551         createOffloadEntry(CE->getAddress(), CE->getAddress(), CE->getVarSize(),
8552                            Flags, CE->getLinkage());
8553 
8554     } else {
8555       llvm_unreachable("Unsupported entry kind.");
8556     }
8557   }
8558 
8559   // Emit requires directive globals to a special entry so the runtime can
8560   // register them when the device image is loaded.
8561   // TODO: This reduces the offloading entries to a 32-bit integer. Offloading
8562   //       entries should be redesigned to better suit this use-case.
8563   if (Config.hasRequiresFlags() && !Config.isTargetDevice())
8564     offloading::emitOffloadingEntry(
8565         M, Constant::getNullValue(PointerType::getUnqual(M.getContext())),
8566         /*Name=*/"",
8567         /*Size=*/0, OffloadEntriesInfoManager::OMPTargetGlobalRegisterRequires,
8568         Config.getRequiresFlags(), "omp_offloading_entries");
8569 }
8570 
8571 void TargetRegionEntryInfo::getTargetRegionEntryFnName(
8572     SmallVectorImpl<char> &Name, StringRef ParentName, unsigned DeviceID,
8573     unsigned FileID, unsigned Line, unsigned Count) {
8574   raw_svector_ostream OS(Name);
8575   OS << "__omp_offloading" << llvm::format("_%x", DeviceID)
8576      << llvm::format("_%x_", FileID) << ParentName << "_l" << Line;
8577   if (Count)
8578     OS << "_" << Count;
8579 }
8580 
8581 void OffloadEntriesInfoManager::getTargetRegionEntryFnName(
8582     SmallVectorImpl<char> &Name, const TargetRegionEntryInfo &EntryInfo) {
8583   unsigned NewCount = getTargetRegionEntryInfoCount(EntryInfo);
8584   TargetRegionEntryInfo::getTargetRegionEntryFnName(
8585       Name, EntryInfo.ParentName, EntryInfo.DeviceID, EntryInfo.FileID,
8586       EntryInfo.Line, NewCount);
8587 }
8588 
8589 TargetRegionEntryInfo
8590 OpenMPIRBuilder::getTargetEntryUniqueInfo(FileIdentifierInfoCallbackTy CallBack,
8591                                           StringRef ParentName) {
8592   sys::fs::UniqueID ID;
8593   auto FileIDInfo = CallBack();
8594   if (auto EC = sys::fs::getUniqueID(std::get<0>(FileIDInfo), ID)) {
8595     report_fatal_error(("Unable to get unique ID for file, during "
8596                         "getTargetEntryUniqueInfo, error message: " +
8597                         EC.message())
8598                            .c_str());
8599   }
8600 
8601   return TargetRegionEntryInfo(ParentName, ID.getDevice(), ID.getFile(),
8602                                std::get<1>(FileIDInfo));
8603 }
8604 
8605 unsigned OpenMPIRBuilder::getFlagMemberOffset() {
8606   unsigned Offset = 0;
8607   for (uint64_t Remain =
8608            static_cast<std::underlying_type_t<omp::OpenMPOffloadMappingFlags>>(
8609                omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF);
8610        !(Remain & 1); Remain = Remain >> 1)
8611     Offset++;
8612   return Offset;
8613 }
8614 
8615 omp::OpenMPOffloadMappingFlags
8616 OpenMPIRBuilder::getMemberOfFlag(unsigned Position) {
8617   // Rotate by getFlagMemberOffset() bits.
8618   return static_cast<omp::OpenMPOffloadMappingFlags>(((uint64_t)Position + 1)
8619                                                      << getFlagMemberOffset());
8620 }
8621 
8622 void OpenMPIRBuilder::setCorrectMemberOfFlag(
8623     omp::OpenMPOffloadMappingFlags &Flags,
8624     omp::OpenMPOffloadMappingFlags MemberOfFlag) {
8625   // If the entry is PTR_AND_OBJ but has not been marked with the special
8626   // placeholder value 0xFFFF in the MEMBER_OF field, then it should not be
8627   // marked as MEMBER_OF.
8628   if (static_cast<std::underlying_type_t<omp::OpenMPOffloadMappingFlags>>(
8629           Flags & omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ) &&
8630       static_cast<std::underlying_type_t<omp::OpenMPOffloadMappingFlags>>(
8631           (Flags & omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF) !=
8632           omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF))
8633     return;
8634 
8635   // Reset the placeholder value to prepare the flag for the assignment of the
8636   // proper MEMBER_OF value.
8637   Flags &= ~omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
8638   Flags |= MemberOfFlag;
8639 }
8640 
8641 Constant *OpenMPIRBuilder::getAddrOfDeclareTargetVar(
8642     OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind CaptureClause,
8643     OffloadEntriesInfoManager::OMPTargetDeviceClauseKind DeviceClause,
8644     bool IsDeclaration, bool IsExternallyVisible,
8645     TargetRegionEntryInfo EntryInfo, StringRef MangledName,
8646     std::vector<GlobalVariable *> &GeneratedRefs, bool OpenMPSIMD,
8647     std::vector<Triple> TargetTriple, Type *LlvmPtrTy,
8648     std::function<Constant *()> GlobalInitializer,
8649     std::function<GlobalValue::LinkageTypes()> VariableLinkage) {
8650   // TODO: convert this to utilise the IRBuilder Config rather than
8651   // a passed down argument.
8652   if (OpenMPSIMD)
8653     return nullptr;
8654 
8655   if (CaptureClause == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink ||
8656       ((CaptureClause == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo ||
8657         CaptureClause ==
8658             OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter) &&
8659        Config.hasRequiresUnifiedSharedMemory())) {
8660     SmallString<64> PtrName;
8661     {
8662       raw_svector_ostream OS(PtrName);
8663       OS << MangledName;
8664       if (!IsExternallyVisible)
8665         OS << format("_%x", EntryInfo.FileID);
8666       OS << "_decl_tgt_ref_ptr";
8667     }
8668 
8669     Value *Ptr = M.getNamedValue(PtrName);
8670 
8671     if (!Ptr) {
8672       GlobalValue *GlobalValue = M.getNamedValue(MangledName);
8673       Ptr = getOrCreateInternalVariable(LlvmPtrTy, PtrName);
8674 
8675       auto *GV = cast<GlobalVariable>(Ptr);
8676       GV->setLinkage(GlobalValue::WeakAnyLinkage);
8677 
8678       if (!Config.isTargetDevice()) {
8679         if (GlobalInitializer)
8680           GV->setInitializer(GlobalInitializer());
8681         else
8682           GV->setInitializer(GlobalValue);
8683       }
8684 
8685       registerTargetGlobalVariable(
8686           CaptureClause, DeviceClause, IsDeclaration, IsExternallyVisible,
8687           EntryInfo, MangledName, GeneratedRefs, OpenMPSIMD, TargetTriple,
8688           GlobalInitializer, VariableLinkage, LlvmPtrTy, cast<Constant>(Ptr));
8689     }
8690 
8691     return cast<Constant>(Ptr);
8692   }
8693 
8694   return nullptr;
8695 }
8696 
8697 void OpenMPIRBuilder::registerTargetGlobalVariable(
8698     OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind CaptureClause,
8699     OffloadEntriesInfoManager::OMPTargetDeviceClauseKind DeviceClause,
8700     bool IsDeclaration, bool IsExternallyVisible,
8701     TargetRegionEntryInfo EntryInfo, StringRef MangledName,
8702     std::vector<GlobalVariable *> &GeneratedRefs, bool OpenMPSIMD,
8703     std::vector<Triple> TargetTriple,
8704     std::function<Constant *()> GlobalInitializer,
8705     std::function<GlobalValue::LinkageTypes()> VariableLinkage, Type *LlvmPtrTy,
8706     Constant *Addr) {
8707   if (DeviceClause != OffloadEntriesInfoManager::OMPTargetDeviceClauseAny ||
8708       (TargetTriple.empty() && !Config.isTargetDevice()))
8709     return;
8710 
8711   OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind Flags;
8712   StringRef VarName;
8713   int64_t VarSize;
8714   GlobalValue::LinkageTypes Linkage;
8715 
8716   if ((CaptureClause == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo ||
8717        CaptureClause ==
8718            OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter) &&
8719       !Config.hasRequiresUnifiedSharedMemory()) {
8720     Flags = OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
8721     VarName = MangledName;
8722     GlobalValue *LlvmVal = M.getNamedValue(VarName);
8723 
8724     if (!IsDeclaration)
8725       VarSize = divideCeil(
8726           M.getDataLayout().getTypeSizeInBits(LlvmVal->getValueType()), 8);
8727     else
8728       VarSize = 0;
8729     Linkage = (VariableLinkage) ? VariableLinkage() : LlvmVal->getLinkage();
8730 
8731     // This is a workaround carried over from Clang which prevents undesired
8732     // optimisation of internal variables.
8733     if (Config.isTargetDevice() &&
8734         (!IsExternallyVisible || Linkage == GlobalValue::LinkOnceODRLinkage)) {
8735       // Do not create a "ref-variable" if the original is not also available
8736       // on the host.
8737       if (!OffloadInfoManager.hasDeviceGlobalVarEntryInfo(VarName))
8738         return;
8739 
8740       std::string RefName = createPlatformSpecificName({VarName, "ref"});
8741 
8742       if (!M.getNamedValue(RefName)) {
8743         Constant *AddrRef =
8744             getOrCreateInternalVariable(Addr->getType(), RefName);
8745         auto *GvAddrRef = cast<GlobalVariable>(AddrRef);
8746         GvAddrRef->setConstant(true);
8747         GvAddrRef->setLinkage(GlobalValue::InternalLinkage);
8748         GvAddrRef->setInitializer(Addr);
8749         GeneratedRefs.push_back(GvAddrRef);
8750       }
8751     }
8752   } else {
8753     if (CaptureClause == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink)
8754       Flags = OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink;
8755     else
8756       Flags = OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
8757 
8758     if (Config.isTargetDevice()) {
8759       VarName = (Addr) ? Addr->getName() : "";
8760       Addr = nullptr;
8761     } else {
8762       Addr = getAddrOfDeclareTargetVar(
8763           CaptureClause, DeviceClause, IsDeclaration, IsExternallyVisible,
8764           EntryInfo, MangledName, GeneratedRefs, OpenMPSIMD, TargetTriple,
8765           LlvmPtrTy, GlobalInitializer, VariableLinkage);
8766       VarName = (Addr) ? Addr->getName() : "";
8767     }
8768     VarSize = M.getDataLayout().getPointerSize();
8769     Linkage = GlobalValue::WeakAnyLinkage;
8770   }
8771 
8772   OffloadInfoManager.registerDeviceGlobalVarEntryInfo(VarName, Addr, VarSize,
8773                                                       Flags, Linkage);
8774 }
8775 
8776 /// Loads all the offload entries information from the host IR
8777 /// metadata.
8778 void OpenMPIRBuilder::loadOffloadInfoMetadata(Module &M) {
8779   // If we are in target mode, load the metadata from the host IR. This code has
8780   // to match the metadata creation in createOffloadEntriesAndInfoMetadata().
8781 
8782   NamedMDNode *MD = M.getNamedMetadata(ompOffloadInfoName);
8783   if (!MD)
8784     return;
8785 
8786   for (MDNode *MN : MD->operands()) {
8787     auto &&GetMDInt = [MN](unsigned Idx) {
8788       auto *V = cast<ConstantAsMetadata>(MN->getOperand(Idx));
8789       return cast<ConstantInt>(V->getValue())->getZExtValue();
8790     };
8791 
8792     auto &&GetMDString = [MN](unsigned Idx) {
8793       auto *V = cast<MDString>(MN->getOperand(Idx));
8794       return V->getString();
8795     };
8796 
8797     switch (GetMDInt(0)) {
8798     default:
8799       llvm_unreachable("Unexpected metadata!");
8800       break;
8801     case OffloadEntriesInfoManager::OffloadEntryInfo::
8802         OffloadingEntryInfoTargetRegion: {
8803       TargetRegionEntryInfo EntryInfo(/*ParentName=*/GetMDString(3),
8804                                       /*DeviceID=*/GetMDInt(1),
8805                                       /*FileID=*/GetMDInt(2),
8806                                       /*Line=*/GetMDInt(4),
8807                                       /*Count=*/GetMDInt(5));
8808       OffloadInfoManager.initializeTargetRegionEntryInfo(EntryInfo,
8809                                                          /*Order=*/GetMDInt(6));
8810       break;
8811     }
8812     case OffloadEntriesInfoManager::OffloadEntryInfo::
8813         OffloadingEntryInfoDeviceGlobalVar:
8814       OffloadInfoManager.initializeDeviceGlobalVarEntryInfo(
8815           /*MangledName=*/GetMDString(1),
8816           static_cast<OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind>(
8817               /*Flags=*/GetMDInt(2)),
8818           /*Order=*/GetMDInt(3));
8819       break;
8820     }
8821   }
8822 }
8823 
8824 void OpenMPIRBuilder::loadOffloadInfoMetadata(StringRef HostFilePath) {
8825   if (HostFilePath.empty())
8826     return;
8827 
8828   auto Buf = MemoryBuffer::getFile(HostFilePath);
8829   if (std::error_code Err = Buf.getError()) {
8830     report_fatal_error(("error opening host file from host file path inside of "
8831                         "OpenMPIRBuilder: " +
8832                         Err.message())
8833                            .c_str());
8834   }
8835 
8836   LLVMContext Ctx;
8837   auto M = expectedToErrorOrAndEmitErrors(
8838       Ctx, parseBitcodeFile(Buf.get()->getMemBufferRef(), Ctx));
8839   if (std::error_code Err = M.getError()) {
8840     report_fatal_error(
8841         ("error parsing host file inside of OpenMPIRBuilder: " + Err.message())
8842             .c_str());
8843   }
8844 
8845   loadOffloadInfoMetadata(*M.get());
8846 }
8847 
8848 //===----------------------------------------------------------------------===//
8849 // OffloadEntriesInfoManager
8850 //===----------------------------------------------------------------------===//
8851 
8852 bool OffloadEntriesInfoManager::empty() const {
8853   return OffloadEntriesTargetRegion.empty() &&
8854          OffloadEntriesDeviceGlobalVar.empty();
8855 }
8856 
8857 unsigned OffloadEntriesInfoManager::getTargetRegionEntryInfoCount(
8858     const TargetRegionEntryInfo &EntryInfo) const {
8859   auto It = OffloadEntriesTargetRegionCount.find(
8860       getTargetRegionEntryCountKey(EntryInfo));
8861   if (It == OffloadEntriesTargetRegionCount.end())
8862     return 0;
8863   return It->second;
8864 }
8865 
8866 void OffloadEntriesInfoManager::incrementTargetRegionEntryInfoCount(
8867     const TargetRegionEntryInfo &EntryInfo) {
8868   OffloadEntriesTargetRegionCount[getTargetRegionEntryCountKey(EntryInfo)] =
8869       EntryInfo.Count + 1;
8870 }
8871 
8872 /// Initialize target region entry.
8873 void OffloadEntriesInfoManager::initializeTargetRegionEntryInfo(
8874     const TargetRegionEntryInfo &EntryInfo, unsigned Order) {
8875   OffloadEntriesTargetRegion[EntryInfo] =
8876       OffloadEntryInfoTargetRegion(Order, /*Addr=*/nullptr, /*ID=*/nullptr,
8877                                    OMPTargetRegionEntryTargetRegion);
8878   ++OffloadingEntriesNum;
8879 }
8880 
8881 void OffloadEntriesInfoManager::registerTargetRegionEntryInfo(
8882     TargetRegionEntryInfo EntryInfo, Constant *Addr, Constant *ID,
8883     OMPTargetRegionEntryKind Flags) {
8884   assert(EntryInfo.Count == 0 && "expected default EntryInfo");
8885 
8886   // Update the EntryInfo with the next available count for this location.
8887   EntryInfo.Count = getTargetRegionEntryInfoCount(EntryInfo);
8888 
8889   // If we are emitting code for a target, the entry is already initialized,
8890   // only has to be registered.
8891   if (OMPBuilder->Config.isTargetDevice()) {
8892     // This could happen if the device compilation is invoked standalone.
8893     if (!hasTargetRegionEntryInfo(EntryInfo)) {
8894       return;
8895     }
8896     auto &Entry = OffloadEntriesTargetRegion[EntryInfo];
8897     Entry.setAddress(Addr);
8898     Entry.setID(ID);
8899     Entry.setFlags(Flags);
8900   } else {
8901     if (Flags == OffloadEntriesInfoManager::OMPTargetRegionEntryTargetRegion &&
8902         hasTargetRegionEntryInfo(EntryInfo, /*IgnoreAddressId*/ true))
8903       return;
8904     assert(!hasTargetRegionEntryInfo(EntryInfo) &&
8905            "Target region entry already registered!");
8906     OffloadEntryInfoTargetRegion Entry(OffloadingEntriesNum, Addr, ID, Flags);
8907     OffloadEntriesTargetRegion[EntryInfo] = Entry;
8908     ++OffloadingEntriesNum;
8909   }
8910   incrementTargetRegionEntryInfoCount(EntryInfo);
8911 }
8912 
8913 bool OffloadEntriesInfoManager::hasTargetRegionEntryInfo(
8914     TargetRegionEntryInfo EntryInfo, bool IgnoreAddressId) const {
8915 
8916   // Update the EntryInfo with the next available count for this location.
8917   EntryInfo.Count = getTargetRegionEntryInfoCount(EntryInfo);
8918 
8919   auto It = OffloadEntriesTargetRegion.find(EntryInfo);
8920   if (It == OffloadEntriesTargetRegion.end()) {
8921     return false;
8922   }
8923   // Fail if this entry is already registered.
8924   if (!IgnoreAddressId && (It->second.getAddress() || It->second.getID()))
8925     return false;
8926   return true;
8927 }
8928 
8929 void OffloadEntriesInfoManager::actOnTargetRegionEntriesInfo(
8930     const OffloadTargetRegionEntryInfoActTy &Action) {
8931   // Scan all target region entries and perform the provided action.
8932   for (const auto &It : OffloadEntriesTargetRegion) {
8933     Action(It.first, It.second);
8934   }
8935 }
8936 
8937 void OffloadEntriesInfoManager::initializeDeviceGlobalVarEntryInfo(
8938     StringRef Name, OMPTargetGlobalVarEntryKind Flags, unsigned Order) {
8939   OffloadEntriesDeviceGlobalVar.try_emplace(Name, Order, Flags);
8940   ++OffloadingEntriesNum;
8941 }
8942 
8943 void OffloadEntriesInfoManager::registerDeviceGlobalVarEntryInfo(
8944     StringRef VarName, Constant *Addr, int64_t VarSize,
8945     OMPTargetGlobalVarEntryKind Flags, GlobalValue::LinkageTypes Linkage) {
8946   if (OMPBuilder->Config.isTargetDevice()) {
8947     // This could happen if the device compilation is invoked standalone.
8948     if (!hasDeviceGlobalVarEntryInfo(VarName))
8949       return;
8950     auto &Entry = OffloadEntriesDeviceGlobalVar[VarName];
8951     if (Entry.getAddress() && hasDeviceGlobalVarEntryInfo(VarName)) {
8952       if (Entry.getVarSize() == 0) {
8953         Entry.setVarSize(VarSize);
8954         Entry.setLinkage(Linkage);
8955       }
8956       return;
8957     }
8958     Entry.setVarSize(VarSize);
8959     Entry.setLinkage(Linkage);
8960     Entry.setAddress(Addr);
8961   } else {
8962     if (hasDeviceGlobalVarEntryInfo(VarName)) {
8963       auto &Entry = OffloadEntriesDeviceGlobalVar[VarName];
8964       assert(Entry.isValid() && Entry.getFlags() == Flags &&
8965              "Entry not initialized!");
8966       if (Entry.getVarSize() == 0) {
8967         Entry.setVarSize(VarSize);
8968         Entry.setLinkage(Linkage);
8969       }
8970       return;
8971     }
8972     if (Flags == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirect)
8973       OffloadEntriesDeviceGlobalVar.try_emplace(VarName, OffloadingEntriesNum,
8974                                                 Addr, VarSize, Flags, Linkage,
8975                                                 VarName.str());
8976     else
8977       OffloadEntriesDeviceGlobalVar.try_emplace(
8978           VarName, OffloadingEntriesNum, Addr, VarSize, Flags, Linkage, "");
8979     ++OffloadingEntriesNum;
8980   }
8981 }
8982 
8983 void OffloadEntriesInfoManager::actOnDeviceGlobalVarEntriesInfo(
8984     const OffloadDeviceGlobalVarEntryInfoActTy &Action) {
8985   // Scan all target region entries and perform the provided action.
8986   for (const auto &E : OffloadEntriesDeviceGlobalVar)
8987     Action(E.getKey(), E.getValue());
8988 }
8989 
8990 //===----------------------------------------------------------------------===//
8991 // CanonicalLoopInfo
8992 //===----------------------------------------------------------------------===//
8993 
8994 void CanonicalLoopInfo::collectControlBlocks(
8995     SmallVectorImpl<BasicBlock *> &BBs) {
8996   // We only count those BBs as control block for which we do not need to
8997   // reverse the CFG, i.e. not the loop body which can contain arbitrary control
8998   // flow. For consistency, this also means we do not add the Body block, which
8999   // is just the entry to the body code.
9000   BBs.reserve(BBs.size() + 6);
9001   BBs.append({getPreheader(), Header, Cond, Latch, Exit, getAfter()});
9002 }
9003 
9004 BasicBlock *CanonicalLoopInfo::getPreheader() const {
9005   assert(isValid() && "Requires a valid canonical loop");
9006   for (BasicBlock *Pred : predecessors(Header)) {
9007     if (Pred != Latch)
9008       return Pred;
9009   }
9010   llvm_unreachable("Missing preheader");
9011 }
9012 
9013 void CanonicalLoopInfo::setTripCount(Value *TripCount) {
9014   assert(isValid() && "Requires a valid canonical loop");
9015 
9016   Instruction *CmpI = &getCond()->front();
9017   assert(isa<CmpInst>(CmpI) && "First inst must compare IV with TripCount");
9018   CmpI->setOperand(1, TripCount);
9019 
9020 #ifndef NDEBUG
9021   assertOK();
9022 #endif
9023 }
9024 
9025 void CanonicalLoopInfo::mapIndVar(
9026     llvm::function_ref<Value *(Instruction *)> Updater) {
9027   assert(isValid() && "Requires a valid canonical loop");
9028 
9029   Instruction *OldIV = getIndVar();
9030 
9031   // Record all uses excluding those introduced by the updater. Uses by the
9032   // CanonicalLoopInfo itself to keep track of the number of iterations are
9033   // excluded.
9034   SmallVector<Use *> ReplacableUses;
9035   for (Use &U : OldIV->uses()) {
9036     auto *User = dyn_cast<Instruction>(U.getUser());
9037     if (!User)
9038       continue;
9039     if (User->getParent() == getCond())
9040       continue;
9041     if (User->getParent() == getLatch())
9042       continue;
9043     ReplacableUses.push_back(&U);
9044   }
9045 
9046   // Run the updater that may introduce new uses
9047   Value *NewIV = Updater(OldIV);
9048 
9049   // Replace the old uses with the value returned by the updater.
9050   for (Use *U : ReplacableUses)
9051     U->set(NewIV);
9052 
9053 #ifndef NDEBUG
9054   assertOK();
9055 #endif
9056 }
9057 
9058 void CanonicalLoopInfo::assertOK() const {
9059 #ifndef NDEBUG
9060   // No constraints if this object currently does not describe a loop.
9061   if (!isValid())
9062     return;
9063 
9064   BasicBlock *Preheader = getPreheader();
9065   BasicBlock *Body = getBody();
9066   BasicBlock *After = getAfter();
9067 
9068   // Verify standard control-flow we use for OpenMP loops.
9069   assert(Preheader);
9070   assert(isa<BranchInst>(Preheader->getTerminator()) &&
9071          "Preheader must terminate with unconditional branch");
9072   assert(Preheader->getSingleSuccessor() == Header &&
9073          "Preheader must jump to header");
9074 
9075   assert(Header);
9076   assert(isa<BranchInst>(Header->getTerminator()) &&
9077          "Header must terminate with unconditional branch");
9078   assert(Header->getSingleSuccessor() == Cond &&
9079          "Header must jump to exiting block");
9080 
9081   assert(Cond);
9082   assert(Cond->getSinglePredecessor() == Header &&
9083          "Exiting block only reachable from header");
9084 
9085   assert(isa<BranchInst>(Cond->getTerminator()) &&
9086          "Exiting block must terminate with conditional branch");
9087   assert(size(successors(Cond)) == 2 &&
9088          "Exiting block must have two successors");
9089   assert(cast<BranchInst>(Cond->getTerminator())->getSuccessor(0) == Body &&
9090          "Exiting block's first successor jump to the body");
9091   assert(cast<BranchInst>(Cond->getTerminator())->getSuccessor(1) == Exit &&
9092          "Exiting block's second successor must exit the loop");
9093 
9094   assert(Body);
9095   assert(Body->getSinglePredecessor() == Cond &&
9096          "Body only reachable from exiting block");
9097   assert(!isa<PHINode>(Body->front()));
9098 
9099   assert(Latch);
9100   assert(isa<BranchInst>(Latch->getTerminator()) &&
9101          "Latch must terminate with unconditional branch");
9102   assert(Latch->getSingleSuccessor() == Header && "Latch must jump to header");
9103   // TODO: To support simple redirecting of the end of the body code that has
9104   // multiple; introduce another auxiliary basic block like preheader and after.
9105   assert(Latch->getSinglePredecessor() != nullptr);
9106   assert(!isa<PHINode>(Latch->front()));
9107 
9108   assert(Exit);
9109   assert(isa<BranchInst>(Exit->getTerminator()) &&
9110          "Exit block must terminate with unconditional branch");
9111   assert(Exit->getSingleSuccessor() == After &&
9112          "Exit block must jump to after block");
9113 
9114   assert(After);
9115   assert(After->getSinglePredecessor() == Exit &&
9116          "After block only reachable from exit block");
9117   assert(After->empty() || !isa<PHINode>(After->front()));
9118 
9119   Instruction *IndVar = getIndVar();
9120   assert(IndVar && "Canonical induction variable not found?");
9121   assert(isa<IntegerType>(IndVar->getType()) &&
9122          "Induction variable must be an integer");
9123   assert(cast<PHINode>(IndVar)->getParent() == Header &&
9124          "Induction variable must be a PHI in the loop header");
9125   assert(cast<PHINode>(IndVar)->getIncomingBlock(0) == Preheader);
9126   assert(
9127       cast<ConstantInt>(cast<PHINode>(IndVar)->getIncomingValue(0))->isZero());
9128   assert(cast<PHINode>(IndVar)->getIncomingBlock(1) == Latch);
9129 
9130   auto *NextIndVar = cast<PHINode>(IndVar)->getIncomingValue(1);
9131   assert(cast<Instruction>(NextIndVar)->getParent() == Latch);
9132   assert(cast<BinaryOperator>(NextIndVar)->getOpcode() == BinaryOperator::Add);
9133   assert(cast<BinaryOperator>(NextIndVar)->getOperand(0) == IndVar);
9134   assert(cast<ConstantInt>(cast<BinaryOperator>(NextIndVar)->getOperand(1))
9135              ->isOne());
9136 
9137   Value *TripCount = getTripCount();
9138   assert(TripCount && "Loop trip count not found?");
9139   assert(IndVar->getType() == TripCount->getType() &&
9140          "Trip count and induction variable must have the same type");
9141 
9142   auto *CmpI = cast<CmpInst>(&Cond->front());
9143   assert(CmpI->getPredicate() == CmpInst::ICMP_ULT &&
9144          "Exit condition must be a signed less-than comparison");
9145   assert(CmpI->getOperand(0) == IndVar &&
9146          "Exit condition must compare the induction variable");
9147   assert(CmpI->getOperand(1) == TripCount &&
9148          "Exit condition must compare with the trip count");
9149 #endif
9150 }
9151 
9152 void CanonicalLoopInfo::invalidate() {
9153   Header = nullptr;
9154   Cond = nullptr;
9155   Latch = nullptr;
9156   Exit = nullptr;
9157 }
9158