xref: /freebsd/contrib/llvm-project/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (revision a90b9d0159070121c221b966469c3e36d912bf82)
1 //===- OpenMPIRBuilder.cpp - Builder for LLVM-IR for OpenMP directives ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 ///
10 /// This file implements the OpenMPIRBuilder class, which is used as a
11 /// convenient way to create LLVM instructions for OpenMP directives.
12 ///
13 //===----------------------------------------------------------------------===//
14 
15 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
16 #include "llvm/ADT/SmallSet.h"
17 #include "llvm/ADT/StringExtras.h"
18 #include "llvm/ADT/StringRef.h"
19 #include "llvm/Analysis/AssumptionCache.h"
20 #include "llvm/Analysis/CodeMetrics.h"
21 #include "llvm/Analysis/LoopInfo.h"
22 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
23 #include "llvm/Analysis/ScalarEvolution.h"
24 #include "llvm/Analysis/TargetLibraryInfo.h"
25 #include "llvm/Bitcode/BitcodeReader.h"
26 #include "llvm/Frontend/Offloading/Utility.h"
27 #include "llvm/Frontend/OpenMP/OMPGridValues.h"
28 #include "llvm/IR/Attributes.h"
29 #include "llvm/IR/BasicBlock.h"
30 #include "llvm/IR/CFG.h"
31 #include "llvm/IR/CallingConv.h"
32 #include "llvm/IR/Constant.h"
33 #include "llvm/IR/Constants.h"
34 #include "llvm/IR/DebugInfoMetadata.h"
35 #include "llvm/IR/DerivedTypes.h"
36 #include "llvm/IR/Function.h"
37 #include "llvm/IR/GlobalVariable.h"
38 #include "llvm/IR/IRBuilder.h"
39 #include "llvm/IR/LLVMContext.h"
40 #include "llvm/IR/MDBuilder.h"
41 #include "llvm/IR/Metadata.h"
42 #include "llvm/IR/PassManager.h"
43 #include "llvm/IR/Value.h"
44 #include "llvm/MC/TargetRegistry.h"
45 #include "llvm/Support/CommandLine.h"
46 #include "llvm/Support/ErrorHandling.h"
47 #include "llvm/Support/FileSystem.h"
48 #include "llvm/Target/TargetMachine.h"
49 #include "llvm/Target/TargetOptions.h"
50 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
51 #include "llvm/Transforms/Utils/Cloning.h"
52 #include "llvm/Transforms/Utils/CodeExtractor.h"
53 #include "llvm/Transforms/Utils/LoopPeel.h"
54 #include "llvm/Transforms/Utils/UnrollLoop.h"
55 
56 #include <cstdint>
57 #include <optional>
58 
59 #define DEBUG_TYPE "openmp-ir-builder"
60 
61 using namespace llvm;
62 using namespace omp;
63 
64 static cl::opt<bool>
65     OptimisticAttributes("openmp-ir-builder-optimistic-attributes", cl::Hidden,
66                          cl::desc("Use optimistic attributes describing "
67                                   "'as-if' properties of runtime calls."),
68                          cl::init(false));
69 
70 static cl::opt<double> UnrollThresholdFactor(
71     "openmp-ir-builder-unroll-threshold-factor", cl::Hidden,
72     cl::desc("Factor for the unroll threshold to account for code "
73              "simplifications still taking place"),
74     cl::init(1.5));
75 
76 #ifndef NDEBUG
77 /// Return whether IP1 and IP2 are ambiguous, i.e. that inserting instructions
78 /// at position IP1 may change the meaning of IP2 or vice-versa. This is because
79 /// an InsertPoint stores the instruction before something is inserted. For
80 /// instance, if both point to the same instruction, two IRBuilders alternating
81 /// creating instruction will cause the instructions to be interleaved.
82 static bool isConflictIP(IRBuilder<>::InsertPoint IP1,
83                          IRBuilder<>::InsertPoint IP2) {
84   if (!IP1.isSet() || !IP2.isSet())
85     return false;
86   return IP1.getBlock() == IP2.getBlock() && IP1.getPoint() == IP2.getPoint();
87 }
88 
89 static bool isValidWorkshareLoopScheduleType(OMPScheduleType SchedType) {
90   // Valid ordered/unordered and base algorithm combinations.
91   switch (SchedType & ~OMPScheduleType::MonotonicityMask) {
92   case OMPScheduleType::UnorderedStaticChunked:
93   case OMPScheduleType::UnorderedStatic:
94   case OMPScheduleType::UnorderedDynamicChunked:
95   case OMPScheduleType::UnorderedGuidedChunked:
96   case OMPScheduleType::UnorderedRuntime:
97   case OMPScheduleType::UnorderedAuto:
98   case OMPScheduleType::UnorderedTrapezoidal:
99   case OMPScheduleType::UnorderedGreedy:
100   case OMPScheduleType::UnorderedBalanced:
101   case OMPScheduleType::UnorderedGuidedIterativeChunked:
102   case OMPScheduleType::UnorderedGuidedAnalyticalChunked:
103   case OMPScheduleType::UnorderedSteal:
104   case OMPScheduleType::UnorderedStaticBalancedChunked:
105   case OMPScheduleType::UnorderedGuidedSimd:
106   case OMPScheduleType::UnorderedRuntimeSimd:
107   case OMPScheduleType::OrderedStaticChunked:
108   case OMPScheduleType::OrderedStatic:
109   case OMPScheduleType::OrderedDynamicChunked:
110   case OMPScheduleType::OrderedGuidedChunked:
111   case OMPScheduleType::OrderedRuntime:
112   case OMPScheduleType::OrderedAuto:
113   case OMPScheduleType::OrderdTrapezoidal:
114   case OMPScheduleType::NomergeUnorderedStaticChunked:
115   case OMPScheduleType::NomergeUnorderedStatic:
116   case OMPScheduleType::NomergeUnorderedDynamicChunked:
117   case OMPScheduleType::NomergeUnorderedGuidedChunked:
118   case OMPScheduleType::NomergeUnorderedRuntime:
119   case OMPScheduleType::NomergeUnorderedAuto:
120   case OMPScheduleType::NomergeUnorderedTrapezoidal:
121   case OMPScheduleType::NomergeUnorderedGreedy:
122   case OMPScheduleType::NomergeUnorderedBalanced:
123   case OMPScheduleType::NomergeUnorderedGuidedIterativeChunked:
124   case OMPScheduleType::NomergeUnorderedGuidedAnalyticalChunked:
125   case OMPScheduleType::NomergeUnorderedSteal:
126   case OMPScheduleType::NomergeOrderedStaticChunked:
127   case OMPScheduleType::NomergeOrderedStatic:
128   case OMPScheduleType::NomergeOrderedDynamicChunked:
129   case OMPScheduleType::NomergeOrderedGuidedChunked:
130   case OMPScheduleType::NomergeOrderedRuntime:
131   case OMPScheduleType::NomergeOrderedAuto:
132   case OMPScheduleType::NomergeOrderedTrapezoidal:
133     break;
134   default:
135     return false;
136   }
137 
138   // Must not set both monotonicity modifiers at the same time.
139   OMPScheduleType MonotonicityFlags =
140       SchedType & OMPScheduleType::MonotonicityMask;
141   if (MonotonicityFlags == OMPScheduleType::MonotonicityMask)
142     return false;
143 
144   return true;
145 }
146 #endif
147 
148 static const omp::GV &getGridValue(const Triple &T, Function *Kernel) {
149   if (T.isAMDGPU()) {
150     StringRef Features =
151         Kernel->getFnAttribute("target-features").getValueAsString();
152     if (Features.count("+wavefrontsize64"))
153       return omp::getAMDGPUGridValues<64>();
154     return omp::getAMDGPUGridValues<32>();
155   }
156   if (T.isNVPTX())
157     return omp::NVPTXGridValues;
158   llvm_unreachable("No grid value available for this architecture!");
159 }
160 
161 /// Determine which scheduling algorithm to use, determined from schedule clause
162 /// arguments.
163 static OMPScheduleType
164 getOpenMPBaseScheduleType(llvm::omp::ScheduleKind ClauseKind, bool HasChunks,
165                           bool HasSimdModifier) {
166   // Currently, the default schedule it static.
167   switch (ClauseKind) {
168   case OMP_SCHEDULE_Default:
169   case OMP_SCHEDULE_Static:
170     return HasChunks ? OMPScheduleType::BaseStaticChunked
171                      : OMPScheduleType::BaseStatic;
172   case OMP_SCHEDULE_Dynamic:
173     return OMPScheduleType::BaseDynamicChunked;
174   case OMP_SCHEDULE_Guided:
175     return HasSimdModifier ? OMPScheduleType::BaseGuidedSimd
176                            : OMPScheduleType::BaseGuidedChunked;
177   case OMP_SCHEDULE_Auto:
178     return llvm::omp::OMPScheduleType::BaseAuto;
179   case OMP_SCHEDULE_Runtime:
180     return HasSimdModifier ? OMPScheduleType::BaseRuntimeSimd
181                            : OMPScheduleType::BaseRuntime;
182   }
183   llvm_unreachable("unhandled schedule clause argument");
184 }
185 
186 /// Adds ordering modifier flags to schedule type.
187 static OMPScheduleType
188 getOpenMPOrderingScheduleType(OMPScheduleType BaseScheduleType,
189                               bool HasOrderedClause) {
190   assert((BaseScheduleType & OMPScheduleType::ModifierMask) ==
191              OMPScheduleType::None &&
192          "Must not have ordering nor monotonicity flags already set");
193 
194   OMPScheduleType OrderingModifier = HasOrderedClause
195                                          ? OMPScheduleType::ModifierOrdered
196                                          : OMPScheduleType::ModifierUnordered;
197   OMPScheduleType OrderingScheduleType = BaseScheduleType | OrderingModifier;
198 
199   // Unsupported combinations
200   if (OrderingScheduleType ==
201       (OMPScheduleType::BaseGuidedSimd | OMPScheduleType::ModifierOrdered))
202     return OMPScheduleType::OrderedGuidedChunked;
203   else if (OrderingScheduleType == (OMPScheduleType::BaseRuntimeSimd |
204                                     OMPScheduleType::ModifierOrdered))
205     return OMPScheduleType::OrderedRuntime;
206 
207   return OrderingScheduleType;
208 }
209 
210 /// Adds monotonicity modifier flags to schedule type.
211 static OMPScheduleType
212 getOpenMPMonotonicityScheduleType(OMPScheduleType ScheduleType,
213                                   bool HasSimdModifier, bool HasMonotonic,
214                                   bool HasNonmonotonic, bool HasOrderedClause) {
215   assert((ScheduleType & OMPScheduleType::MonotonicityMask) ==
216              OMPScheduleType::None &&
217          "Must not have monotonicity flags already set");
218   assert((!HasMonotonic || !HasNonmonotonic) &&
219          "Monotonic and Nonmonotonic are contradicting each other");
220 
221   if (HasMonotonic) {
222     return ScheduleType | OMPScheduleType::ModifierMonotonic;
223   } else if (HasNonmonotonic) {
224     return ScheduleType | OMPScheduleType::ModifierNonmonotonic;
225   } else {
226     // OpenMP 5.1, 2.11.4 Worksharing-Loop Construct, Description.
227     // If the static schedule kind is specified or if the ordered clause is
228     // specified, and if the nonmonotonic modifier is not specified, the
229     // effect is as if the monotonic modifier is specified. Otherwise, unless
230     // the monotonic modifier is specified, the effect is as if the
231     // nonmonotonic modifier is specified.
232     OMPScheduleType BaseScheduleType =
233         ScheduleType & ~OMPScheduleType::ModifierMask;
234     if ((BaseScheduleType == OMPScheduleType::BaseStatic) ||
235         (BaseScheduleType == OMPScheduleType::BaseStaticChunked) ||
236         HasOrderedClause) {
237       // The monotonic is used by default in openmp runtime library, so no need
238       // to set it.
239       return ScheduleType;
240     } else {
241       return ScheduleType | OMPScheduleType::ModifierNonmonotonic;
242     }
243   }
244 }
245 
246 /// Determine the schedule type using schedule and ordering clause arguments.
247 static OMPScheduleType
248 computeOpenMPScheduleType(ScheduleKind ClauseKind, bool HasChunks,
249                           bool HasSimdModifier, bool HasMonotonicModifier,
250                           bool HasNonmonotonicModifier, bool HasOrderedClause) {
251   OMPScheduleType BaseSchedule =
252       getOpenMPBaseScheduleType(ClauseKind, HasChunks, HasSimdModifier);
253   OMPScheduleType OrderedSchedule =
254       getOpenMPOrderingScheduleType(BaseSchedule, HasOrderedClause);
255   OMPScheduleType Result = getOpenMPMonotonicityScheduleType(
256       OrderedSchedule, HasSimdModifier, HasMonotonicModifier,
257       HasNonmonotonicModifier, HasOrderedClause);
258 
259   assert(isValidWorkshareLoopScheduleType(Result));
260   return Result;
261 }
262 
263 /// Make \p Source branch to \p Target.
264 ///
265 /// Handles two situations:
266 /// * \p Source already has an unconditional branch.
267 /// * \p Source is a degenerate block (no terminator because the BB is
268 ///             the current head of the IR construction).
269 static void redirectTo(BasicBlock *Source, BasicBlock *Target, DebugLoc DL) {
270   if (Instruction *Term = Source->getTerminator()) {
271     auto *Br = cast<BranchInst>(Term);
272     assert(!Br->isConditional() &&
273            "BB's terminator must be an unconditional branch (or degenerate)");
274     BasicBlock *Succ = Br->getSuccessor(0);
275     Succ->removePredecessor(Source, /*KeepOneInputPHIs=*/true);
276     Br->setSuccessor(0, Target);
277     return;
278   }
279 
280   auto *NewBr = BranchInst::Create(Target, Source);
281   NewBr->setDebugLoc(DL);
282 }
283 
284 void llvm::spliceBB(IRBuilderBase::InsertPoint IP, BasicBlock *New,
285                     bool CreateBranch) {
286   assert(New->getFirstInsertionPt() == New->begin() &&
287          "Target BB must not have PHI nodes");
288 
289   // Move instructions to new block.
290   BasicBlock *Old = IP.getBlock();
291   New->splice(New->begin(), Old, IP.getPoint(), Old->end());
292 
293   if (CreateBranch)
294     BranchInst::Create(New, Old);
295 }
296 
297 void llvm::spliceBB(IRBuilder<> &Builder, BasicBlock *New, bool CreateBranch) {
298   DebugLoc DebugLoc = Builder.getCurrentDebugLocation();
299   BasicBlock *Old = Builder.GetInsertBlock();
300 
301   spliceBB(Builder.saveIP(), New, CreateBranch);
302   if (CreateBranch)
303     Builder.SetInsertPoint(Old->getTerminator());
304   else
305     Builder.SetInsertPoint(Old);
306 
307   // SetInsertPoint also updates the Builder's debug location, but we want to
308   // keep the one the Builder was configured to use.
309   Builder.SetCurrentDebugLocation(DebugLoc);
310 }
311 
312 BasicBlock *llvm::splitBB(IRBuilderBase::InsertPoint IP, bool CreateBranch,
313                           llvm::Twine Name) {
314   BasicBlock *Old = IP.getBlock();
315   BasicBlock *New = BasicBlock::Create(
316       Old->getContext(), Name.isTriviallyEmpty() ? Old->getName() : Name,
317       Old->getParent(), Old->getNextNode());
318   spliceBB(IP, New, CreateBranch);
319   New->replaceSuccessorsPhiUsesWith(Old, New);
320   return New;
321 }
322 
323 BasicBlock *llvm::splitBB(IRBuilderBase &Builder, bool CreateBranch,
324                           llvm::Twine Name) {
325   DebugLoc DebugLoc = Builder.getCurrentDebugLocation();
326   BasicBlock *New = splitBB(Builder.saveIP(), CreateBranch, Name);
327   if (CreateBranch)
328     Builder.SetInsertPoint(Builder.GetInsertBlock()->getTerminator());
329   else
330     Builder.SetInsertPoint(Builder.GetInsertBlock());
331   // SetInsertPoint also updates the Builder's debug location, but we want to
332   // keep the one the Builder was configured to use.
333   Builder.SetCurrentDebugLocation(DebugLoc);
334   return New;
335 }
336 
337 BasicBlock *llvm::splitBB(IRBuilder<> &Builder, bool CreateBranch,
338                           llvm::Twine Name) {
339   DebugLoc DebugLoc = Builder.getCurrentDebugLocation();
340   BasicBlock *New = splitBB(Builder.saveIP(), CreateBranch, Name);
341   if (CreateBranch)
342     Builder.SetInsertPoint(Builder.GetInsertBlock()->getTerminator());
343   else
344     Builder.SetInsertPoint(Builder.GetInsertBlock());
345   // SetInsertPoint also updates the Builder's debug location, but we want to
346   // keep the one the Builder was configured to use.
347   Builder.SetCurrentDebugLocation(DebugLoc);
348   return New;
349 }
350 
351 BasicBlock *llvm::splitBBWithSuffix(IRBuilderBase &Builder, bool CreateBranch,
352                                     llvm::Twine Suffix) {
353   BasicBlock *Old = Builder.GetInsertBlock();
354   return splitBB(Builder, CreateBranch, Old->getName() + Suffix);
355 }
356 
357 // This function creates a fake integer value and a fake use for the integer
358 // value. It returns the fake value created. This is useful in modeling the
359 // extra arguments to the outlined functions.
360 Value *createFakeIntVal(IRBuilder<> &Builder,
361                         OpenMPIRBuilder::InsertPointTy OuterAllocaIP,
362                         std::stack<Instruction *> &ToBeDeleted,
363                         OpenMPIRBuilder::InsertPointTy InnerAllocaIP,
364                         const Twine &Name = "", bool AsPtr = true) {
365   Builder.restoreIP(OuterAllocaIP);
366   Instruction *FakeVal;
367   AllocaInst *FakeValAddr =
368       Builder.CreateAlloca(Builder.getInt32Ty(), nullptr, Name + ".addr");
369   ToBeDeleted.push(FakeValAddr);
370 
371   if (AsPtr) {
372     FakeVal = FakeValAddr;
373   } else {
374     FakeVal =
375         Builder.CreateLoad(Builder.getInt32Ty(), FakeValAddr, Name + ".val");
376     ToBeDeleted.push(FakeVal);
377   }
378 
379   // Generate a fake use of this value
380   Builder.restoreIP(InnerAllocaIP);
381   Instruction *UseFakeVal;
382   if (AsPtr) {
383     UseFakeVal =
384         Builder.CreateLoad(Builder.getInt32Ty(), FakeVal, Name + ".use");
385   } else {
386     UseFakeVal =
387         cast<BinaryOperator>(Builder.CreateAdd(FakeVal, Builder.getInt32(10)));
388   }
389   ToBeDeleted.push(UseFakeVal);
390   return FakeVal;
391 }
392 
393 //===----------------------------------------------------------------------===//
394 // OpenMPIRBuilderConfig
395 //===----------------------------------------------------------------------===//
396 
397 namespace {
398 LLVM_ENABLE_BITMASK_ENUMS_IN_NAMESPACE();
399 /// Values for bit flags for marking which requires clauses have been used.
400 enum OpenMPOffloadingRequiresDirFlags {
401   /// flag undefined.
402   OMP_REQ_UNDEFINED = 0x000,
403   /// no requires directive present.
404   OMP_REQ_NONE = 0x001,
405   /// reverse_offload clause.
406   OMP_REQ_REVERSE_OFFLOAD = 0x002,
407   /// unified_address clause.
408   OMP_REQ_UNIFIED_ADDRESS = 0x004,
409   /// unified_shared_memory clause.
410   OMP_REQ_UNIFIED_SHARED_MEMORY = 0x008,
411   /// dynamic_allocators clause.
412   OMP_REQ_DYNAMIC_ALLOCATORS = 0x010,
413   LLVM_MARK_AS_BITMASK_ENUM(/*LargestValue=*/OMP_REQ_DYNAMIC_ALLOCATORS)
414 };
415 
416 } // anonymous namespace
417 
418 OpenMPIRBuilderConfig::OpenMPIRBuilderConfig()
419     : RequiresFlags(OMP_REQ_UNDEFINED) {}
420 
421 OpenMPIRBuilderConfig::OpenMPIRBuilderConfig(
422     bool IsTargetDevice, bool IsGPU, bool OpenMPOffloadMandatory,
423     bool HasRequiresReverseOffload, bool HasRequiresUnifiedAddress,
424     bool HasRequiresUnifiedSharedMemory, bool HasRequiresDynamicAllocators)
425     : IsTargetDevice(IsTargetDevice), IsGPU(IsGPU),
426       OpenMPOffloadMandatory(OpenMPOffloadMandatory),
427       RequiresFlags(OMP_REQ_UNDEFINED) {
428   if (HasRequiresReverseOffload)
429     RequiresFlags |= OMP_REQ_REVERSE_OFFLOAD;
430   if (HasRequiresUnifiedAddress)
431     RequiresFlags |= OMP_REQ_UNIFIED_ADDRESS;
432   if (HasRequiresUnifiedSharedMemory)
433     RequiresFlags |= OMP_REQ_UNIFIED_SHARED_MEMORY;
434   if (HasRequiresDynamicAllocators)
435     RequiresFlags |= OMP_REQ_DYNAMIC_ALLOCATORS;
436 }
437 
438 bool OpenMPIRBuilderConfig::hasRequiresReverseOffload() const {
439   return RequiresFlags & OMP_REQ_REVERSE_OFFLOAD;
440 }
441 
442 bool OpenMPIRBuilderConfig::hasRequiresUnifiedAddress() const {
443   return RequiresFlags & OMP_REQ_UNIFIED_ADDRESS;
444 }
445 
446 bool OpenMPIRBuilderConfig::hasRequiresUnifiedSharedMemory() const {
447   return RequiresFlags & OMP_REQ_UNIFIED_SHARED_MEMORY;
448 }
449 
450 bool OpenMPIRBuilderConfig::hasRequiresDynamicAllocators() const {
451   return RequiresFlags & OMP_REQ_DYNAMIC_ALLOCATORS;
452 }
453 
454 int64_t OpenMPIRBuilderConfig::getRequiresFlags() const {
455   return hasRequiresFlags() ? RequiresFlags
456                             : static_cast<int64_t>(OMP_REQ_NONE);
457 }
458 
459 void OpenMPIRBuilderConfig::setHasRequiresReverseOffload(bool Value) {
460   if (Value)
461     RequiresFlags |= OMP_REQ_REVERSE_OFFLOAD;
462   else
463     RequiresFlags &= ~OMP_REQ_REVERSE_OFFLOAD;
464 }
465 
466 void OpenMPIRBuilderConfig::setHasRequiresUnifiedAddress(bool Value) {
467   if (Value)
468     RequiresFlags |= OMP_REQ_UNIFIED_ADDRESS;
469   else
470     RequiresFlags &= ~OMP_REQ_UNIFIED_ADDRESS;
471 }
472 
473 void OpenMPIRBuilderConfig::setHasRequiresUnifiedSharedMemory(bool Value) {
474   if (Value)
475     RequiresFlags |= OMP_REQ_UNIFIED_SHARED_MEMORY;
476   else
477     RequiresFlags &= ~OMP_REQ_UNIFIED_SHARED_MEMORY;
478 }
479 
480 void OpenMPIRBuilderConfig::setHasRequiresDynamicAllocators(bool Value) {
481   if (Value)
482     RequiresFlags |= OMP_REQ_DYNAMIC_ALLOCATORS;
483   else
484     RequiresFlags &= ~OMP_REQ_DYNAMIC_ALLOCATORS;
485 }
486 
487 //===----------------------------------------------------------------------===//
488 // OpenMPIRBuilder
489 //===----------------------------------------------------------------------===//
490 
491 void OpenMPIRBuilder::getKernelArgsVector(TargetKernelArgs &KernelArgs,
492                                           IRBuilderBase &Builder,
493                                           SmallVector<Value *> &ArgsVector) {
494   Value *Version = Builder.getInt32(OMP_KERNEL_ARG_VERSION);
495   Value *PointerNum = Builder.getInt32(KernelArgs.NumTargetItems);
496   auto Int32Ty = Type::getInt32Ty(Builder.getContext());
497   Value *ZeroArray = Constant::getNullValue(ArrayType::get(Int32Ty, 3));
498   Value *Flags = Builder.getInt64(KernelArgs.HasNoWait);
499 
500   Value *NumTeams3D =
501       Builder.CreateInsertValue(ZeroArray, KernelArgs.NumTeams, {0});
502   Value *NumThreads3D =
503       Builder.CreateInsertValue(ZeroArray, KernelArgs.NumThreads, {0});
504 
505   ArgsVector = {Version,
506                 PointerNum,
507                 KernelArgs.RTArgs.BasePointersArray,
508                 KernelArgs.RTArgs.PointersArray,
509                 KernelArgs.RTArgs.SizesArray,
510                 KernelArgs.RTArgs.MapTypesArray,
511                 KernelArgs.RTArgs.MapNamesArray,
512                 KernelArgs.RTArgs.MappersArray,
513                 KernelArgs.NumIterations,
514                 Flags,
515                 NumTeams3D,
516                 NumThreads3D,
517                 KernelArgs.DynCGGroupMem};
518 }
519 
520 void OpenMPIRBuilder::addAttributes(omp::RuntimeFunction FnID, Function &Fn) {
521   LLVMContext &Ctx = Fn.getContext();
522 
523   // Get the function's current attributes.
524   auto Attrs = Fn.getAttributes();
525   auto FnAttrs = Attrs.getFnAttrs();
526   auto RetAttrs = Attrs.getRetAttrs();
527   SmallVector<AttributeSet, 4> ArgAttrs;
528   for (size_t ArgNo = 0; ArgNo < Fn.arg_size(); ++ArgNo)
529     ArgAttrs.emplace_back(Attrs.getParamAttrs(ArgNo));
530 
531   // Add AS to FnAS while taking special care with integer extensions.
532   auto addAttrSet = [&](AttributeSet &FnAS, const AttributeSet &AS,
533                         bool Param = true) -> void {
534     bool HasSignExt = AS.hasAttribute(Attribute::SExt);
535     bool HasZeroExt = AS.hasAttribute(Attribute::ZExt);
536     if (HasSignExt || HasZeroExt) {
537       assert(AS.getNumAttributes() == 1 &&
538              "Currently not handling extension attr combined with others.");
539       if (Param) {
540         if (auto AK = TargetLibraryInfo::getExtAttrForI32Param(T, HasSignExt))
541           FnAS = FnAS.addAttribute(Ctx, AK);
542       } else if (auto AK =
543                      TargetLibraryInfo::getExtAttrForI32Return(T, HasSignExt))
544         FnAS = FnAS.addAttribute(Ctx, AK);
545     } else {
546       FnAS = FnAS.addAttributes(Ctx, AS);
547     }
548   };
549 
550 #define OMP_ATTRS_SET(VarName, AttrSet) AttributeSet VarName = AttrSet;
551 #include "llvm/Frontend/OpenMP/OMPKinds.def"
552 
553   // Add attributes to the function declaration.
554   switch (FnID) {
555 #define OMP_RTL_ATTRS(Enum, FnAttrSet, RetAttrSet, ArgAttrSets)                \
556   case Enum:                                                                   \
557     FnAttrs = FnAttrs.addAttributes(Ctx, FnAttrSet);                           \
558     addAttrSet(RetAttrs, RetAttrSet, /*Param*/ false);                         \
559     for (size_t ArgNo = 0; ArgNo < ArgAttrSets.size(); ++ArgNo)                \
560       addAttrSet(ArgAttrs[ArgNo], ArgAttrSets[ArgNo]);                         \
561     Fn.setAttributes(AttributeList::get(Ctx, FnAttrs, RetAttrs, ArgAttrs));    \
562     break;
563 #include "llvm/Frontend/OpenMP/OMPKinds.def"
564   default:
565     // Attributes are optional.
566     break;
567   }
568 }
569 
570 FunctionCallee
571 OpenMPIRBuilder::getOrCreateRuntimeFunction(Module &M, RuntimeFunction FnID) {
572   FunctionType *FnTy = nullptr;
573   Function *Fn = nullptr;
574 
575   // Try to find the declation in the module first.
576   switch (FnID) {
577 #define OMP_RTL(Enum, Str, IsVarArg, ReturnType, ...)                          \
578   case Enum:                                                                   \
579     FnTy = FunctionType::get(ReturnType, ArrayRef<Type *>{__VA_ARGS__},        \
580                              IsVarArg);                                        \
581     Fn = M.getFunction(Str);                                                   \
582     break;
583 #include "llvm/Frontend/OpenMP/OMPKinds.def"
584   }
585 
586   if (!Fn) {
587     // Create a new declaration if we need one.
588     switch (FnID) {
589 #define OMP_RTL(Enum, Str, ...)                                                \
590   case Enum:                                                                   \
591     Fn = Function::Create(FnTy, GlobalValue::ExternalLinkage, Str, M);         \
592     break;
593 #include "llvm/Frontend/OpenMP/OMPKinds.def"
594     }
595 
596     // Add information if the runtime function takes a callback function
597     if (FnID == OMPRTL___kmpc_fork_call || FnID == OMPRTL___kmpc_fork_teams) {
598       if (!Fn->hasMetadata(LLVMContext::MD_callback)) {
599         LLVMContext &Ctx = Fn->getContext();
600         MDBuilder MDB(Ctx);
601         // Annotate the callback behavior of the runtime function:
602         //  - The callback callee is argument number 2 (microtask).
603         //  - The first two arguments of the callback callee are unknown (-1).
604         //  - All variadic arguments to the runtime function are passed to the
605         //    callback callee.
606         Fn->addMetadata(
607             LLVMContext::MD_callback,
608             *MDNode::get(Ctx, {MDB.createCallbackEncoding(
609                                   2, {-1, -1}, /* VarArgsArePassed */ true)}));
610       }
611     }
612 
613     LLVM_DEBUG(dbgs() << "Created OpenMP runtime function " << Fn->getName()
614                       << " with type " << *Fn->getFunctionType() << "\n");
615     addAttributes(FnID, *Fn);
616 
617   } else {
618     LLVM_DEBUG(dbgs() << "Found OpenMP runtime function " << Fn->getName()
619                       << " with type " << *Fn->getFunctionType() << "\n");
620   }
621 
622   assert(Fn && "Failed to create OpenMP runtime function");
623 
624   return {FnTy, Fn};
625 }
626 
627 Function *OpenMPIRBuilder::getOrCreateRuntimeFunctionPtr(RuntimeFunction FnID) {
628   FunctionCallee RTLFn = getOrCreateRuntimeFunction(M, FnID);
629   auto *Fn = dyn_cast<llvm::Function>(RTLFn.getCallee());
630   assert(Fn && "Failed to create OpenMP runtime function pointer");
631   return Fn;
632 }
633 
634 void OpenMPIRBuilder::initialize() { initializeTypes(M); }
635 
636 void OpenMPIRBuilder::finalize(Function *Fn) {
637   SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
638   SmallVector<BasicBlock *, 32> Blocks;
639   SmallVector<OutlineInfo, 16> DeferredOutlines;
640   for (OutlineInfo &OI : OutlineInfos) {
641     // Skip functions that have not finalized yet; may happen with nested
642     // function generation.
643     if (Fn && OI.getFunction() != Fn) {
644       DeferredOutlines.push_back(OI);
645       continue;
646     }
647 
648     ParallelRegionBlockSet.clear();
649     Blocks.clear();
650     OI.collectBlocks(ParallelRegionBlockSet, Blocks);
651 
652     Function *OuterFn = OI.getFunction();
653     CodeExtractorAnalysisCache CEAC(*OuterFn);
654     // If we generate code for the target device, we need to allocate
655     // struct for aggregate params in the device default alloca address space.
656     // OpenMP runtime requires that the params of the extracted functions are
657     // passed as zero address space pointers. This flag ensures that
658     // CodeExtractor generates correct code for extracted functions
659     // which are used by OpenMP runtime.
660     bool ArgsInZeroAddressSpace = Config.isTargetDevice();
661     CodeExtractor Extractor(Blocks, /* DominatorTree */ nullptr,
662                             /* AggregateArgs */ true,
663                             /* BlockFrequencyInfo */ nullptr,
664                             /* BranchProbabilityInfo */ nullptr,
665                             /* AssumptionCache */ nullptr,
666                             /* AllowVarArgs */ true,
667                             /* AllowAlloca */ true,
668                             /* AllocaBlock*/ OI.OuterAllocaBB,
669                             /* Suffix */ ".omp_par", ArgsInZeroAddressSpace);
670 
671     LLVM_DEBUG(dbgs() << "Before     outlining: " << *OuterFn << "\n");
672     LLVM_DEBUG(dbgs() << "Entry " << OI.EntryBB->getName()
673                       << " Exit: " << OI.ExitBB->getName() << "\n");
674     assert(Extractor.isEligible() &&
675            "Expected OpenMP outlining to be possible!");
676 
677     for (auto *V : OI.ExcludeArgsFromAggregate)
678       Extractor.excludeArgFromAggregate(V);
679 
680     Function *OutlinedFn = Extractor.extractCodeRegion(CEAC);
681 
682     LLVM_DEBUG(dbgs() << "After      outlining: " << *OuterFn << "\n");
683     LLVM_DEBUG(dbgs() << "   Outlined function: " << *OutlinedFn << "\n");
684     assert(OutlinedFn->getReturnType()->isVoidTy() &&
685            "OpenMP outlined functions should not return a value!");
686 
687     // For compability with the clang CG we move the outlined function after the
688     // one with the parallel region.
689     OutlinedFn->removeFromParent();
690     M.getFunctionList().insertAfter(OuterFn->getIterator(), OutlinedFn);
691 
692     // Remove the artificial entry introduced by the extractor right away, we
693     // made our own entry block after all.
694     {
695       BasicBlock &ArtificialEntry = OutlinedFn->getEntryBlock();
696       assert(ArtificialEntry.getUniqueSuccessor() == OI.EntryBB);
697       assert(OI.EntryBB->getUniquePredecessor() == &ArtificialEntry);
698       // Move instructions from the to-be-deleted ArtificialEntry to the entry
699       // basic block of the parallel region. CodeExtractor generates
700       // instructions to unwrap the aggregate argument and may sink
701       // allocas/bitcasts for values that are solely used in the outlined region
702       // and do not escape.
703       assert(!ArtificialEntry.empty() &&
704              "Expected instructions to add in the outlined region entry");
705       for (BasicBlock::reverse_iterator It = ArtificialEntry.rbegin(),
706                                         End = ArtificialEntry.rend();
707            It != End;) {
708         Instruction &I = *It;
709         It++;
710 
711         if (I.isTerminator())
712           continue;
713 
714         I.moveBeforePreserving(*OI.EntryBB, OI.EntryBB->getFirstInsertionPt());
715       }
716 
717       OI.EntryBB->moveBefore(&ArtificialEntry);
718       ArtificialEntry.eraseFromParent();
719     }
720     assert(&OutlinedFn->getEntryBlock() == OI.EntryBB);
721     assert(OutlinedFn && OutlinedFn->getNumUses() == 1);
722 
723     // Run a user callback, e.g. to add attributes.
724     if (OI.PostOutlineCB)
725       OI.PostOutlineCB(*OutlinedFn);
726   }
727 
728   // Remove work items that have been completed.
729   OutlineInfos = std::move(DeferredOutlines);
730 
731   EmitMetadataErrorReportFunctionTy &&ErrorReportFn =
732       [](EmitMetadataErrorKind Kind,
733          const TargetRegionEntryInfo &EntryInfo) -> void {
734     errs() << "Error of kind: " << Kind
735            << " when emitting offload entries and metadata during "
736               "OMPIRBuilder finalization \n";
737   };
738 
739   if (!OffloadInfoManager.empty())
740     createOffloadEntriesAndInfoMetadata(ErrorReportFn);
741 }
742 
743 OpenMPIRBuilder::~OpenMPIRBuilder() {
744   assert(OutlineInfos.empty() && "There must be no outstanding outlinings");
745 }
746 
747 GlobalValue *OpenMPIRBuilder::createGlobalFlag(unsigned Value, StringRef Name) {
748   IntegerType *I32Ty = Type::getInt32Ty(M.getContext());
749   auto *GV =
750       new GlobalVariable(M, I32Ty,
751                          /* isConstant = */ true, GlobalValue::WeakODRLinkage,
752                          ConstantInt::get(I32Ty, Value), Name);
753   GV->setVisibility(GlobalValue::HiddenVisibility);
754 
755   return GV;
756 }
757 
758 Constant *OpenMPIRBuilder::getOrCreateIdent(Constant *SrcLocStr,
759                                             uint32_t SrcLocStrSize,
760                                             IdentFlag LocFlags,
761                                             unsigned Reserve2Flags) {
762   // Enable "C-mode".
763   LocFlags |= OMP_IDENT_FLAG_KMPC;
764 
765   Constant *&Ident =
766       IdentMap[{SrcLocStr, uint64_t(LocFlags) << 31 | Reserve2Flags}];
767   if (!Ident) {
768     Constant *I32Null = ConstantInt::getNullValue(Int32);
769     Constant *IdentData[] = {I32Null,
770                              ConstantInt::get(Int32, uint32_t(LocFlags)),
771                              ConstantInt::get(Int32, Reserve2Flags),
772                              ConstantInt::get(Int32, SrcLocStrSize), SrcLocStr};
773     Constant *Initializer =
774         ConstantStruct::get(OpenMPIRBuilder::Ident, IdentData);
775 
776     // Look for existing encoding of the location + flags, not needed but
777     // minimizes the difference to the existing solution while we transition.
778     for (GlobalVariable &GV : M.globals())
779       if (GV.getValueType() == OpenMPIRBuilder::Ident && GV.hasInitializer())
780         if (GV.getInitializer() == Initializer)
781           Ident = &GV;
782 
783     if (!Ident) {
784       auto *GV = new GlobalVariable(
785           M, OpenMPIRBuilder::Ident,
786           /* isConstant = */ true, GlobalValue::PrivateLinkage, Initializer, "",
787           nullptr, GlobalValue::NotThreadLocal,
788           M.getDataLayout().getDefaultGlobalsAddressSpace());
789       GV->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
790       GV->setAlignment(Align(8));
791       Ident = GV;
792     }
793   }
794 
795   return ConstantExpr::getPointerBitCastOrAddrSpaceCast(Ident, IdentPtr);
796 }
797 
798 Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(StringRef LocStr,
799                                                 uint32_t &SrcLocStrSize) {
800   SrcLocStrSize = LocStr.size();
801   Constant *&SrcLocStr = SrcLocStrMap[LocStr];
802   if (!SrcLocStr) {
803     Constant *Initializer =
804         ConstantDataArray::getString(M.getContext(), LocStr);
805 
806     // Look for existing encoding of the location, not needed but minimizes the
807     // difference to the existing solution while we transition.
808     for (GlobalVariable &GV : M.globals())
809       if (GV.isConstant() && GV.hasInitializer() &&
810           GV.getInitializer() == Initializer)
811         return SrcLocStr = ConstantExpr::getPointerCast(&GV, Int8Ptr);
812 
813     SrcLocStr = Builder.CreateGlobalStringPtr(LocStr, /* Name */ "",
814                                               /* AddressSpace */ 0, &M);
815   }
816   return SrcLocStr;
817 }
818 
819 Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(StringRef FunctionName,
820                                                 StringRef FileName,
821                                                 unsigned Line, unsigned Column,
822                                                 uint32_t &SrcLocStrSize) {
823   SmallString<128> Buffer;
824   Buffer.push_back(';');
825   Buffer.append(FileName);
826   Buffer.push_back(';');
827   Buffer.append(FunctionName);
828   Buffer.push_back(';');
829   Buffer.append(std::to_string(Line));
830   Buffer.push_back(';');
831   Buffer.append(std::to_string(Column));
832   Buffer.push_back(';');
833   Buffer.push_back(';');
834   return getOrCreateSrcLocStr(Buffer.str(), SrcLocStrSize);
835 }
836 
837 Constant *
838 OpenMPIRBuilder::getOrCreateDefaultSrcLocStr(uint32_t &SrcLocStrSize) {
839   StringRef UnknownLoc = ";unknown;unknown;0;0;;";
840   return getOrCreateSrcLocStr(UnknownLoc, SrcLocStrSize);
841 }
842 
843 Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(DebugLoc DL,
844                                                 uint32_t &SrcLocStrSize,
845                                                 Function *F) {
846   DILocation *DIL = DL.get();
847   if (!DIL)
848     return getOrCreateDefaultSrcLocStr(SrcLocStrSize);
849   StringRef FileName = M.getName();
850   if (DIFile *DIF = DIL->getFile())
851     if (std::optional<StringRef> Source = DIF->getSource())
852       FileName = *Source;
853   StringRef Function = DIL->getScope()->getSubprogram()->getName();
854   if (Function.empty() && F)
855     Function = F->getName();
856   return getOrCreateSrcLocStr(Function, FileName, DIL->getLine(),
857                               DIL->getColumn(), SrcLocStrSize);
858 }
859 
860 Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(const LocationDescription &Loc,
861                                                 uint32_t &SrcLocStrSize) {
862   return getOrCreateSrcLocStr(Loc.DL, SrcLocStrSize,
863                               Loc.IP.getBlock()->getParent());
864 }
865 
866 Value *OpenMPIRBuilder::getOrCreateThreadID(Value *Ident) {
867   return Builder.CreateCall(
868       getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_global_thread_num), Ident,
869       "omp_global_thread_num");
870 }
871 
872 OpenMPIRBuilder::InsertPointTy
873 OpenMPIRBuilder::createBarrier(const LocationDescription &Loc, Directive DK,
874                                bool ForceSimpleCall, bool CheckCancelFlag) {
875   if (!updateToLocation(Loc))
876     return Loc.IP;
877   return emitBarrierImpl(Loc, DK, ForceSimpleCall, CheckCancelFlag);
878 }
879 
880 OpenMPIRBuilder::InsertPointTy
881 OpenMPIRBuilder::emitBarrierImpl(const LocationDescription &Loc, Directive Kind,
882                                  bool ForceSimpleCall, bool CheckCancelFlag) {
883   // Build call __kmpc_cancel_barrier(loc, thread_id) or
884   //            __kmpc_barrier(loc, thread_id);
885 
886   IdentFlag BarrierLocFlags;
887   switch (Kind) {
888   case OMPD_for:
889     BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL_FOR;
890     break;
891   case OMPD_sections:
892     BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL_SECTIONS;
893     break;
894   case OMPD_single:
895     BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL_SINGLE;
896     break;
897   case OMPD_barrier:
898     BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_EXPL;
899     break;
900   default:
901     BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL;
902     break;
903   }
904 
905   uint32_t SrcLocStrSize;
906   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
907   Value *Args[] = {
908       getOrCreateIdent(SrcLocStr, SrcLocStrSize, BarrierLocFlags),
909       getOrCreateThreadID(getOrCreateIdent(SrcLocStr, SrcLocStrSize))};
910 
911   // If we are in a cancellable parallel region, barriers are cancellation
912   // points.
913   // TODO: Check why we would force simple calls or to ignore the cancel flag.
914   bool UseCancelBarrier =
915       !ForceSimpleCall && isLastFinalizationInfoCancellable(OMPD_parallel);
916 
917   Value *Result =
918       Builder.CreateCall(getOrCreateRuntimeFunctionPtr(
919                              UseCancelBarrier ? OMPRTL___kmpc_cancel_barrier
920                                               : OMPRTL___kmpc_barrier),
921                          Args);
922 
923   if (UseCancelBarrier && CheckCancelFlag)
924     emitCancelationCheckImpl(Result, OMPD_parallel);
925 
926   return Builder.saveIP();
927 }
928 
929 OpenMPIRBuilder::InsertPointTy
930 OpenMPIRBuilder::createCancel(const LocationDescription &Loc,
931                               Value *IfCondition,
932                               omp::Directive CanceledDirective) {
933   if (!updateToLocation(Loc))
934     return Loc.IP;
935 
936   // LLVM utilities like blocks with terminators.
937   auto *UI = Builder.CreateUnreachable();
938 
939   Instruction *ThenTI = UI, *ElseTI = nullptr;
940   if (IfCondition)
941     SplitBlockAndInsertIfThenElse(IfCondition, UI, &ThenTI, &ElseTI);
942   Builder.SetInsertPoint(ThenTI);
943 
944   Value *CancelKind = nullptr;
945   switch (CanceledDirective) {
946 #define OMP_CANCEL_KIND(Enum, Str, DirectiveEnum, Value)                       \
947   case DirectiveEnum:                                                          \
948     CancelKind = Builder.getInt32(Value);                                      \
949     break;
950 #include "llvm/Frontend/OpenMP/OMPKinds.def"
951   default:
952     llvm_unreachable("Unknown cancel kind!");
953   }
954 
955   uint32_t SrcLocStrSize;
956   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
957   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
958   Value *Args[] = {Ident, getOrCreateThreadID(Ident), CancelKind};
959   Value *Result = Builder.CreateCall(
960       getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_cancel), Args);
961   auto ExitCB = [this, CanceledDirective, Loc](InsertPointTy IP) {
962     if (CanceledDirective == OMPD_parallel) {
963       IRBuilder<>::InsertPointGuard IPG(Builder);
964       Builder.restoreIP(IP);
965       createBarrier(LocationDescription(Builder.saveIP(), Loc.DL),
966                     omp::Directive::OMPD_unknown, /* ForceSimpleCall */ false,
967                     /* CheckCancelFlag */ false);
968     }
969   };
970 
971   // The actual cancel logic is shared with others, e.g., cancel_barriers.
972   emitCancelationCheckImpl(Result, CanceledDirective, ExitCB);
973 
974   // Update the insertion point and remove the terminator we introduced.
975   Builder.SetInsertPoint(UI->getParent());
976   UI->eraseFromParent();
977 
978   return Builder.saveIP();
979 }
980 
981 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitTargetKernel(
982     const LocationDescription &Loc, InsertPointTy AllocaIP, Value *&Return,
983     Value *Ident, Value *DeviceID, Value *NumTeams, Value *NumThreads,
984     Value *HostPtr, ArrayRef<Value *> KernelArgs) {
985   if (!updateToLocation(Loc))
986     return Loc.IP;
987 
988   Builder.restoreIP(AllocaIP);
989   auto *KernelArgsPtr =
990       Builder.CreateAlloca(OpenMPIRBuilder::KernelArgs, nullptr, "kernel_args");
991   Builder.restoreIP(Loc.IP);
992 
993   for (unsigned I = 0, Size = KernelArgs.size(); I != Size; ++I) {
994     llvm::Value *Arg =
995         Builder.CreateStructGEP(OpenMPIRBuilder::KernelArgs, KernelArgsPtr, I);
996     Builder.CreateAlignedStore(
997         KernelArgs[I], Arg,
998         M.getDataLayout().getPrefTypeAlign(KernelArgs[I]->getType()));
999   }
1000 
1001   SmallVector<Value *> OffloadingArgs{Ident,      DeviceID, NumTeams,
1002                                       NumThreads, HostPtr,  KernelArgsPtr};
1003 
1004   Return = Builder.CreateCall(
1005       getOrCreateRuntimeFunction(M, OMPRTL___tgt_target_kernel),
1006       OffloadingArgs);
1007 
1008   return Builder.saveIP();
1009 }
1010 
1011 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitKernelLaunch(
1012     const LocationDescription &Loc, Function *OutlinedFn, Value *OutlinedFnID,
1013     EmitFallbackCallbackTy emitTargetCallFallbackCB, TargetKernelArgs &Args,
1014     Value *DeviceID, Value *RTLoc, InsertPointTy AllocaIP) {
1015 
1016   if (!updateToLocation(Loc))
1017     return Loc.IP;
1018 
1019   Builder.restoreIP(Loc.IP);
1020   // On top of the arrays that were filled up, the target offloading call
1021   // takes as arguments the device id as well as the host pointer. The host
1022   // pointer is used by the runtime library to identify the current target
1023   // region, so it only has to be unique and not necessarily point to
1024   // anything. It could be the pointer to the outlined function that
1025   // implements the target region, but we aren't using that so that the
1026   // compiler doesn't need to keep that, and could therefore inline the host
1027   // function if proven worthwhile during optimization.
1028 
1029   // From this point on, we need to have an ID of the target region defined.
1030   assert(OutlinedFnID && "Invalid outlined function ID!");
1031   (void)OutlinedFnID;
1032 
1033   // Return value of the runtime offloading call.
1034   Value *Return = nullptr;
1035 
1036   // Arguments for the target kernel.
1037   SmallVector<Value *> ArgsVector;
1038   getKernelArgsVector(Args, Builder, ArgsVector);
1039 
1040   // The target region is an outlined function launched by the runtime
1041   // via calls to __tgt_target_kernel().
1042   //
1043   // Note that on the host and CPU targets, the runtime implementation of
1044   // these calls simply call the outlined function without forking threads.
1045   // The outlined functions themselves have runtime calls to
1046   // __kmpc_fork_teams() and __kmpc_fork() for this purpose, codegen'd by
1047   // the compiler in emitTeamsCall() and emitParallelCall().
1048   //
1049   // In contrast, on the NVPTX target, the implementation of
1050   // __tgt_target_teams() launches a GPU kernel with the requested number
1051   // of teams and threads so no additional calls to the runtime are required.
1052   // Check the error code and execute the host version if required.
1053   Builder.restoreIP(emitTargetKernel(Builder, AllocaIP, Return, RTLoc, DeviceID,
1054                                      Args.NumTeams, Args.NumThreads,
1055                                      OutlinedFnID, ArgsVector));
1056 
1057   BasicBlock *OffloadFailedBlock =
1058       BasicBlock::Create(Builder.getContext(), "omp_offload.failed");
1059   BasicBlock *OffloadContBlock =
1060       BasicBlock::Create(Builder.getContext(), "omp_offload.cont");
1061   Value *Failed = Builder.CreateIsNotNull(Return);
1062   Builder.CreateCondBr(Failed, OffloadFailedBlock, OffloadContBlock);
1063 
1064   auto CurFn = Builder.GetInsertBlock()->getParent();
1065   emitBlock(OffloadFailedBlock, CurFn);
1066   Builder.restoreIP(emitTargetCallFallbackCB(Builder.saveIP()));
1067   emitBranch(OffloadContBlock);
1068   emitBlock(OffloadContBlock, CurFn, /*IsFinished=*/true);
1069   return Builder.saveIP();
1070 }
1071 
1072 void OpenMPIRBuilder::emitCancelationCheckImpl(Value *CancelFlag,
1073                                                omp::Directive CanceledDirective,
1074                                                FinalizeCallbackTy ExitCB) {
1075   assert(isLastFinalizationInfoCancellable(CanceledDirective) &&
1076          "Unexpected cancellation!");
1077 
1078   // For a cancel barrier we create two new blocks.
1079   BasicBlock *BB = Builder.GetInsertBlock();
1080   BasicBlock *NonCancellationBlock;
1081   if (Builder.GetInsertPoint() == BB->end()) {
1082     // TODO: This branch will not be needed once we moved to the
1083     // OpenMPIRBuilder codegen completely.
1084     NonCancellationBlock = BasicBlock::Create(
1085         BB->getContext(), BB->getName() + ".cont", BB->getParent());
1086   } else {
1087     NonCancellationBlock = SplitBlock(BB, &*Builder.GetInsertPoint());
1088     BB->getTerminator()->eraseFromParent();
1089     Builder.SetInsertPoint(BB);
1090   }
1091   BasicBlock *CancellationBlock = BasicBlock::Create(
1092       BB->getContext(), BB->getName() + ".cncl", BB->getParent());
1093 
1094   // Jump to them based on the return value.
1095   Value *Cmp = Builder.CreateIsNull(CancelFlag);
1096   Builder.CreateCondBr(Cmp, NonCancellationBlock, CancellationBlock,
1097                        /* TODO weight */ nullptr, nullptr);
1098 
1099   // From the cancellation block we finalize all variables and go to the
1100   // post finalization block that is known to the FiniCB callback.
1101   Builder.SetInsertPoint(CancellationBlock);
1102   if (ExitCB)
1103     ExitCB(Builder.saveIP());
1104   auto &FI = FinalizationStack.back();
1105   FI.FiniCB(Builder.saveIP());
1106 
1107   // The continuation block is where code generation continues.
1108   Builder.SetInsertPoint(NonCancellationBlock, NonCancellationBlock->begin());
1109 }
1110 
1111 // Callback used to create OpenMP runtime calls to support
1112 // omp parallel clause for the device.
1113 // We need to use this callback to replace call to the OutlinedFn in OuterFn
1114 // by the call to the OpenMP DeviceRTL runtime function (kmpc_parallel_51)
1115 static void targetParallelCallback(
1116     OpenMPIRBuilder *OMPIRBuilder, Function &OutlinedFn, Function *OuterFn,
1117     BasicBlock *OuterAllocaBB, Value *Ident, Value *IfCondition,
1118     Value *NumThreads, Instruction *PrivTID, AllocaInst *PrivTIDAddr,
1119     Value *ThreadID, const SmallVector<Instruction *, 4> &ToBeDeleted) {
1120   // Add some known attributes.
1121   IRBuilder<> &Builder = OMPIRBuilder->Builder;
1122   OutlinedFn.addParamAttr(0, Attribute::NoAlias);
1123   OutlinedFn.addParamAttr(1, Attribute::NoAlias);
1124   OutlinedFn.addParamAttr(0, Attribute::NoUndef);
1125   OutlinedFn.addParamAttr(1, Attribute::NoUndef);
1126   OutlinedFn.addFnAttr(Attribute::NoUnwind);
1127 
1128   assert(OutlinedFn.arg_size() >= 2 &&
1129          "Expected at least tid and bounded tid as arguments");
1130   unsigned NumCapturedVars = OutlinedFn.arg_size() - /* tid & bounded tid */ 2;
1131 
1132   CallInst *CI = cast<CallInst>(OutlinedFn.user_back());
1133   assert(CI && "Expected call instruction to outlined function");
1134   CI->getParent()->setName("omp_parallel");
1135 
1136   Builder.SetInsertPoint(CI);
1137   Type *PtrTy = OMPIRBuilder->VoidPtr;
1138   Value *NullPtrValue = Constant::getNullValue(PtrTy);
1139 
1140   // Add alloca for kernel args
1141   OpenMPIRBuilder ::InsertPointTy CurrentIP = Builder.saveIP();
1142   Builder.SetInsertPoint(OuterAllocaBB, OuterAllocaBB->getFirstInsertionPt());
1143   AllocaInst *ArgsAlloca =
1144       Builder.CreateAlloca(ArrayType::get(PtrTy, NumCapturedVars));
1145   Value *Args = ArgsAlloca;
1146   // Add address space cast if array for storing arguments is not allocated
1147   // in address space 0
1148   if (ArgsAlloca->getAddressSpace())
1149     Args = Builder.CreatePointerCast(ArgsAlloca, PtrTy);
1150   Builder.restoreIP(CurrentIP);
1151 
1152   // Store captured vars which are used by kmpc_parallel_51
1153   for (unsigned Idx = 0; Idx < NumCapturedVars; Idx++) {
1154     Value *V = *(CI->arg_begin() + 2 + Idx);
1155     Value *StoreAddress = Builder.CreateConstInBoundsGEP2_64(
1156         ArrayType::get(PtrTy, NumCapturedVars), Args, 0, Idx);
1157     Builder.CreateStore(V, StoreAddress);
1158   }
1159 
1160   Value *Cond =
1161       IfCondition ? Builder.CreateSExtOrTrunc(IfCondition, OMPIRBuilder->Int32)
1162                   : Builder.getInt32(1);
1163 
1164   // Build kmpc_parallel_51 call
1165   Value *Parallel51CallArgs[] = {
1166       /* identifier*/ Ident,
1167       /* global thread num*/ ThreadID,
1168       /* if expression */ Cond,
1169       /* number of threads */ NumThreads ? NumThreads : Builder.getInt32(-1),
1170       /* Proc bind */ Builder.getInt32(-1),
1171       /* outlined function */
1172       Builder.CreateBitCast(&OutlinedFn, OMPIRBuilder->ParallelTaskPtr),
1173       /* wrapper function */ NullPtrValue,
1174       /* arguments of the outlined funciton*/ Args,
1175       /* number of arguments */ Builder.getInt64(NumCapturedVars)};
1176 
1177   FunctionCallee RTLFn =
1178       OMPIRBuilder->getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_parallel_51);
1179 
1180   Builder.CreateCall(RTLFn, Parallel51CallArgs);
1181 
1182   LLVM_DEBUG(dbgs() << "With kmpc_parallel_51 placed: "
1183                     << *Builder.GetInsertBlock()->getParent() << "\n");
1184 
1185   // Initialize the local TID stack location with the argument value.
1186   Builder.SetInsertPoint(PrivTID);
1187   Function::arg_iterator OutlinedAI = OutlinedFn.arg_begin();
1188   Builder.CreateStore(Builder.CreateLoad(OMPIRBuilder->Int32, OutlinedAI),
1189                       PrivTIDAddr);
1190 
1191   // Remove redundant call to the outlined function.
1192   CI->eraseFromParent();
1193 
1194   for (Instruction *I : ToBeDeleted) {
1195     I->eraseFromParent();
1196   }
1197 }
1198 
1199 // Callback used to create OpenMP runtime calls to support
1200 // omp parallel clause for the host.
1201 // We need to use this callback to replace call to the OutlinedFn in OuterFn
1202 // by the call to the OpenMP host runtime function ( __kmpc_fork_call[_if])
1203 static void
1204 hostParallelCallback(OpenMPIRBuilder *OMPIRBuilder, Function &OutlinedFn,
1205                      Function *OuterFn, Value *Ident, Value *IfCondition,
1206                      Instruction *PrivTID, AllocaInst *PrivTIDAddr,
1207                      const SmallVector<Instruction *, 4> &ToBeDeleted) {
1208   IRBuilder<> &Builder = OMPIRBuilder->Builder;
1209   FunctionCallee RTLFn;
1210   if (IfCondition) {
1211     RTLFn =
1212         OMPIRBuilder->getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_call_if);
1213   } else {
1214     RTLFn =
1215         OMPIRBuilder->getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_call);
1216   }
1217   if (auto *F = dyn_cast<Function>(RTLFn.getCallee())) {
1218     if (!F->hasMetadata(LLVMContext::MD_callback)) {
1219       LLVMContext &Ctx = F->getContext();
1220       MDBuilder MDB(Ctx);
1221       // Annotate the callback behavior of the __kmpc_fork_call:
1222       //  - The callback callee is argument number 2 (microtask).
1223       //  - The first two arguments of the callback callee are unknown (-1).
1224       //  - All variadic arguments to the __kmpc_fork_call are passed to the
1225       //    callback callee.
1226       F->addMetadata(LLVMContext::MD_callback,
1227                      *MDNode::get(Ctx, {MDB.createCallbackEncoding(
1228                                            2, {-1, -1},
1229                                            /* VarArgsArePassed */ true)}));
1230     }
1231   }
1232   // Add some known attributes.
1233   OutlinedFn.addParamAttr(0, Attribute::NoAlias);
1234   OutlinedFn.addParamAttr(1, Attribute::NoAlias);
1235   OutlinedFn.addFnAttr(Attribute::NoUnwind);
1236 
1237   assert(OutlinedFn.arg_size() >= 2 &&
1238          "Expected at least tid and bounded tid as arguments");
1239   unsigned NumCapturedVars = OutlinedFn.arg_size() - /* tid & bounded tid */ 2;
1240 
1241   CallInst *CI = cast<CallInst>(OutlinedFn.user_back());
1242   CI->getParent()->setName("omp_parallel");
1243   Builder.SetInsertPoint(CI);
1244 
1245   // Build call __kmpc_fork_call[_if](Ident, n, microtask, var1, .., varn);
1246   Value *ForkCallArgs[] = {
1247       Ident, Builder.getInt32(NumCapturedVars),
1248       Builder.CreateBitCast(&OutlinedFn, OMPIRBuilder->ParallelTaskPtr)};
1249 
1250   SmallVector<Value *, 16> RealArgs;
1251   RealArgs.append(std::begin(ForkCallArgs), std::end(ForkCallArgs));
1252   if (IfCondition) {
1253     Value *Cond = Builder.CreateSExtOrTrunc(IfCondition, OMPIRBuilder->Int32);
1254     RealArgs.push_back(Cond);
1255   }
1256   RealArgs.append(CI->arg_begin() + /* tid & bound tid */ 2, CI->arg_end());
1257 
1258   // __kmpc_fork_call_if always expects a void ptr as the last argument
1259   // If there are no arguments, pass a null pointer.
1260   auto PtrTy = OMPIRBuilder->VoidPtr;
1261   if (IfCondition && NumCapturedVars == 0) {
1262     Value *NullPtrValue = Constant::getNullValue(PtrTy);
1263     RealArgs.push_back(NullPtrValue);
1264   }
1265   if (IfCondition && RealArgs.back()->getType() != PtrTy)
1266     RealArgs.back() = Builder.CreateBitCast(RealArgs.back(), PtrTy);
1267 
1268   Builder.CreateCall(RTLFn, RealArgs);
1269 
1270   LLVM_DEBUG(dbgs() << "With fork_call placed: "
1271                     << *Builder.GetInsertBlock()->getParent() << "\n");
1272 
1273   // Initialize the local TID stack location with the argument value.
1274   Builder.SetInsertPoint(PrivTID);
1275   Function::arg_iterator OutlinedAI = OutlinedFn.arg_begin();
1276   Builder.CreateStore(Builder.CreateLoad(OMPIRBuilder->Int32, OutlinedAI),
1277                       PrivTIDAddr);
1278 
1279   // Remove redundant call to the outlined function.
1280   CI->eraseFromParent();
1281 
1282   for (Instruction *I : ToBeDeleted) {
1283     I->eraseFromParent();
1284   }
1285 }
1286 
1287 IRBuilder<>::InsertPoint OpenMPIRBuilder::createParallel(
1288     const LocationDescription &Loc, InsertPointTy OuterAllocaIP,
1289     BodyGenCallbackTy BodyGenCB, PrivatizeCallbackTy PrivCB,
1290     FinalizeCallbackTy FiniCB, Value *IfCondition, Value *NumThreads,
1291     omp::ProcBindKind ProcBind, bool IsCancellable) {
1292   assert(!isConflictIP(Loc.IP, OuterAllocaIP) && "IPs must not be ambiguous");
1293 
1294   if (!updateToLocation(Loc))
1295     return Loc.IP;
1296 
1297   uint32_t SrcLocStrSize;
1298   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1299   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1300   Value *ThreadID = getOrCreateThreadID(Ident);
1301   // If we generate code for the target device, we need to allocate
1302   // struct for aggregate params in the device default alloca address space.
1303   // OpenMP runtime requires that the params of the extracted functions are
1304   // passed as zero address space pointers. This flag ensures that extracted
1305   // function arguments are declared in zero address space
1306   bool ArgsInZeroAddressSpace = Config.isTargetDevice();
1307 
1308   // Build call __kmpc_push_num_threads(&Ident, global_tid, num_threads)
1309   // only if we compile for host side.
1310   if (NumThreads && !Config.isTargetDevice()) {
1311     Value *Args[] = {
1312         Ident, ThreadID,
1313         Builder.CreateIntCast(NumThreads, Int32, /*isSigned*/ false)};
1314     Builder.CreateCall(
1315         getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_num_threads), Args);
1316   }
1317 
1318   if (ProcBind != OMP_PROC_BIND_default) {
1319     // Build call __kmpc_push_proc_bind(&Ident, global_tid, proc_bind)
1320     Value *Args[] = {
1321         Ident, ThreadID,
1322         ConstantInt::get(Int32, unsigned(ProcBind), /*isSigned=*/true)};
1323     Builder.CreateCall(
1324         getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_proc_bind), Args);
1325   }
1326 
1327   BasicBlock *InsertBB = Builder.GetInsertBlock();
1328   Function *OuterFn = InsertBB->getParent();
1329 
1330   // Save the outer alloca block because the insertion iterator may get
1331   // invalidated and we still need this later.
1332   BasicBlock *OuterAllocaBlock = OuterAllocaIP.getBlock();
1333 
1334   // Vector to remember instructions we used only during the modeling but which
1335   // we want to delete at the end.
1336   SmallVector<Instruction *, 4> ToBeDeleted;
1337 
1338   // Change the location to the outer alloca insertion point to create and
1339   // initialize the allocas we pass into the parallel region.
1340   Builder.restoreIP(OuterAllocaIP);
1341   AllocaInst *TIDAddrAlloca = Builder.CreateAlloca(Int32, nullptr, "tid.addr");
1342   AllocaInst *ZeroAddrAlloca =
1343       Builder.CreateAlloca(Int32, nullptr, "zero.addr");
1344   Instruction *TIDAddr = TIDAddrAlloca;
1345   Instruction *ZeroAddr = ZeroAddrAlloca;
1346   if (ArgsInZeroAddressSpace && M.getDataLayout().getAllocaAddrSpace() != 0) {
1347     // Add additional casts to enforce pointers in zero address space
1348     TIDAddr = new AddrSpaceCastInst(
1349         TIDAddrAlloca, PointerType ::get(M.getContext(), 0), "tid.addr.ascast");
1350     TIDAddr->insertAfter(TIDAddrAlloca);
1351     ToBeDeleted.push_back(TIDAddr);
1352     ZeroAddr = new AddrSpaceCastInst(ZeroAddrAlloca,
1353                                      PointerType ::get(M.getContext(), 0),
1354                                      "zero.addr.ascast");
1355     ZeroAddr->insertAfter(ZeroAddrAlloca);
1356     ToBeDeleted.push_back(ZeroAddr);
1357   }
1358 
1359   // We only need TIDAddr and ZeroAddr for modeling purposes to get the
1360   // associated arguments in the outlined function, so we delete them later.
1361   ToBeDeleted.push_back(TIDAddrAlloca);
1362   ToBeDeleted.push_back(ZeroAddrAlloca);
1363 
1364   // Create an artificial insertion point that will also ensure the blocks we
1365   // are about to split are not degenerated.
1366   auto *UI = new UnreachableInst(Builder.getContext(), InsertBB);
1367 
1368   BasicBlock *EntryBB = UI->getParent();
1369   BasicBlock *PRegEntryBB = EntryBB->splitBasicBlock(UI, "omp.par.entry");
1370   BasicBlock *PRegBodyBB = PRegEntryBB->splitBasicBlock(UI, "omp.par.region");
1371   BasicBlock *PRegPreFiniBB =
1372       PRegBodyBB->splitBasicBlock(UI, "omp.par.pre_finalize");
1373   BasicBlock *PRegExitBB = PRegPreFiniBB->splitBasicBlock(UI, "omp.par.exit");
1374 
1375   auto FiniCBWrapper = [&](InsertPointTy IP) {
1376     // Hide "open-ended" blocks from the given FiniCB by setting the right jump
1377     // target to the region exit block.
1378     if (IP.getBlock()->end() == IP.getPoint()) {
1379       IRBuilder<>::InsertPointGuard IPG(Builder);
1380       Builder.restoreIP(IP);
1381       Instruction *I = Builder.CreateBr(PRegExitBB);
1382       IP = InsertPointTy(I->getParent(), I->getIterator());
1383     }
1384     assert(IP.getBlock()->getTerminator()->getNumSuccessors() == 1 &&
1385            IP.getBlock()->getTerminator()->getSuccessor(0) == PRegExitBB &&
1386            "Unexpected insertion point for finalization call!");
1387     return FiniCB(IP);
1388   };
1389 
1390   FinalizationStack.push_back({FiniCBWrapper, OMPD_parallel, IsCancellable});
1391 
1392   // Generate the privatization allocas in the block that will become the entry
1393   // of the outlined function.
1394   Builder.SetInsertPoint(PRegEntryBB->getTerminator());
1395   InsertPointTy InnerAllocaIP = Builder.saveIP();
1396 
1397   AllocaInst *PrivTIDAddr =
1398       Builder.CreateAlloca(Int32, nullptr, "tid.addr.local");
1399   Instruction *PrivTID = Builder.CreateLoad(Int32, PrivTIDAddr, "tid");
1400 
1401   // Add some fake uses for OpenMP provided arguments.
1402   ToBeDeleted.push_back(Builder.CreateLoad(Int32, TIDAddr, "tid.addr.use"));
1403   Instruction *ZeroAddrUse =
1404       Builder.CreateLoad(Int32, ZeroAddr, "zero.addr.use");
1405   ToBeDeleted.push_back(ZeroAddrUse);
1406 
1407   // EntryBB
1408   //   |
1409   //   V
1410   // PRegionEntryBB         <- Privatization allocas are placed here.
1411   //   |
1412   //   V
1413   // PRegionBodyBB          <- BodeGen is invoked here.
1414   //   |
1415   //   V
1416   // PRegPreFiniBB          <- The block we will start finalization from.
1417   //   |
1418   //   V
1419   // PRegionExitBB          <- A common exit to simplify block collection.
1420   //
1421 
1422   LLVM_DEBUG(dbgs() << "Before body codegen: " << *OuterFn << "\n");
1423 
1424   // Let the caller create the body.
1425   assert(BodyGenCB && "Expected body generation callback!");
1426   InsertPointTy CodeGenIP(PRegBodyBB, PRegBodyBB->begin());
1427   BodyGenCB(InnerAllocaIP, CodeGenIP);
1428 
1429   LLVM_DEBUG(dbgs() << "After  body codegen: " << *OuterFn << "\n");
1430 
1431   OutlineInfo OI;
1432   if (Config.isTargetDevice()) {
1433     // Generate OpenMP target specific runtime call
1434     OI.PostOutlineCB = [=, ToBeDeletedVec =
1435                                std::move(ToBeDeleted)](Function &OutlinedFn) {
1436       targetParallelCallback(this, OutlinedFn, OuterFn, OuterAllocaBlock, Ident,
1437                              IfCondition, NumThreads, PrivTID, PrivTIDAddr,
1438                              ThreadID, ToBeDeletedVec);
1439     };
1440   } else {
1441     // Generate OpenMP host runtime call
1442     OI.PostOutlineCB = [=, ToBeDeletedVec =
1443                                std::move(ToBeDeleted)](Function &OutlinedFn) {
1444       hostParallelCallback(this, OutlinedFn, OuterFn, Ident, IfCondition,
1445                            PrivTID, PrivTIDAddr, ToBeDeletedVec);
1446     };
1447   }
1448 
1449   // Adjust the finalization stack, verify the adjustment, and call the
1450   // finalize function a last time to finalize values between the pre-fini
1451   // block and the exit block if we left the parallel "the normal way".
1452   auto FiniInfo = FinalizationStack.pop_back_val();
1453   (void)FiniInfo;
1454   assert(FiniInfo.DK == OMPD_parallel &&
1455          "Unexpected finalization stack state!");
1456 
1457   Instruction *PRegPreFiniTI = PRegPreFiniBB->getTerminator();
1458 
1459   InsertPointTy PreFiniIP(PRegPreFiniBB, PRegPreFiniTI->getIterator());
1460   FiniCB(PreFiniIP);
1461 
1462   OI.OuterAllocaBB = OuterAllocaBlock;
1463   OI.EntryBB = PRegEntryBB;
1464   OI.ExitBB = PRegExitBB;
1465 
1466   SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
1467   SmallVector<BasicBlock *, 32> Blocks;
1468   OI.collectBlocks(ParallelRegionBlockSet, Blocks);
1469 
1470   // Ensure a single exit node for the outlined region by creating one.
1471   // We might have multiple incoming edges to the exit now due to finalizations,
1472   // e.g., cancel calls that cause the control flow to leave the region.
1473   BasicBlock *PRegOutlinedExitBB = PRegExitBB;
1474   PRegExitBB = SplitBlock(PRegExitBB, &*PRegExitBB->getFirstInsertionPt());
1475   PRegOutlinedExitBB->setName("omp.par.outlined.exit");
1476   Blocks.push_back(PRegOutlinedExitBB);
1477 
1478   CodeExtractorAnalysisCache CEAC(*OuterFn);
1479   CodeExtractor Extractor(Blocks, /* DominatorTree */ nullptr,
1480                           /* AggregateArgs */ false,
1481                           /* BlockFrequencyInfo */ nullptr,
1482                           /* BranchProbabilityInfo */ nullptr,
1483                           /* AssumptionCache */ nullptr,
1484                           /* AllowVarArgs */ true,
1485                           /* AllowAlloca */ true,
1486                           /* AllocationBlock */ OuterAllocaBlock,
1487                           /* Suffix */ ".omp_par", ArgsInZeroAddressSpace);
1488 
1489   // Find inputs to, outputs from the code region.
1490   BasicBlock *CommonExit = nullptr;
1491   SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
1492   Extractor.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
1493   Extractor.findInputsOutputs(Inputs, Outputs, SinkingCands);
1494 
1495   LLVM_DEBUG(dbgs() << "Before privatization: " << *OuterFn << "\n");
1496 
1497   FunctionCallee TIDRTLFn =
1498       getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_global_thread_num);
1499 
1500   auto PrivHelper = [&](Value &V) {
1501     if (&V == TIDAddr || &V == ZeroAddr) {
1502       OI.ExcludeArgsFromAggregate.push_back(&V);
1503       return;
1504     }
1505 
1506     SetVector<Use *> Uses;
1507     for (Use &U : V.uses())
1508       if (auto *UserI = dyn_cast<Instruction>(U.getUser()))
1509         if (ParallelRegionBlockSet.count(UserI->getParent()))
1510           Uses.insert(&U);
1511 
1512     // __kmpc_fork_call expects extra arguments as pointers. If the input
1513     // already has a pointer type, everything is fine. Otherwise, store the
1514     // value onto stack and load it back inside the to-be-outlined region. This
1515     // will ensure only the pointer will be passed to the function.
1516     // FIXME: if there are more than 15 trailing arguments, they must be
1517     // additionally packed in a struct.
1518     Value *Inner = &V;
1519     if (!V.getType()->isPointerTy()) {
1520       IRBuilder<>::InsertPointGuard Guard(Builder);
1521       LLVM_DEBUG(llvm::dbgs() << "Forwarding input as pointer: " << V << "\n");
1522 
1523       Builder.restoreIP(OuterAllocaIP);
1524       Value *Ptr =
1525           Builder.CreateAlloca(V.getType(), nullptr, V.getName() + ".reloaded");
1526 
1527       // Store to stack at end of the block that currently branches to the entry
1528       // block of the to-be-outlined region.
1529       Builder.SetInsertPoint(InsertBB,
1530                              InsertBB->getTerminator()->getIterator());
1531       Builder.CreateStore(&V, Ptr);
1532 
1533       // Load back next to allocations in the to-be-outlined region.
1534       Builder.restoreIP(InnerAllocaIP);
1535       Inner = Builder.CreateLoad(V.getType(), Ptr);
1536     }
1537 
1538     Value *ReplacementValue = nullptr;
1539     CallInst *CI = dyn_cast<CallInst>(&V);
1540     if (CI && CI->getCalledFunction() == TIDRTLFn.getCallee()) {
1541       ReplacementValue = PrivTID;
1542     } else {
1543       Builder.restoreIP(
1544           PrivCB(InnerAllocaIP, Builder.saveIP(), V, *Inner, ReplacementValue));
1545       assert(ReplacementValue &&
1546              "Expected copy/create callback to set replacement value!");
1547       if (ReplacementValue == &V)
1548         return;
1549     }
1550 
1551     for (Use *UPtr : Uses)
1552       UPtr->set(ReplacementValue);
1553   };
1554 
1555   // Reset the inner alloca insertion as it will be used for loading the values
1556   // wrapped into pointers before passing them into the to-be-outlined region.
1557   // Configure it to insert immediately after the fake use of zero address so
1558   // that they are available in the generated body and so that the
1559   // OpenMP-related values (thread ID and zero address pointers) remain leading
1560   // in the argument list.
1561   InnerAllocaIP = IRBuilder<>::InsertPoint(
1562       ZeroAddrUse->getParent(), ZeroAddrUse->getNextNode()->getIterator());
1563 
1564   // Reset the outer alloca insertion point to the entry of the relevant block
1565   // in case it was invalidated.
1566   OuterAllocaIP = IRBuilder<>::InsertPoint(
1567       OuterAllocaBlock, OuterAllocaBlock->getFirstInsertionPt());
1568 
1569   for (Value *Input : Inputs) {
1570     LLVM_DEBUG(dbgs() << "Captured input: " << *Input << "\n");
1571     PrivHelper(*Input);
1572   }
1573   LLVM_DEBUG({
1574     for (Value *Output : Outputs)
1575       LLVM_DEBUG(dbgs() << "Captured output: " << *Output << "\n");
1576   });
1577   assert(Outputs.empty() &&
1578          "OpenMP outlining should not produce live-out values!");
1579 
1580   LLVM_DEBUG(dbgs() << "After  privatization: " << *OuterFn << "\n");
1581   LLVM_DEBUG({
1582     for (auto *BB : Blocks)
1583       dbgs() << " PBR: " << BB->getName() << "\n";
1584   });
1585 
1586   // Register the outlined info.
1587   addOutlineInfo(std::move(OI));
1588 
1589   InsertPointTy AfterIP(UI->getParent(), UI->getParent()->end());
1590   UI->eraseFromParent();
1591 
1592   return AfterIP;
1593 }
1594 
1595 void OpenMPIRBuilder::emitFlush(const LocationDescription &Loc) {
1596   // Build call void __kmpc_flush(ident_t *loc)
1597   uint32_t SrcLocStrSize;
1598   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1599   Value *Args[] = {getOrCreateIdent(SrcLocStr, SrcLocStrSize)};
1600 
1601   Builder.CreateCall(getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_flush), Args);
1602 }
1603 
1604 void OpenMPIRBuilder::createFlush(const LocationDescription &Loc) {
1605   if (!updateToLocation(Loc))
1606     return;
1607   emitFlush(Loc);
1608 }
1609 
1610 void OpenMPIRBuilder::emitTaskwaitImpl(const LocationDescription &Loc) {
1611   // Build call kmp_int32 __kmpc_omp_taskwait(ident_t *loc, kmp_int32
1612   // global_tid);
1613   uint32_t SrcLocStrSize;
1614   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1615   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1616   Value *Args[] = {Ident, getOrCreateThreadID(Ident)};
1617 
1618   // Ignore return result until untied tasks are supported.
1619   Builder.CreateCall(getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_taskwait),
1620                      Args);
1621 }
1622 
1623 void OpenMPIRBuilder::createTaskwait(const LocationDescription &Loc) {
1624   if (!updateToLocation(Loc))
1625     return;
1626   emitTaskwaitImpl(Loc);
1627 }
1628 
1629 void OpenMPIRBuilder::emitTaskyieldImpl(const LocationDescription &Loc) {
1630   // Build call __kmpc_omp_taskyield(loc, thread_id, 0);
1631   uint32_t SrcLocStrSize;
1632   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1633   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1634   Constant *I32Null = ConstantInt::getNullValue(Int32);
1635   Value *Args[] = {Ident, getOrCreateThreadID(Ident), I32Null};
1636 
1637   Builder.CreateCall(getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_taskyield),
1638                      Args);
1639 }
1640 
1641 void OpenMPIRBuilder::createTaskyield(const LocationDescription &Loc) {
1642   if (!updateToLocation(Loc))
1643     return;
1644   emitTaskyieldImpl(Loc);
1645 }
1646 
1647 OpenMPIRBuilder::InsertPointTy
1648 OpenMPIRBuilder::createTask(const LocationDescription &Loc,
1649                             InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB,
1650                             bool Tied, Value *Final, Value *IfCondition,
1651                             SmallVector<DependData> Dependencies) {
1652 
1653   if (!updateToLocation(Loc))
1654     return InsertPointTy();
1655 
1656   uint32_t SrcLocStrSize;
1657   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1658   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1659   // The current basic block is split into four basic blocks. After outlining,
1660   // they will be mapped as follows:
1661   // ```
1662   // def current_fn() {
1663   //   current_basic_block:
1664   //     br label %task.exit
1665   //   task.exit:
1666   //     ; instructions after task
1667   // }
1668   // def outlined_fn() {
1669   //   task.alloca:
1670   //     br label %task.body
1671   //   task.body:
1672   //     ret void
1673   // }
1674   // ```
1675   BasicBlock *TaskExitBB = splitBB(Builder, /*CreateBranch=*/true, "task.exit");
1676   BasicBlock *TaskBodyBB = splitBB(Builder, /*CreateBranch=*/true, "task.body");
1677   BasicBlock *TaskAllocaBB =
1678       splitBB(Builder, /*CreateBranch=*/true, "task.alloca");
1679 
1680   InsertPointTy TaskAllocaIP =
1681       InsertPointTy(TaskAllocaBB, TaskAllocaBB->begin());
1682   InsertPointTy TaskBodyIP = InsertPointTy(TaskBodyBB, TaskBodyBB->begin());
1683   BodyGenCB(TaskAllocaIP, TaskBodyIP);
1684 
1685   OutlineInfo OI;
1686   OI.EntryBB = TaskAllocaBB;
1687   OI.OuterAllocaBB = AllocaIP.getBlock();
1688   OI.ExitBB = TaskExitBB;
1689 
1690   // Add the thread ID argument.
1691   std::stack<Instruction *> ToBeDeleted;
1692   OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
1693       Builder, AllocaIP, ToBeDeleted, TaskAllocaIP, "global.tid", false));
1694 
1695   OI.PostOutlineCB = [this, Ident, Tied, Final, IfCondition, Dependencies,
1696                       TaskAllocaBB, ToBeDeleted](Function &OutlinedFn) mutable {
1697     // Replace the Stale CI by appropriate RTL function call.
1698     assert(OutlinedFn.getNumUses() == 1 &&
1699            "there must be a single user for the outlined function");
1700     CallInst *StaleCI = cast<CallInst>(OutlinedFn.user_back());
1701 
1702     // HasShareds is true if any variables are captured in the outlined region,
1703     // false otherwise.
1704     bool HasShareds = StaleCI->arg_size() > 1;
1705     Builder.SetInsertPoint(StaleCI);
1706 
1707     // Gather the arguments for emitting the runtime call for
1708     // @__kmpc_omp_task_alloc
1709     Function *TaskAllocFn =
1710         getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_alloc);
1711 
1712     // Arguments - `loc_ref` (Ident) and `gtid` (ThreadID)
1713     // call.
1714     Value *ThreadID = getOrCreateThreadID(Ident);
1715 
1716     // Argument - `flags`
1717     // Task is tied iff (Flags & 1) == 1.
1718     // Task is untied iff (Flags & 1) == 0.
1719     // Task is final iff (Flags & 2) == 2.
1720     // Task is not final iff (Flags & 2) == 0.
1721     // TODO: Handle the other flags.
1722     Value *Flags = Builder.getInt32(Tied);
1723     if (Final) {
1724       Value *FinalFlag =
1725           Builder.CreateSelect(Final, Builder.getInt32(2), Builder.getInt32(0));
1726       Flags = Builder.CreateOr(FinalFlag, Flags);
1727     }
1728 
1729     // Argument - `sizeof_kmp_task_t` (TaskSize)
1730     // Tasksize refers to the size in bytes of kmp_task_t data structure
1731     // including private vars accessed in task.
1732     // TODO: add kmp_task_t_with_privates (privates)
1733     Value *TaskSize = Builder.getInt64(
1734         divideCeil(M.getDataLayout().getTypeSizeInBits(Task), 8));
1735 
1736     // Argument - `sizeof_shareds` (SharedsSize)
1737     // SharedsSize refers to the shareds array size in the kmp_task_t data
1738     // structure.
1739     Value *SharedsSize = Builder.getInt64(0);
1740     if (HasShareds) {
1741       AllocaInst *ArgStructAlloca =
1742           dyn_cast<AllocaInst>(StaleCI->getArgOperand(1));
1743       assert(ArgStructAlloca &&
1744              "Unable to find the alloca instruction corresponding to arguments "
1745              "for extracted function");
1746       StructType *ArgStructType =
1747           dyn_cast<StructType>(ArgStructAlloca->getAllocatedType());
1748       assert(ArgStructType && "Unable to find struct type corresponding to "
1749                               "arguments for extracted function");
1750       SharedsSize =
1751           Builder.getInt64(M.getDataLayout().getTypeStoreSize(ArgStructType));
1752     }
1753     // Emit the @__kmpc_omp_task_alloc runtime call
1754     // The runtime call returns a pointer to an area where the task captured
1755     // variables must be copied before the task is run (TaskData)
1756     CallInst *TaskData = Builder.CreateCall(
1757         TaskAllocFn, {/*loc_ref=*/Ident, /*gtid=*/ThreadID, /*flags=*/Flags,
1758                       /*sizeof_task=*/TaskSize, /*sizeof_shared=*/SharedsSize,
1759                       /*task_func=*/&OutlinedFn});
1760 
1761     // Copy the arguments for outlined function
1762     if (HasShareds) {
1763       Value *Shareds = StaleCI->getArgOperand(1);
1764       Align Alignment = TaskData->getPointerAlignment(M.getDataLayout());
1765       Value *TaskShareds = Builder.CreateLoad(VoidPtr, TaskData);
1766       Builder.CreateMemCpy(TaskShareds, Alignment, Shareds, Alignment,
1767                            SharedsSize);
1768     }
1769 
1770     Value *DepArray = nullptr;
1771     if (Dependencies.size()) {
1772       InsertPointTy OldIP = Builder.saveIP();
1773       Builder.SetInsertPoint(
1774           &OldIP.getBlock()->getParent()->getEntryBlock().back());
1775 
1776       Type *DepArrayTy = ArrayType::get(DependInfo, Dependencies.size());
1777       DepArray = Builder.CreateAlloca(DepArrayTy, nullptr, ".dep.arr.addr");
1778 
1779       unsigned P = 0;
1780       for (const DependData &Dep : Dependencies) {
1781         Value *Base =
1782             Builder.CreateConstInBoundsGEP2_64(DepArrayTy, DepArray, 0, P);
1783         // Store the pointer to the variable
1784         Value *Addr = Builder.CreateStructGEP(
1785             DependInfo, Base,
1786             static_cast<unsigned int>(RTLDependInfoFields::BaseAddr));
1787         Value *DepValPtr =
1788             Builder.CreatePtrToInt(Dep.DepVal, Builder.getInt64Ty());
1789         Builder.CreateStore(DepValPtr, Addr);
1790         // Store the size of the variable
1791         Value *Size = Builder.CreateStructGEP(
1792             DependInfo, Base,
1793             static_cast<unsigned int>(RTLDependInfoFields::Len));
1794         Builder.CreateStore(Builder.getInt64(M.getDataLayout().getTypeStoreSize(
1795                                 Dep.DepValueType)),
1796                             Size);
1797         // Store the dependency kind
1798         Value *Flags = Builder.CreateStructGEP(
1799             DependInfo, Base,
1800             static_cast<unsigned int>(RTLDependInfoFields::Flags));
1801         Builder.CreateStore(
1802             ConstantInt::get(Builder.getInt8Ty(),
1803                              static_cast<unsigned int>(Dep.DepKind)),
1804             Flags);
1805         ++P;
1806       }
1807 
1808       Builder.restoreIP(OldIP);
1809     }
1810 
1811     // In the presence of the `if` clause, the following IR is generated:
1812     //    ...
1813     //    %data = call @__kmpc_omp_task_alloc(...)
1814     //    br i1 %if_condition, label %then, label %else
1815     //  then:
1816     //    call @__kmpc_omp_task(...)
1817     //    br label %exit
1818     //  else:
1819     //    call @__kmpc_omp_task_begin_if0(...)
1820     //    call @outlined_fn(...)
1821     //    call @__kmpc_omp_task_complete_if0(...)
1822     //    br label %exit
1823     //  exit:
1824     //    ...
1825     if (IfCondition) {
1826       // `SplitBlockAndInsertIfThenElse` requires the block to have a
1827       // terminator.
1828       splitBB(Builder, /*CreateBranch=*/true, "if.end");
1829       Instruction *IfTerminator =
1830           Builder.GetInsertPoint()->getParent()->getTerminator();
1831       Instruction *ThenTI = IfTerminator, *ElseTI = nullptr;
1832       Builder.SetInsertPoint(IfTerminator);
1833       SplitBlockAndInsertIfThenElse(IfCondition, IfTerminator, &ThenTI,
1834                                     &ElseTI);
1835       Builder.SetInsertPoint(ElseTI);
1836       Function *TaskBeginFn =
1837           getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_begin_if0);
1838       Function *TaskCompleteFn =
1839           getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_complete_if0);
1840       Builder.CreateCall(TaskBeginFn, {Ident, ThreadID, TaskData});
1841       CallInst *CI = nullptr;
1842       if (HasShareds)
1843         CI = Builder.CreateCall(&OutlinedFn, {ThreadID, TaskData});
1844       else
1845         CI = Builder.CreateCall(&OutlinedFn, {ThreadID});
1846       CI->setDebugLoc(StaleCI->getDebugLoc());
1847       Builder.CreateCall(TaskCompleteFn, {Ident, ThreadID, TaskData});
1848       Builder.SetInsertPoint(ThenTI);
1849     }
1850 
1851     if (Dependencies.size()) {
1852       Function *TaskFn =
1853           getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_with_deps);
1854       Builder.CreateCall(
1855           TaskFn,
1856           {Ident, ThreadID, TaskData, Builder.getInt32(Dependencies.size()),
1857            DepArray, ConstantInt::get(Builder.getInt32Ty(), 0),
1858            ConstantPointerNull::get(PointerType::getUnqual(M.getContext()))});
1859 
1860     } else {
1861       // Emit the @__kmpc_omp_task runtime call to spawn the task
1862       Function *TaskFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task);
1863       Builder.CreateCall(TaskFn, {Ident, ThreadID, TaskData});
1864     }
1865 
1866     StaleCI->eraseFromParent();
1867 
1868     Builder.SetInsertPoint(TaskAllocaBB, TaskAllocaBB->begin());
1869     if (HasShareds) {
1870       LoadInst *Shareds = Builder.CreateLoad(VoidPtr, OutlinedFn.getArg(1));
1871       OutlinedFn.getArg(1)->replaceUsesWithIf(
1872           Shareds, [Shareds](Use &U) { return U.getUser() != Shareds; });
1873     }
1874 
1875     while (!ToBeDeleted.empty()) {
1876       ToBeDeleted.top()->eraseFromParent();
1877       ToBeDeleted.pop();
1878     }
1879   };
1880 
1881   addOutlineInfo(std::move(OI));
1882   Builder.SetInsertPoint(TaskExitBB, TaskExitBB->begin());
1883 
1884   return Builder.saveIP();
1885 }
1886 
1887 OpenMPIRBuilder::InsertPointTy
1888 OpenMPIRBuilder::createTaskgroup(const LocationDescription &Loc,
1889                                  InsertPointTy AllocaIP,
1890                                  BodyGenCallbackTy BodyGenCB) {
1891   if (!updateToLocation(Loc))
1892     return InsertPointTy();
1893 
1894   uint32_t SrcLocStrSize;
1895   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1896   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1897   Value *ThreadID = getOrCreateThreadID(Ident);
1898 
1899   // Emit the @__kmpc_taskgroup runtime call to start the taskgroup
1900   Function *TaskgroupFn =
1901       getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_taskgroup);
1902   Builder.CreateCall(TaskgroupFn, {Ident, ThreadID});
1903 
1904   BasicBlock *TaskgroupExitBB = splitBB(Builder, true, "taskgroup.exit");
1905   BodyGenCB(AllocaIP, Builder.saveIP());
1906 
1907   Builder.SetInsertPoint(TaskgroupExitBB);
1908   // Emit the @__kmpc_end_taskgroup runtime call to end the taskgroup
1909   Function *EndTaskgroupFn =
1910       getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_taskgroup);
1911   Builder.CreateCall(EndTaskgroupFn, {Ident, ThreadID});
1912 
1913   return Builder.saveIP();
1914 }
1915 
1916 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createSections(
1917     const LocationDescription &Loc, InsertPointTy AllocaIP,
1918     ArrayRef<StorableBodyGenCallbackTy> SectionCBs, PrivatizeCallbackTy PrivCB,
1919     FinalizeCallbackTy FiniCB, bool IsCancellable, bool IsNowait) {
1920   assert(!isConflictIP(AllocaIP, Loc.IP) && "Dedicated IP allocas required");
1921 
1922   if (!updateToLocation(Loc))
1923     return Loc.IP;
1924 
1925   auto FiniCBWrapper = [&](InsertPointTy IP) {
1926     if (IP.getBlock()->end() != IP.getPoint())
1927       return FiniCB(IP);
1928     // This must be done otherwise any nested constructs using FinalizeOMPRegion
1929     // will fail because that function requires the Finalization Basic Block to
1930     // have a terminator, which is already removed by EmitOMPRegionBody.
1931     // IP is currently at cancelation block.
1932     // We need to backtrack to the condition block to fetch
1933     // the exit block and create a branch from cancelation
1934     // to exit block.
1935     IRBuilder<>::InsertPointGuard IPG(Builder);
1936     Builder.restoreIP(IP);
1937     auto *CaseBB = IP.getBlock()->getSinglePredecessor();
1938     auto *CondBB = CaseBB->getSinglePredecessor()->getSinglePredecessor();
1939     auto *ExitBB = CondBB->getTerminator()->getSuccessor(1);
1940     Instruction *I = Builder.CreateBr(ExitBB);
1941     IP = InsertPointTy(I->getParent(), I->getIterator());
1942     return FiniCB(IP);
1943   };
1944 
1945   FinalizationStack.push_back({FiniCBWrapper, OMPD_sections, IsCancellable});
1946 
1947   // Each section is emitted as a switch case
1948   // Each finalization callback is handled from clang.EmitOMPSectionDirective()
1949   // -> OMP.createSection() which generates the IR for each section
1950   // Iterate through all sections and emit a switch construct:
1951   // switch (IV) {
1952   //   case 0:
1953   //     <SectionStmt[0]>;
1954   //     break;
1955   // ...
1956   //   case <NumSection> - 1:
1957   //     <SectionStmt[<NumSection> - 1]>;
1958   //     break;
1959   // }
1960   // ...
1961   // section_loop.after:
1962   // <FiniCB>;
1963   auto LoopBodyGenCB = [&](InsertPointTy CodeGenIP, Value *IndVar) {
1964     Builder.restoreIP(CodeGenIP);
1965     BasicBlock *Continue =
1966         splitBBWithSuffix(Builder, /*CreateBranch=*/false, ".sections.after");
1967     Function *CurFn = Continue->getParent();
1968     SwitchInst *SwitchStmt = Builder.CreateSwitch(IndVar, Continue);
1969 
1970     unsigned CaseNumber = 0;
1971     for (auto SectionCB : SectionCBs) {
1972       BasicBlock *CaseBB = BasicBlock::Create(
1973           M.getContext(), "omp_section_loop.body.case", CurFn, Continue);
1974       SwitchStmt->addCase(Builder.getInt32(CaseNumber), CaseBB);
1975       Builder.SetInsertPoint(CaseBB);
1976       BranchInst *CaseEndBr = Builder.CreateBr(Continue);
1977       SectionCB(InsertPointTy(),
1978                 {CaseEndBr->getParent(), CaseEndBr->getIterator()});
1979       CaseNumber++;
1980     }
1981     // remove the existing terminator from body BB since there can be no
1982     // terminators after switch/case
1983   };
1984   // Loop body ends here
1985   // LowerBound, UpperBound, and STride for createCanonicalLoop
1986   Type *I32Ty = Type::getInt32Ty(M.getContext());
1987   Value *LB = ConstantInt::get(I32Ty, 0);
1988   Value *UB = ConstantInt::get(I32Ty, SectionCBs.size());
1989   Value *ST = ConstantInt::get(I32Ty, 1);
1990   llvm::CanonicalLoopInfo *LoopInfo = createCanonicalLoop(
1991       Loc, LoopBodyGenCB, LB, UB, ST, true, false, AllocaIP, "section_loop");
1992   InsertPointTy AfterIP =
1993       applyStaticWorkshareLoop(Loc.DL, LoopInfo, AllocaIP, !IsNowait);
1994 
1995   // Apply the finalization callback in LoopAfterBB
1996   auto FiniInfo = FinalizationStack.pop_back_val();
1997   assert(FiniInfo.DK == OMPD_sections &&
1998          "Unexpected finalization stack state!");
1999   if (FinalizeCallbackTy &CB = FiniInfo.FiniCB) {
2000     Builder.restoreIP(AfterIP);
2001     BasicBlock *FiniBB =
2002         splitBBWithSuffix(Builder, /*CreateBranch=*/true, "sections.fini");
2003     CB(Builder.saveIP());
2004     AfterIP = {FiniBB, FiniBB->begin()};
2005   }
2006 
2007   return AfterIP;
2008 }
2009 
2010 OpenMPIRBuilder::InsertPointTy
2011 OpenMPIRBuilder::createSection(const LocationDescription &Loc,
2012                                BodyGenCallbackTy BodyGenCB,
2013                                FinalizeCallbackTy FiniCB) {
2014   if (!updateToLocation(Loc))
2015     return Loc.IP;
2016 
2017   auto FiniCBWrapper = [&](InsertPointTy IP) {
2018     if (IP.getBlock()->end() != IP.getPoint())
2019       return FiniCB(IP);
2020     // This must be done otherwise any nested constructs using FinalizeOMPRegion
2021     // will fail because that function requires the Finalization Basic Block to
2022     // have a terminator, which is already removed by EmitOMPRegionBody.
2023     // IP is currently at cancelation block.
2024     // We need to backtrack to the condition block to fetch
2025     // the exit block and create a branch from cancelation
2026     // to exit block.
2027     IRBuilder<>::InsertPointGuard IPG(Builder);
2028     Builder.restoreIP(IP);
2029     auto *CaseBB = Loc.IP.getBlock();
2030     auto *CondBB = CaseBB->getSinglePredecessor()->getSinglePredecessor();
2031     auto *ExitBB = CondBB->getTerminator()->getSuccessor(1);
2032     Instruction *I = Builder.CreateBr(ExitBB);
2033     IP = InsertPointTy(I->getParent(), I->getIterator());
2034     return FiniCB(IP);
2035   };
2036 
2037   Directive OMPD = Directive::OMPD_sections;
2038   // Since we are using Finalization Callback here, HasFinalize
2039   // and IsCancellable have to be true
2040   return EmitOMPInlinedRegion(OMPD, nullptr, nullptr, BodyGenCB, FiniCBWrapper,
2041                               /*Conditional*/ false, /*hasFinalize*/ true,
2042                               /*IsCancellable*/ true);
2043 }
2044 
2045 /// Create a function with a unique name and a "void (i8*, i8*)" signature in
2046 /// the given module and return it.
2047 Function *getFreshReductionFunc(Module &M) {
2048   Type *VoidTy = Type::getVoidTy(M.getContext());
2049   Type *Int8PtrTy = PointerType::getUnqual(M.getContext());
2050   auto *FuncTy =
2051       FunctionType::get(VoidTy, {Int8PtrTy, Int8PtrTy}, /* IsVarArg */ false);
2052   return Function::Create(FuncTy, GlobalVariable::InternalLinkage,
2053                           M.getDataLayout().getDefaultGlobalsAddressSpace(),
2054                           ".omp.reduction.func", &M);
2055 }
2056 
2057 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createReductions(
2058     const LocationDescription &Loc, InsertPointTy AllocaIP,
2059     ArrayRef<ReductionInfo> ReductionInfos, bool IsNoWait) {
2060   for (const ReductionInfo &RI : ReductionInfos) {
2061     (void)RI;
2062     assert(RI.Variable && "expected non-null variable");
2063     assert(RI.PrivateVariable && "expected non-null private variable");
2064     assert(RI.ReductionGen && "expected non-null reduction generator callback");
2065     assert(RI.Variable->getType() == RI.PrivateVariable->getType() &&
2066            "expected variables and their private equivalents to have the same "
2067            "type");
2068     assert(RI.Variable->getType()->isPointerTy() &&
2069            "expected variables to be pointers");
2070   }
2071 
2072   if (!updateToLocation(Loc))
2073     return InsertPointTy();
2074 
2075   BasicBlock *InsertBlock = Loc.IP.getBlock();
2076   BasicBlock *ContinuationBlock =
2077       InsertBlock->splitBasicBlock(Loc.IP.getPoint(), "reduce.finalize");
2078   InsertBlock->getTerminator()->eraseFromParent();
2079 
2080   // Create and populate array of type-erased pointers to private reduction
2081   // values.
2082   unsigned NumReductions = ReductionInfos.size();
2083   Type *RedArrayTy = ArrayType::get(Builder.getPtrTy(), NumReductions);
2084   Builder.restoreIP(AllocaIP);
2085   Value *RedArray = Builder.CreateAlloca(RedArrayTy, nullptr, "red.array");
2086 
2087   Builder.SetInsertPoint(InsertBlock, InsertBlock->end());
2088 
2089   for (auto En : enumerate(ReductionInfos)) {
2090     unsigned Index = En.index();
2091     const ReductionInfo &RI = En.value();
2092     Value *RedArrayElemPtr = Builder.CreateConstInBoundsGEP2_64(
2093         RedArrayTy, RedArray, 0, Index, "red.array.elem." + Twine(Index));
2094     Builder.CreateStore(RI.PrivateVariable, RedArrayElemPtr);
2095   }
2096 
2097   // Emit a call to the runtime function that orchestrates the reduction.
2098   // Declare the reduction function in the process.
2099   Function *Func = Builder.GetInsertBlock()->getParent();
2100   Module *Module = Func->getParent();
2101   uint32_t SrcLocStrSize;
2102   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
2103   bool CanGenerateAtomic =
2104       llvm::all_of(ReductionInfos, [](const ReductionInfo &RI) {
2105         return RI.AtomicReductionGen;
2106       });
2107   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize,
2108                                   CanGenerateAtomic
2109                                       ? IdentFlag::OMP_IDENT_FLAG_ATOMIC_REDUCE
2110                                       : IdentFlag(0));
2111   Value *ThreadId = getOrCreateThreadID(Ident);
2112   Constant *NumVariables = Builder.getInt32(NumReductions);
2113   const DataLayout &DL = Module->getDataLayout();
2114   unsigned RedArrayByteSize = DL.getTypeStoreSize(RedArrayTy);
2115   Constant *RedArraySize = Builder.getInt64(RedArrayByteSize);
2116   Function *ReductionFunc = getFreshReductionFunc(*Module);
2117   Value *Lock = getOMPCriticalRegionLock(".reduction");
2118   Function *ReduceFunc = getOrCreateRuntimeFunctionPtr(
2119       IsNoWait ? RuntimeFunction::OMPRTL___kmpc_reduce_nowait
2120                : RuntimeFunction::OMPRTL___kmpc_reduce);
2121   CallInst *ReduceCall =
2122       Builder.CreateCall(ReduceFunc,
2123                          {Ident, ThreadId, NumVariables, RedArraySize, RedArray,
2124                           ReductionFunc, Lock},
2125                          "reduce");
2126 
2127   // Create final reduction entry blocks for the atomic and non-atomic case.
2128   // Emit IR that dispatches control flow to one of the blocks based on the
2129   // reduction supporting the atomic mode.
2130   BasicBlock *NonAtomicRedBlock =
2131       BasicBlock::Create(Module->getContext(), "reduce.switch.nonatomic", Func);
2132   BasicBlock *AtomicRedBlock =
2133       BasicBlock::Create(Module->getContext(), "reduce.switch.atomic", Func);
2134   SwitchInst *Switch =
2135       Builder.CreateSwitch(ReduceCall, ContinuationBlock, /* NumCases */ 2);
2136   Switch->addCase(Builder.getInt32(1), NonAtomicRedBlock);
2137   Switch->addCase(Builder.getInt32(2), AtomicRedBlock);
2138 
2139   // Populate the non-atomic reduction using the elementwise reduction function.
2140   // This loads the elements from the global and private variables and reduces
2141   // them before storing back the result to the global variable.
2142   Builder.SetInsertPoint(NonAtomicRedBlock);
2143   for (auto En : enumerate(ReductionInfos)) {
2144     const ReductionInfo &RI = En.value();
2145     Type *ValueType = RI.ElementType;
2146     Value *RedValue = Builder.CreateLoad(ValueType, RI.Variable,
2147                                          "red.value." + Twine(En.index()));
2148     Value *PrivateRedValue =
2149         Builder.CreateLoad(ValueType, RI.PrivateVariable,
2150                            "red.private.value." + Twine(En.index()));
2151     Value *Reduced;
2152     Builder.restoreIP(
2153         RI.ReductionGen(Builder.saveIP(), RedValue, PrivateRedValue, Reduced));
2154     if (!Builder.GetInsertBlock())
2155       return InsertPointTy();
2156     Builder.CreateStore(Reduced, RI.Variable);
2157   }
2158   Function *EndReduceFunc = getOrCreateRuntimeFunctionPtr(
2159       IsNoWait ? RuntimeFunction::OMPRTL___kmpc_end_reduce_nowait
2160                : RuntimeFunction::OMPRTL___kmpc_end_reduce);
2161   Builder.CreateCall(EndReduceFunc, {Ident, ThreadId, Lock});
2162   Builder.CreateBr(ContinuationBlock);
2163 
2164   // Populate the atomic reduction using the atomic elementwise reduction
2165   // function. There are no loads/stores here because they will be happening
2166   // inside the atomic elementwise reduction.
2167   Builder.SetInsertPoint(AtomicRedBlock);
2168   if (CanGenerateAtomic) {
2169     for (const ReductionInfo &RI : ReductionInfos) {
2170       Builder.restoreIP(RI.AtomicReductionGen(Builder.saveIP(), RI.ElementType,
2171                                               RI.Variable, RI.PrivateVariable));
2172       if (!Builder.GetInsertBlock())
2173         return InsertPointTy();
2174     }
2175     Builder.CreateBr(ContinuationBlock);
2176   } else {
2177     Builder.CreateUnreachable();
2178   }
2179 
2180   // Populate the outlined reduction function using the elementwise reduction
2181   // function. Partial values are extracted from the type-erased array of
2182   // pointers to private variables.
2183   BasicBlock *ReductionFuncBlock =
2184       BasicBlock::Create(Module->getContext(), "", ReductionFunc);
2185   Builder.SetInsertPoint(ReductionFuncBlock);
2186   Value *LHSArrayPtr = ReductionFunc->getArg(0);
2187   Value *RHSArrayPtr = ReductionFunc->getArg(1);
2188 
2189   for (auto En : enumerate(ReductionInfos)) {
2190     const ReductionInfo &RI = En.value();
2191     Value *LHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
2192         RedArrayTy, LHSArrayPtr, 0, En.index());
2193     Value *LHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), LHSI8PtrPtr);
2194     Value *LHSPtr = Builder.CreateBitCast(LHSI8Ptr, RI.Variable->getType());
2195     Value *LHS = Builder.CreateLoad(RI.ElementType, LHSPtr);
2196     Value *RHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
2197         RedArrayTy, RHSArrayPtr, 0, En.index());
2198     Value *RHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), RHSI8PtrPtr);
2199     Value *RHSPtr =
2200         Builder.CreateBitCast(RHSI8Ptr, RI.PrivateVariable->getType());
2201     Value *RHS = Builder.CreateLoad(RI.ElementType, RHSPtr);
2202     Value *Reduced;
2203     Builder.restoreIP(RI.ReductionGen(Builder.saveIP(), LHS, RHS, Reduced));
2204     if (!Builder.GetInsertBlock())
2205       return InsertPointTy();
2206     Builder.CreateStore(Reduced, LHSPtr);
2207   }
2208   Builder.CreateRetVoid();
2209 
2210   Builder.SetInsertPoint(ContinuationBlock);
2211   return Builder.saveIP();
2212 }
2213 
2214 OpenMPIRBuilder::InsertPointTy
2215 OpenMPIRBuilder::createMaster(const LocationDescription &Loc,
2216                               BodyGenCallbackTy BodyGenCB,
2217                               FinalizeCallbackTy FiniCB) {
2218 
2219   if (!updateToLocation(Loc))
2220     return Loc.IP;
2221 
2222   Directive OMPD = Directive::OMPD_master;
2223   uint32_t SrcLocStrSize;
2224   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
2225   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
2226   Value *ThreadId = getOrCreateThreadID(Ident);
2227   Value *Args[] = {Ident, ThreadId};
2228 
2229   Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_master);
2230   Instruction *EntryCall = Builder.CreateCall(EntryRTLFn, Args);
2231 
2232   Function *ExitRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_master);
2233   Instruction *ExitCall = Builder.CreateCall(ExitRTLFn, Args);
2234 
2235   return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
2236                               /*Conditional*/ true, /*hasFinalize*/ true);
2237 }
2238 
2239 OpenMPIRBuilder::InsertPointTy
2240 OpenMPIRBuilder::createMasked(const LocationDescription &Loc,
2241                               BodyGenCallbackTy BodyGenCB,
2242                               FinalizeCallbackTy FiniCB, Value *Filter) {
2243   if (!updateToLocation(Loc))
2244     return Loc.IP;
2245 
2246   Directive OMPD = Directive::OMPD_masked;
2247   uint32_t SrcLocStrSize;
2248   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
2249   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
2250   Value *ThreadId = getOrCreateThreadID(Ident);
2251   Value *Args[] = {Ident, ThreadId, Filter};
2252   Value *ArgsEnd[] = {Ident, ThreadId};
2253 
2254   Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_masked);
2255   Instruction *EntryCall = Builder.CreateCall(EntryRTLFn, Args);
2256 
2257   Function *ExitRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_masked);
2258   Instruction *ExitCall = Builder.CreateCall(ExitRTLFn, ArgsEnd);
2259 
2260   return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
2261                               /*Conditional*/ true, /*hasFinalize*/ true);
2262 }
2263 
2264 CanonicalLoopInfo *OpenMPIRBuilder::createLoopSkeleton(
2265     DebugLoc DL, Value *TripCount, Function *F, BasicBlock *PreInsertBefore,
2266     BasicBlock *PostInsertBefore, const Twine &Name) {
2267   Module *M = F->getParent();
2268   LLVMContext &Ctx = M->getContext();
2269   Type *IndVarTy = TripCount->getType();
2270 
2271   // Create the basic block structure.
2272   BasicBlock *Preheader =
2273       BasicBlock::Create(Ctx, "omp_" + Name + ".preheader", F, PreInsertBefore);
2274   BasicBlock *Header =
2275       BasicBlock::Create(Ctx, "omp_" + Name + ".header", F, PreInsertBefore);
2276   BasicBlock *Cond =
2277       BasicBlock::Create(Ctx, "omp_" + Name + ".cond", F, PreInsertBefore);
2278   BasicBlock *Body =
2279       BasicBlock::Create(Ctx, "omp_" + Name + ".body", F, PreInsertBefore);
2280   BasicBlock *Latch =
2281       BasicBlock::Create(Ctx, "omp_" + Name + ".inc", F, PostInsertBefore);
2282   BasicBlock *Exit =
2283       BasicBlock::Create(Ctx, "omp_" + Name + ".exit", F, PostInsertBefore);
2284   BasicBlock *After =
2285       BasicBlock::Create(Ctx, "omp_" + Name + ".after", F, PostInsertBefore);
2286 
2287   // Use specified DebugLoc for new instructions.
2288   Builder.SetCurrentDebugLocation(DL);
2289 
2290   Builder.SetInsertPoint(Preheader);
2291   Builder.CreateBr(Header);
2292 
2293   Builder.SetInsertPoint(Header);
2294   PHINode *IndVarPHI = Builder.CreatePHI(IndVarTy, 2, "omp_" + Name + ".iv");
2295   IndVarPHI->addIncoming(ConstantInt::get(IndVarTy, 0), Preheader);
2296   Builder.CreateBr(Cond);
2297 
2298   Builder.SetInsertPoint(Cond);
2299   Value *Cmp =
2300       Builder.CreateICmpULT(IndVarPHI, TripCount, "omp_" + Name + ".cmp");
2301   Builder.CreateCondBr(Cmp, Body, Exit);
2302 
2303   Builder.SetInsertPoint(Body);
2304   Builder.CreateBr(Latch);
2305 
2306   Builder.SetInsertPoint(Latch);
2307   Value *Next = Builder.CreateAdd(IndVarPHI, ConstantInt::get(IndVarTy, 1),
2308                                   "omp_" + Name + ".next", /*HasNUW=*/true);
2309   Builder.CreateBr(Header);
2310   IndVarPHI->addIncoming(Next, Latch);
2311 
2312   Builder.SetInsertPoint(Exit);
2313   Builder.CreateBr(After);
2314 
2315   // Remember and return the canonical control flow.
2316   LoopInfos.emplace_front();
2317   CanonicalLoopInfo *CL = &LoopInfos.front();
2318 
2319   CL->Header = Header;
2320   CL->Cond = Cond;
2321   CL->Latch = Latch;
2322   CL->Exit = Exit;
2323 
2324 #ifndef NDEBUG
2325   CL->assertOK();
2326 #endif
2327   return CL;
2328 }
2329 
2330 CanonicalLoopInfo *
2331 OpenMPIRBuilder::createCanonicalLoop(const LocationDescription &Loc,
2332                                      LoopBodyGenCallbackTy BodyGenCB,
2333                                      Value *TripCount, const Twine &Name) {
2334   BasicBlock *BB = Loc.IP.getBlock();
2335   BasicBlock *NextBB = BB->getNextNode();
2336 
2337   CanonicalLoopInfo *CL = createLoopSkeleton(Loc.DL, TripCount, BB->getParent(),
2338                                              NextBB, NextBB, Name);
2339   BasicBlock *After = CL->getAfter();
2340 
2341   // If location is not set, don't connect the loop.
2342   if (updateToLocation(Loc)) {
2343     // Split the loop at the insertion point: Branch to the preheader and move
2344     // every following instruction to after the loop (the After BB). Also, the
2345     // new successor is the loop's after block.
2346     spliceBB(Builder, After, /*CreateBranch=*/false);
2347     Builder.CreateBr(CL->getPreheader());
2348   }
2349 
2350   // Emit the body content. We do it after connecting the loop to the CFG to
2351   // avoid that the callback encounters degenerate BBs.
2352   BodyGenCB(CL->getBodyIP(), CL->getIndVar());
2353 
2354 #ifndef NDEBUG
2355   CL->assertOK();
2356 #endif
2357   return CL;
2358 }
2359 
2360 CanonicalLoopInfo *OpenMPIRBuilder::createCanonicalLoop(
2361     const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
2362     Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
2363     InsertPointTy ComputeIP, const Twine &Name) {
2364 
2365   // Consider the following difficulties (assuming 8-bit signed integers):
2366   //  * Adding \p Step to the loop counter which passes \p Stop may overflow:
2367   //      DO I = 1, 100, 50
2368   ///  * A \p Step of INT_MIN cannot not be normalized to a positive direction:
2369   //      DO I = 100, 0, -128
2370 
2371   // Start, Stop and Step must be of the same integer type.
2372   auto *IndVarTy = cast<IntegerType>(Start->getType());
2373   assert(IndVarTy == Stop->getType() && "Stop type mismatch");
2374   assert(IndVarTy == Step->getType() && "Step type mismatch");
2375 
2376   LocationDescription ComputeLoc =
2377       ComputeIP.isSet() ? LocationDescription(ComputeIP, Loc.DL) : Loc;
2378   updateToLocation(ComputeLoc);
2379 
2380   ConstantInt *Zero = ConstantInt::get(IndVarTy, 0);
2381   ConstantInt *One = ConstantInt::get(IndVarTy, 1);
2382 
2383   // Like Step, but always positive.
2384   Value *Incr = Step;
2385 
2386   // Distance between Start and Stop; always positive.
2387   Value *Span;
2388 
2389   // Condition whether there are no iterations are executed at all, e.g. because
2390   // UB < LB.
2391   Value *ZeroCmp;
2392 
2393   if (IsSigned) {
2394     // Ensure that increment is positive. If not, negate and invert LB and UB.
2395     Value *IsNeg = Builder.CreateICmpSLT(Step, Zero);
2396     Incr = Builder.CreateSelect(IsNeg, Builder.CreateNeg(Step), Step);
2397     Value *LB = Builder.CreateSelect(IsNeg, Stop, Start);
2398     Value *UB = Builder.CreateSelect(IsNeg, Start, Stop);
2399     Span = Builder.CreateSub(UB, LB, "", false, true);
2400     ZeroCmp = Builder.CreateICmp(
2401         InclusiveStop ? CmpInst::ICMP_SLT : CmpInst::ICMP_SLE, UB, LB);
2402   } else {
2403     Span = Builder.CreateSub(Stop, Start, "", true);
2404     ZeroCmp = Builder.CreateICmp(
2405         InclusiveStop ? CmpInst::ICMP_ULT : CmpInst::ICMP_ULE, Stop, Start);
2406   }
2407 
2408   Value *CountIfLooping;
2409   if (InclusiveStop) {
2410     CountIfLooping = Builder.CreateAdd(Builder.CreateUDiv(Span, Incr), One);
2411   } else {
2412     // Avoid incrementing past stop since it could overflow.
2413     Value *CountIfTwo = Builder.CreateAdd(
2414         Builder.CreateUDiv(Builder.CreateSub(Span, One), Incr), One);
2415     Value *OneCmp = Builder.CreateICmp(CmpInst::ICMP_ULE, Span, Incr);
2416     CountIfLooping = Builder.CreateSelect(OneCmp, One, CountIfTwo);
2417   }
2418   Value *TripCount = Builder.CreateSelect(ZeroCmp, Zero, CountIfLooping,
2419                                           "omp_" + Name + ".tripcount");
2420 
2421   auto BodyGen = [=](InsertPointTy CodeGenIP, Value *IV) {
2422     Builder.restoreIP(CodeGenIP);
2423     Value *Span = Builder.CreateMul(IV, Step);
2424     Value *IndVar = Builder.CreateAdd(Span, Start);
2425     BodyGenCB(Builder.saveIP(), IndVar);
2426   };
2427   LocationDescription LoopLoc = ComputeIP.isSet() ? Loc.IP : Builder.saveIP();
2428   return createCanonicalLoop(LoopLoc, BodyGen, TripCount, Name);
2429 }
2430 
2431 // Returns an LLVM function to call for initializing loop bounds using OpenMP
2432 // static scheduling depending on `type`. Only i32 and i64 are supported by the
2433 // runtime. Always interpret integers as unsigned similarly to
2434 // CanonicalLoopInfo.
2435 static FunctionCallee getKmpcForStaticInitForType(Type *Ty, Module &M,
2436                                                   OpenMPIRBuilder &OMPBuilder) {
2437   unsigned Bitwidth = Ty->getIntegerBitWidth();
2438   if (Bitwidth == 32)
2439     return OMPBuilder.getOrCreateRuntimeFunction(
2440         M, omp::RuntimeFunction::OMPRTL___kmpc_for_static_init_4u);
2441   if (Bitwidth == 64)
2442     return OMPBuilder.getOrCreateRuntimeFunction(
2443         M, omp::RuntimeFunction::OMPRTL___kmpc_for_static_init_8u);
2444   llvm_unreachable("unknown OpenMP loop iterator bitwidth");
2445 }
2446 
2447 OpenMPIRBuilder::InsertPointTy
2448 OpenMPIRBuilder::applyStaticWorkshareLoop(DebugLoc DL, CanonicalLoopInfo *CLI,
2449                                           InsertPointTy AllocaIP,
2450                                           bool NeedsBarrier) {
2451   assert(CLI->isValid() && "Requires a valid canonical loop");
2452   assert(!isConflictIP(AllocaIP, CLI->getPreheaderIP()) &&
2453          "Require dedicated allocate IP");
2454 
2455   // Set up the source location value for OpenMP runtime.
2456   Builder.restoreIP(CLI->getPreheaderIP());
2457   Builder.SetCurrentDebugLocation(DL);
2458 
2459   uint32_t SrcLocStrSize;
2460   Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
2461   Value *SrcLoc = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
2462 
2463   // Declare useful OpenMP runtime functions.
2464   Value *IV = CLI->getIndVar();
2465   Type *IVTy = IV->getType();
2466   FunctionCallee StaticInit = getKmpcForStaticInitForType(IVTy, M, *this);
2467   FunctionCallee StaticFini =
2468       getOrCreateRuntimeFunction(M, omp::OMPRTL___kmpc_for_static_fini);
2469 
2470   // Allocate space for computed loop bounds as expected by the "init" function.
2471   Builder.restoreIP(AllocaIP);
2472   Type *I32Type = Type::getInt32Ty(M.getContext());
2473   Value *PLastIter = Builder.CreateAlloca(I32Type, nullptr, "p.lastiter");
2474   Value *PLowerBound = Builder.CreateAlloca(IVTy, nullptr, "p.lowerbound");
2475   Value *PUpperBound = Builder.CreateAlloca(IVTy, nullptr, "p.upperbound");
2476   Value *PStride = Builder.CreateAlloca(IVTy, nullptr, "p.stride");
2477 
2478   // At the end of the preheader, prepare for calling the "init" function by
2479   // storing the current loop bounds into the allocated space. A canonical loop
2480   // always iterates from 0 to trip-count with step 1. Note that "init" expects
2481   // and produces an inclusive upper bound.
2482   Builder.SetInsertPoint(CLI->getPreheader()->getTerminator());
2483   Constant *Zero = ConstantInt::get(IVTy, 0);
2484   Constant *One = ConstantInt::get(IVTy, 1);
2485   Builder.CreateStore(Zero, PLowerBound);
2486   Value *UpperBound = Builder.CreateSub(CLI->getTripCount(), One);
2487   Builder.CreateStore(UpperBound, PUpperBound);
2488   Builder.CreateStore(One, PStride);
2489 
2490   Value *ThreadNum = getOrCreateThreadID(SrcLoc);
2491 
2492   Constant *SchedulingType = ConstantInt::get(
2493       I32Type, static_cast<int>(OMPScheduleType::UnorderedStatic));
2494 
2495   // Call the "init" function and update the trip count of the loop with the
2496   // value it produced.
2497   Builder.CreateCall(StaticInit,
2498                      {SrcLoc, ThreadNum, SchedulingType, PLastIter, PLowerBound,
2499                       PUpperBound, PStride, One, Zero});
2500   Value *LowerBound = Builder.CreateLoad(IVTy, PLowerBound);
2501   Value *InclusiveUpperBound = Builder.CreateLoad(IVTy, PUpperBound);
2502   Value *TripCountMinusOne = Builder.CreateSub(InclusiveUpperBound, LowerBound);
2503   Value *TripCount = Builder.CreateAdd(TripCountMinusOne, One);
2504   CLI->setTripCount(TripCount);
2505 
2506   // Update all uses of the induction variable except the one in the condition
2507   // block that compares it with the actual upper bound, and the increment in
2508   // the latch block.
2509 
2510   CLI->mapIndVar([&](Instruction *OldIV) -> Value * {
2511     Builder.SetInsertPoint(CLI->getBody(),
2512                            CLI->getBody()->getFirstInsertionPt());
2513     Builder.SetCurrentDebugLocation(DL);
2514     return Builder.CreateAdd(OldIV, LowerBound);
2515   });
2516 
2517   // In the "exit" block, call the "fini" function.
2518   Builder.SetInsertPoint(CLI->getExit(),
2519                          CLI->getExit()->getTerminator()->getIterator());
2520   Builder.CreateCall(StaticFini, {SrcLoc, ThreadNum});
2521 
2522   // Add the barrier if requested.
2523   if (NeedsBarrier)
2524     createBarrier(LocationDescription(Builder.saveIP(), DL),
2525                   omp::Directive::OMPD_for, /* ForceSimpleCall */ false,
2526                   /* CheckCancelFlag */ false);
2527 
2528   InsertPointTy AfterIP = CLI->getAfterIP();
2529   CLI->invalidate();
2530 
2531   return AfterIP;
2532 }
2533 
2534 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(
2535     DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
2536     bool NeedsBarrier, Value *ChunkSize) {
2537   assert(CLI->isValid() && "Requires a valid canonical loop");
2538   assert(ChunkSize && "Chunk size is required");
2539 
2540   LLVMContext &Ctx = CLI->getFunction()->getContext();
2541   Value *IV = CLI->getIndVar();
2542   Value *OrigTripCount = CLI->getTripCount();
2543   Type *IVTy = IV->getType();
2544   assert(IVTy->getIntegerBitWidth() <= 64 &&
2545          "Max supported tripcount bitwidth is 64 bits");
2546   Type *InternalIVTy = IVTy->getIntegerBitWidth() <= 32 ? Type::getInt32Ty(Ctx)
2547                                                         : Type::getInt64Ty(Ctx);
2548   Type *I32Type = Type::getInt32Ty(M.getContext());
2549   Constant *Zero = ConstantInt::get(InternalIVTy, 0);
2550   Constant *One = ConstantInt::get(InternalIVTy, 1);
2551 
2552   // Declare useful OpenMP runtime functions.
2553   FunctionCallee StaticInit =
2554       getKmpcForStaticInitForType(InternalIVTy, M, *this);
2555   FunctionCallee StaticFini =
2556       getOrCreateRuntimeFunction(M, omp::OMPRTL___kmpc_for_static_fini);
2557 
2558   // Allocate space for computed loop bounds as expected by the "init" function.
2559   Builder.restoreIP(AllocaIP);
2560   Builder.SetCurrentDebugLocation(DL);
2561   Value *PLastIter = Builder.CreateAlloca(I32Type, nullptr, "p.lastiter");
2562   Value *PLowerBound =
2563       Builder.CreateAlloca(InternalIVTy, nullptr, "p.lowerbound");
2564   Value *PUpperBound =
2565       Builder.CreateAlloca(InternalIVTy, nullptr, "p.upperbound");
2566   Value *PStride = Builder.CreateAlloca(InternalIVTy, nullptr, "p.stride");
2567 
2568   // Set up the source location value for the OpenMP runtime.
2569   Builder.restoreIP(CLI->getPreheaderIP());
2570   Builder.SetCurrentDebugLocation(DL);
2571 
2572   // TODO: Detect overflow in ubsan or max-out with current tripcount.
2573   Value *CastedChunkSize =
2574       Builder.CreateZExtOrTrunc(ChunkSize, InternalIVTy, "chunksize");
2575   Value *CastedTripCount =
2576       Builder.CreateZExt(OrigTripCount, InternalIVTy, "tripcount");
2577 
2578   Constant *SchedulingType = ConstantInt::get(
2579       I32Type, static_cast<int>(OMPScheduleType::UnorderedStaticChunked));
2580   Builder.CreateStore(Zero, PLowerBound);
2581   Value *OrigUpperBound = Builder.CreateSub(CastedTripCount, One);
2582   Builder.CreateStore(OrigUpperBound, PUpperBound);
2583   Builder.CreateStore(One, PStride);
2584 
2585   // Call the "init" function and update the trip count of the loop with the
2586   // value it produced.
2587   uint32_t SrcLocStrSize;
2588   Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
2589   Value *SrcLoc = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
2590   Value *ThreadNum = getOrCreateThreadID(SrcLoc);
2591   Builder.CreateCall(StaticInit,
2592                      {/*loc=*/SrcLoc, /*global_tid=*/ThreadNum,
2593                       /*schedtype=*/SchedulingType, /*plastiter=*/PLastIter,
2594                       /*plower=*/PLowerBound, /*pupper=*/PUpperBound,
2595                       /*pstride=*/PStride, /*incr=*/One,
2596                       /*chunk=*/CastedChunkSize});
2597 
2598   // Load values written by the "init" function.
2599   Value *FirstChunkStart =
2600       Builder.CreateLoad(InternalIVTy, PLowerBound, "omp_firstchunk.lb");
2601   Value *FirstChunkStop =
2602       Builder.CreateLoad(InternalIVTy, PUpperBound, "omp_firstchunk.ub");
2603   Value *FirstChunkEnd = Builder.CreateAdd(FirstChunkStop, One);
2604   Value *ChunkRange =
2605       Builder.CreateSub(FirstChunkEnd, FirstChunkStart, "omp_chunk.range");
2606   Value *NextChunkStride =
2607       Builder.CreateLoad(InternalIVTy, PStride, "omp_dispatch.stride");
2608 
2609   // Create outer "dispatch" loop for enumerating the chunks.
2610   BasicBlock *DispatchEnter = splitBB(Builder, true);
2611   Value *DispatchCounter;
2612   CanonicalLoopInfo *DispatchCLI = createCanonicalLoop(
2613       {Builder.saveIP(), DL},
2614       [&](InsertPointTy BodyIP, Value *Counter) { DispatchCounter = Counter; },
2615       FirstChunkStart, CastedTripCount, NextChunkStride,
2616       /*IsSigned=*/false, /*InclusiveStop=*/false, /*ComputeIP=*/{},
2617       "dispatch");
2618 
2619   // Remember the BasicBlocks of the dispatch loop we need, then invalidate to
2620   // not have to preserve the canonical invariant.
2621   BasicBlock *DispatchBody = DispatchCLI->getBody();
2622   BasicBlock *DispatchLatch = DispatchCLI->getLatch();
2623   BasicBlock *DispatchExit = DispatchCLI->getExit();
2624   BasicBlock *DispatchAfter = DispatchCLI->getAfter();
2625   DispatchCLI->invalidate();
2626 
2627   // Rewire the original loop to become the chunk loop inside the dispatch loop.
2628   redirectTo(DispatchAfter, CLI->getAfter(), DL);
2629   redirectTo(CLI->getExit(), DispatchLatch, DL);
2630   redirectTo(DispatchBody, DispatchEnter, DL);
2631 
2632   // Prepare the prolog of the chunk loop.
2633   Builder.restoreIP(CLI->getPreheaderIP());
2634   Builder.SetCurrentDebugLocation(DL);
2635 
2636   // Compute the number of iterations of the chunk loop.
2637   Builder.SetInsertPoint(CLI->getPreheader()->getTerminator());
2638   Value *ChunkEnd = Builder.CreateAdd(DispatchCounter, ChunkRange);
2639   Value *IsLastChunk =
2640       Builder.CreateICmpUGE(ChunkEnd, CastedTripCount, "omp_chunk.is_last");
2641   Value *CountUntilOrigTripCount =
2642       Builder.CreateSub(CastedTripCount, DispatchCounter);
2643   Value *ChunkTripCount = Builder.CreateSelect(
2644       IsLastChunk, CountUntilOrigTripCount, ChunkRange, "omp_chunk.tripcount");
2645   Value *BackcastedChunkTC =
2646       Builder.CreateTrunc(ChunkTripCount, IVTy, "omp_chunk.tripcount.trunc");
2647   CLI->setTripCount(BackcastedChunkTC);
2648 
2649   // Update all uses of the induction variable except the one in the condition
2650   // block that compares it with the actual upper bound, and the increment in
2651   // the latch block.
2652   Value *BackcastedDispatchCounter =
2653       Builder.CreateTrunc(DispatchCounter, IVTy, "omp_dispatch.iv.trunc");
2654   CLI->mapIndVar([&](Instruction *) -> Value * {
2655     Builder.restoreIP(CLI->getBodyIP());
2656     return Builder.CreateAdd(IV, BackcastedDispatchCounter);
2657   });
2658 
2659   // In the "exit" block, call the "fini" function.
2660   Builder.SetInsertPoint(DispatchExit, DispatchExit->getFirstInsertionPt());
2661   Builder.CreateCall(StaticFini, {SrcLoc, ThreadNum});
2662 
2663   // Add the barrier if requested.
2664   if (NeedsBarrier)
2665     createBarrier(LocationDescription(Builder.saveIP(), DL), OMPD_for,
2666                   /*ForceSimpleCall=*/false, /*CheckCancelFlag=*/false);
2667 
2668 #ifndef NDEBUG
2669   // Even though we currently do not support applying additional methods to it,
2670   // the chunk loop should remain a canonical loop.
2671   CLI->assertOK();
2672 #endif
2673 
2674   return {DispatchAfter, DispatchAfter->getFirstInsertionPt()};
2675 }
2676 
2677 // Returns an LLVM function to call for executing an OpenMP static worksharing
2678 // for loop depending on `type`. Only i32 and i64 are supported by the runtime.
2679 // Always interpret integers as unsigned similarly to CanonicalLoopInfo.
2680 static FunctionCallee
2681 getKmpcForStaticLoopForType(Type *Ty, OpenMPIRBuilder *OMPBuilder,
2682                             WorksharingLoopType LoopType) {
2683   unsigned Bitwidth = Ty->getIntegerBitWidth();
2684   Module &M = OMPBuilder->M;
2685   switch (LoopType) {
2686   case WorksharingLoopType::ForStaticLoop:
2687     if (Bitwidth == 32)
2688       return OMPBuilder->getOrCreateRuntimeFunction(
2689           M, omp::RuntimeFunction::OMPRTL___kmpc_for_static_loop_4u);
2690     if (Bitwidth == 64)
2691       return OMPBuilder->getOrCreateRuntimeFunction(
2692           M, omp::RuntimeFunction::OMPRTL___kmpc_for_static_loop_8u);
2693     break;
2694   case WorksharingLoopType::DistributeStaticLoop:
2695     if (Bitwidth == 32)
2696       return OMPBuilder->getOrCreateRuntimeFunction(
2697           M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_static_loop_4u);
2698     if (Bitwidth == 64)
2699       return OMPBuilder->getOrCreateRuntimeFunction(
2700           M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_static_loop_8u);
2701     break;
2702   case WorksharingLoopType::DistributeForStaticLoop:
2703     if (Bitwidth == 32)
2704       return OMPBuilder->getOrCreateRuntimeFunction(
2705           M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_for_static_loop_4u);
2706     if (Bitwidth == 64)
2707       return OMPBuilder->getOrCreateRuntimeFunction(
2708           M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_for_static_loop_8u);
2709     break;
2710   }
2711   if (Bitwidth != 32 && Bitwidth != 64) {
2712     llvm_unreachable("Unknown OpenMP loop iterator bitwidth");
2713   }
2714   llvm_unreachable("Unknown type of OpenMP worksharing loop");
2715 }
2716 
2717 // Inserts a call to proper OpenMP Device RTL function which handles
2718 // loop worksharing.
2719 static void createTargetLoopWorkshareCall(
2720     OpenMPIRBuilder *OMPBuilder, WorksharingLoopType LoopType,
2721     BasicBlock *InsertBlock, Value *Ident, Value *LoopBodyArg,
2722     Type *ParallelTaskPtr, Value *TripCount, Function &LoopBodyFn) {
2723   Type *TripCountTy = TripCount->getType();
2724   Module &M = OMPBuilder->M;
2725   IRBuilder<> &Builder = OMPBuilder->Builder;
2726   FunctionCallee RTLFn =
2727       getKmpcForStaticLoopForType(TripCountTy, OMPBuilder, LoopType);
2728   SmallVector<Value *, 8> RealArgs;
2729   RealArgs.push_back(Ident);
2730   RealArgs.push_back(Builder.CreateBitCast(&LoopBodyFn, ParallelTaskPtr));
2731   RealArgs.push_back(LoopBodyArg);
2732   RealArgs.push_back(TripCount);
2733   if (LoopType == WorksharingLoopType::DistributeStaticLoop) {
2734     RealArgs.push_back(ConstantInt::get(TripCountTy, 0));
2735     Builder.CreateCall(RTLFn, RealArgs);
2736     return;
2737   }
2738   FunctionCallee RTLNumThreads = OMPBuilder->getOrCreateRuntimeFunction(
2739       M, omp::RuntimeFunction::OMPRTL_omp_get_num_threads);
2740   Builder.restoreIP({InsertBlock, std::prev(InsertBlock->end())});
2741   Value *NumThreads = Builder.CreateCall(RTLNumThreads, {});
2742 
2743   RealArgs.push_back(
2744       Builder.CreateZExtOrTrunc(NumThreads, TripCountTy, "num.threads.cast"));
2745   RealArgs.push_back(ConstantInt::get(TripCountTy, 0));
2746   if (LoopType == WorksharingLoopType::DistributeForStaticLoop) {
2747     RealArgs.push_back(ConstantInt::get(TripCountTy, 0));
2748   }
2749 
2750   Builder.CreateCall(RTLFn, RealArgs);
2751 }
2752 
2753 static void
2754 workshareLoopTargetCallback(OpenMPIRBuilder *OMPIRBuilder,
2755                             CanonicalLoopInfo *CLI, Value *Ident,
2756                             Function &OutlinedFn, Type *ParallelTaskPtr,
2757                             const SmallVector<Instruction *, 4> &ToBeDeleted,
2758                             WorksharingLoopType LoopType) {
2759   IRBuilder<> &Builder = OMPIRBuilder->Builder;
2760   BasicBlock *Preheader = CLI->getPreheader();
2761   Value *TripCount = CLI->getTripCount();
2762 
2763   // After loop body outling, the loop body contains only set up
2764   // of loop body argument structure and the call to the outlined
2765   // loop body function. Firstly, we need to move setup of loop body args
2766   // into loop preheader.
2767   Preheader->splice(std::prev(Preheader->end()), CLI->getBody(),
2768                     CLI->getBody()->begin(), std::prev(CLI->getBody()->end()));
2769 
2770   // The next step is to remove the whole loop. We do not it need anymore.
2771   // That's why make an unconditional branch from loop preheader to loop
2772   // exit block
2773   Builder.restoreIP({Preheader, Preheader->end()});
2774   Preheader->getTerminator()->eraseFromParent();
2775   Builder.CreateBr(CLI->getExit());
2776 
2777   // Delete dead loop blocks
2778   OpenMPIRBuilder::OutlineInfo CleanUpInfo;
2779   SmallPtrSet<BasicBlock *, 32> RegionBlockSet;
2780   SmallVector<BasicBlock *, 32> BlocksToBeRemoved;
2781   CleanUpInfo.EntryBB = CLI->getHeader();
2782   CleanUpInfo.ExitBB = CLI->getExit();
2783   CleanUpInfo.collectBlocks(RegionBlockSet, BlocksToBeRemoved);
2784   DeleteDeadBlocks(BlocksToBeRemoved);
2785 
2786   // Find the instruction which corresponds to loop body argument structure
2787   // and remove the call to loop body function instruction.
2788   Value *LoopBodyArg;
2789   User *OutlinedFnUser = OutlinedFn.getUniqueUndroppableUser();
2790   assert(OutlinedFnUser &&
2791          "Expected unique undroppable user of outlined function");
2792   CallInst *OutlinedFnCallInstruction = dyn_cast<CallInst>(OutlinedFnUser);
2793   assert(OutlinedFnCallInstruction && "Expected outlined function call");
2794   assert((OutlinedFnCallInstruction->getParent() == Preheader) &&
2795          "Expected outlined function call to be located in loop preheader");
2796   // Check in case no argument structure has been passed.
2797   if (OutlinedFnCallInstruction->arg_size() > 1)
2798     LoopBodyArg = OutlinedFnCallInstruction->getArgOperand(1);
2799   else
2800     LoopBodyArg = Constant::getNullValue(Builder.getPtrTy());
2801   OutlinedFnCallInstruction->eraseFromParent();
2802 
2803   createTargetLoopWorkshareCall(OMPIRBuilder, LoopType, Preheader, Ident,
2804                                 LoopBodyArg, ParallelTaskPtr, TripCount,
2805                                 OutlinedFn);
2806 
2807   for (auto &ToBeDeletedItem : ToBeDeleted)
2808     ToBeDeletedItem->eraseFromParent();
2809   CLI->invalidate();
2810 }
2811 
2812 OpenMPIRBuilder::InsertPointTy
2813 OpenMPIRBuilder::applyWorkshareLoopTarget(DebugLoc DL, CanonicalLoopInfo *CLI,
2814                                           InsertPointTy AllocaIP,
2815                                           WorksharingLoopType LoopType) {
2816   uint32_t SrcLocStrSize;
2817   Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
2818   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
2819 
2820   OutlineInfo OI;
2821   OI.OuterAllocaBB = CLI->getPreheader();
2822   Function *OuterFn = CLI->getPreheader()->getParent();
2823 
2824   // Instructions which need to be deleted at the end of code generation
2825   SmallVector<Instruction *, 4> ToBeDeleted;
2826 
2827   OI.OuterAllocaBB = AllocaIP.getBlock();
2828 
2829   // Mark the body loop as region which needs to be extracted
2830   OI.EntryBB = CLI->getBody();
2831   OI.ExitBB = CLI->getLatch()->splitBasicBlock(CLI->getLatch()->begin(),
2832                                                "omp.prelatch", true);
2833 
2834   // Prepare loop body for extraction
2835   Builder.restoreIP({CLI->getPreheader(), CLI->getPreheader()->begin()});
2836 
2837   // Insert new loop counter variable which will be used only in loop
2838   // body.
2839   AllocaInst *NewLoopCnt = Builder.CreateAlloca(CLI->getIndVarType(), 0, "");
2840   Instruction *NewLoopCntLoad =
2841       Builder.CreateLoad(CLI->getIndVarType(), NewLoopCnt);
2842   // New loop counter instructions are redundant in the loop preheader when
2843   // code generation for workshare loop is finshed. That's why mark them as
2844   // ready for deletion.
2845   ToBeDeleted.push_back(NewLoopCntLoad);
2846   ToBeDeleted.push_back(NewLoopCnt);
2847 
2848   // Analyse loop body region. Find all input variables which are used inside
2849   // loop body region.
2850   SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
2851   SmallVector<BasicBlock *, 32> Blocks;
2852   OI.collectBlocks(ParallelRegionBlockSet, Blocks);
2853   SmallVector<BasicBlock *, 32> BlocksT(ParallelRegionBlockSet.begin(),
2854                                         ParallelRegionBlockSet.end());
2855 
2856   CodeExtractorAnalysisCache CEAC(*OuterFn);
2857   CodeExtractor Extractor(Blocks,
2858                           /* DominatorTree */ nullptr,
2859                           /* AggregateArgs */ true,
2860                           /* BlockFrequencyInfo */ nullptr,
2861                           /* BranchProbabilityInfo */ nullptr,
2862                           /* AssumptionCache */ nullptr,
2863                           /* AllowVarArgs */ true,
2864                           /* AllowAlloca */ true,
2865                           /* AllocationBlock */ CLI->getPreheader(),
2866                           /* Suffix */ ".omp_wsloop",
2867                           /* AggrArgsIn0AddrSpace */ true);
2868 
2869   BasicBlock *CommonExit = nullptr;
2870   SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
2871 
2872   // Find allocas outside the loop body region which are used inside loop
2873   // body
2874   Extractor.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
2875 
2876   // We need to model loop body region as the function f(cnt, loop_arg).
2877   // That's why we replace loop induction variable by the new counter
2878   // which will be one of loop body function argument
2879   SmallVector<User *> Users(CLI->getIndVar()->user_begin(),
2880                             CLI->getIndVar()->user_end());
2881   for (auto Use : Users) {
2882     if (Instruction *Inst = dyn_cast<Instruction>(Use)) {
2883       if (ParallelRegionBlockSet.count(Inst->getParent())) {
2884         Inst->replaceUsesOfWith(CLI->getIndVar(), NewLoopCntLoad);
2885       }
2886     }
2887   }
2888   // Make sure that loop counter variable is not merged into loop body
2889   // function argument structure and it is passed as separate variable
2890   OI.ExcludeArgsFromAggregate.push_back(NewLoopCntLoad);
2891 
2892   // PostOutline CB is invoked when loop body function is outlined and
2893   // loop body is replaced by call to outlined function. We need to add
2894   // call to OpenMP device rtl inside loop preheader. OpenMP device rtl
2895   // function will handle loop control logic.
2896   //
2897   OI.PostOutlineCB = [=, ToBeDeletedVec =
2898                              std::move(ToBeDeleted)](Function &OutlinedFn) {
2899     workshareLoopTargetCallback(this, CLI, Ident, OutlinedFn, ParallelTaskPtr,
2900                                 ToBeDeletedVec, LoopType);
2901   };
2902   addOutlineInfo(std::move(OI));
2903   return CLI->getAfterIP();
2904 }
2905 
2906 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyWorkshareLoop(
2907     DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
2908     bool NeedsBarrier, omp::ScheduleKind SchedKind, Value *ChunkSize,
2909     bool HasSimdModifier, bool HasMonotonicModifier,
2910     bool HasNonmonotonicModifier, bool HasOrderedClause,
2911     WorksharingLoopType LoopType) {
2912   if (Config.isTargetDevice())
2913     return applyWorkshareLoopTarget(DL, CLI, AllocaIP, LoopType);
2914   OMPScheduleType EffectiveScheduleType = computeOpenMPScheduleType(
2915       SchedKind, ChunkSize, HasSimdModifier, HasMonotonicModifier,
2916       HasNonmonotonicModifier, HasOrderedClause);
2917 
2918   bool IsOrdered = (EffectiveScheduleType & OMPScheduleType::ModifierOrdered) ==
2919                    OMPScheduleType::ModifierOrdered;
2920   switch (EffectiveScheduleType & ~OMPScheduleType::ModifierMask) {
2921   case OMPScheduleType::BaseStatic:
2922     assert(!ChunkSize && "No chunk size with static-chunked schedule");
2923     if (IsOrdered)
2924       return applyDynamicWorkshareLoop(DL, CLI, AllocaIP, EffectiveScheduleType,
2925                                        NeedsBarrier, ChunkSize);
2926     // FIXME: Monotonicity ignored?
2927     return applyStaticWorkshareLoop(DL, CLI, AllocaIP, NeedsBarrier);
2928 
2929   case OMPScheduleType::BaseStaticChunked:
2930     if (IsOrdered)
2931       return applyDynamicWorkshareLoop(DL, CLI, AllocaIP, EffectiveScheduleType,
2932                                        NeedsBarrier, ChunkSize);
2933     // FIXME: Monotonicity ignored?
2934     return applyStaticChunkedWorkshareLoop(DL, CLI, AllocaIP, NeedsBarrier,
2935                                            ChunkSize);
2936 
2937   case OMPScheduleType::BaseRuntime:
2938   case OMPScheduleType::BaseAuto:
2939   case OMPScheduleType::BaseGreedy:
2940   case OMPScheduleType::BaseBalanced:
2941   case OMPScheduleType::BaseSteal:
2942   case OMPScheduleType::BaseGuidedSimd:
2943   case OMPScheduleType::BaseRuntimeSimd:
2944     assert(!ChunkSize &&
2945            "schedule type does not support user-defined chunk sizes");
2946     [[fallthrough]];
2947   case OMPScheduleType::BaseDynamicChunked:
2948   case OMPScheduleType::BaseGuidedChunked:
2949   case OMPScheduleType::BaseGuidedIterativeChunked:
2950   case OMPScheduleType::BaseGuidedAnalyticalChunked:
2951   case OMPScheduleType::BaseStaticBalancedChunked:
2952     return applyDynamicWorkshareLoop(DL, CLI, AllocaIP, EffectiveScheduleType,
2953                                      NeedsBarrier, ChunkSize);
2954 
2955   default:
2956     llvm_unreachable("Unknown/unimplemented schedule kind");
2957   }
2958 }
2959 
2960 /// Returns an LLVM function to call for initializing loop bounds using OpenMP
2961 /// dynamic scheduling depending on `type`. Only i32 and i64 are supported by
2962 /// the runtime. Always interpret integers as unsigned similarly to
2963 /// CanonicalLoopInfo.
2964 static FunctionCallee
2965 getKmpcForDynamicInitForType(Type *Ty, Module &M, OpenMPIRBuilder &OMPBuilder) {
2966   unsigned Bitwidth = Ty->getIntegerBitWidth();
2967   if (Bitwidth == 32)
2968     return OMPBuilder.getOrCreateRuntimeFunction(
2969         M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_init_4u);
2970   if (Bitwidth == 64)
2971     return OMPBuilder.getOrCreateRuntimeFunction(
2972         M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_init_8u);
2973   llvm_unreachable("unknown OpenMP loop iterator bitwidth");
2974 }
2975 
2976 /// Returns an LLVM function to call for updating the next loop using OpenMP
2977 /// dynamic scheduling depending on `type`. Only i32 and i64 are supported by
2978 /// the runtime. Always interpret integers as unsigned similarly to
2979 /// CanonicalLoopInfo.
2980 static FunctionCallee
2981 getKmpcForDynamicNextForType(Type *Ty, Module &M, OpenMPIRBuilder &OMPBuilder) {
2982   unsigned Bitwidth = Ty->getIntegerBitWidth();
2983   if (Bitwidth == 32)
2984     return OMPBuilder.getOrCreateRuntimeFunction(
2985         M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_next_4u);
2986   if (Bitwidth == 64)
2987     return OMPBuilder.getOrCreateRuntimeFunction(
2988         M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_next_8u);
2989   llvm_unreachable("unknown OpenMP loop iterator bitwidth");
2990 }
2991 
2992 /// Returns an LLVM function to call for finalizing the dynamic loop using
2993 /// depending on `type`. Only i32 and i64 are supported by the runtime. Always
2994 /// interpret integers as unsigned similarly to CanonicalLoopInfo.
2995 static FunctionCallee
2996 getKmpcForDynamicFiniForType(Type *Ty, Module &M, OpenMPIRBuilder &OMPBuilder) {
2997   unsigned Bitwidth = Ty->getIntegerBitWidth();
2998   if (Bitwidth == 32)
2999     return OMPBuilder.getOrCreateRuntimeFunction(
3000         M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_fini_4u);
3001   if (Bitwidth == 64)
3002     return OMPBuilder.getOrCreateRuntimeFunction(
3003         M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_fini_8u);
3004   llvm_unreachable("unknown OpenMP loop iterator bitwidth");
3005 }
3006 
3007 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyDynamicWorkshareLoop(
3008     DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
3009     OMPScheduleType SchedType, bool NeedsBarrier, Value *Chunk) {
3010   assert(CLI->isValid() && "Requires a valid canonical loop");
3011   assert(!isConflictIP(AllocaIP, CLI->getPreheaderIP()) &&
3012          "Require dedicated allocate IP");
3013   assert(isValidWorkshareLoopScheduleType(SchedType) &&
3014          "Require valid schedule type");
3015 
3016   bool Ordered = (SchedType & OMPScheduleType::ModifierOrdered) ==
3017                  OMPScheduleType::ModifierOrdered;
3018 
3019   // Set up the source location value for OpenMP runtime.
3020   Builder.SetCurrentDebugLocation(DL);
3021 
3022   uint32_t SrcLocStrSize;
3023   Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
3024   Value *SrcLoc = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3025 
3026   // Declare useful OpenMP runtime functions.
3027   Value *IV = CLI->getIndVar();
3028   Type *IVTy = IV->getType();
3029   FunctionCallee DynamicInit = getKmpcForDynamicInitForType(IVTy, M, *this);
3030   FunctionCallee DynamicNext = getKmpcForDynamicNextForType(IVTy, M, *this);
3031 
3032   // Allocate space for computed loop bounds as expected by the "init" function.
3033   Builder.restoreIP(AllocaIP);
3034   Type *I32Type = Type::getInt32Ty(M.getContext());
3035   Value *PLastIter = Builder.CreateAlloca(I32Type, nullptr, "p.lastiter");
3036   Value *PLowerBound = Builder.CreateAlloca(IVTy, nullptr, "p.lowerbound");
3037   Value *PUpperBound = Builder.CreateAlloca(IVTy, nullptr, "p.upperbound");
3038   Value *PStride = Builder.CreateAlloca(IVTy, nullptr, "p.stride");
3039 
3040   // At the end of the preheader, prepare for calling the "init" function by
3041   // storing the current loop bounds into the allocated space. A canonical loop
3042   // always iterates from 0 to trip-count with step 1. Note that "init" expects
3043   // and produces an inclusive upper bound.
3044   BasicBlock *PreHeader = CLI->getPreheader();
3045   Builder.SetInsertPoint(PreHeader->getTerminator());
3046   Constant *One = ConstantInt::get(IVTy, 1);
3047   Builder.CreateStore(One, PLowerBound);
3048   Value *UpperBound = CLI->getTripCount();
3049   Builder.CreateStore(UpperBound, PUpperBound);
3050   Builder.CreateStore(One, PStride);
3051 
3052   BasicBlock *Header = CLI->getHeader();
3053   BasicBlock *Exit = CLI->getExit();
3054   BasicBlock *Cond = CLI->getCond();
3055   BasicBlock *Latch = CLI->getLatch();
3056   InsertPointTy AfterIP = CLI->getAfterIP();
3057 
3058   // The CLI will be "broken" in the code below, as the loop is no longer
3059   // a valid canonical loop.
3060 
3061   if (!Chunk)
3062     Chunk = One;
3063 
3064   Value *ThreadNum = getOrCreateThreadID(SrcLoc);
3065 
3066   Constant *SchedulingType =
3067       ConstantInt::get(I32Type, static_cast<int>(SchedType));
3068 
3069   // Call the "init" function.
3070   Builder.CreateCall(DynamicInit,
3071                      {SrcLoc, ThreadNum, SchedulingType, /* LowerBound */ One,
3072                       UpperBound, /* step */ One, Chunk});
3073 
3074   // An outer loop around the existing one.
3075   BasicBlock *OuterCond = BasicBlock::Create(
3076       PreHeader->getContext(), Twine(PreHeader->getName()) + ".outer.cond",
3077       PreHeader->getParent());
3078   // This needs to be 32-bit always, so can't use the IVTy Zero above.
3079   Builder.SetInsertPoint(OuterCond, OuterCond->getFirstInsertionPt());
3080   Value *Res =
3081       Builder.CreateCall(DynamicNext, {SrcLoc, ThreadNum, PLastIter,
3082                                        PLowerBound, PUpperBound, PStride});
3083   Constant *Zero32 = ConstantInt::get(I32Type, 0);
3084   Value *MoreWork = Builder.CreateCmp(CmpInst::ICMP_NE, Res, Zero32);
3085   Value *LowerBound =
3086       Builder.CreateSub(Builder.CreateLoad(IVTy, PLowerBound), One, "lb");
3087   Builder.CreateCondBr(MoreWork, Header, Exit);
3088 
3089   // Change PHI-node in loop header to use outer cond rather than preheader,
3090   // and set IV to the LowerBound.
3091   Instruction *Phi = &Header->front();
3092   auto *PI = cast<PHINode>(Phi);
3093   PI->setIncomingBlock(0, OuterCond);
3094   PI->setIncomingValue(0, LowerBound);
3095 
3096   // Then set the pre-header to jump to the OuterCond
3097   Instruction *Term = PreHeader->getTerminator();
3098   auto *Br = cast<BranchInst>(Term);
3099   Br->setSuccessor(0, OuterCond);
3100 
3101   // Modify the inner condition:
3102   // * Use the UpperBound returned from the DynamicNext call.
3103   // * jump to the loop outer loop when done with one of the inner loops.
3104   Builder.SetInsertPoint(Cond, Cond->getFirstInsertionPt());
3105   UpperBound = Builder.CreateLoad(IVTy, PUpperBound, "ub");
3106   Instruction *Comp = &*Builder.GetInsertPoint();
3107   auto *CI = cast<CmpInst>(Comp);
3108   CI->setOperand(1, UpperBound);
3109   // Redirect the inner exit to branch to outer condition.
3110   Instruction *Branch = &Cond->back();
3111   auto *BI = cast<BranchInst>(Branch);
3112   assert(BI->getSuccessor(1) == Exit);
3113   BI->setSuccessor(1, OuterCond);
3114 
3115   // Call the "fini" function if "ordered" is present in wsloop directive.
3116   if (Ordered) {
3117     Builder.SetInsertPoint(&Latch->back());
3118     FunctionCallee DynamicFini = getKmpcForDynamicFiniForType(IVTy, M, *this);
3119     Builder.CreateCall(DynamicFini, {SrcLoc, ThreadNum});
3120   }
3121 
3122   // Add the barrier if requested.
3123   if (NeedsBarrier) {
3124     Builder.SetInsertPoint(&Exit->back());
3125     createBarrier(LocationDescription(Builder.saveIP(), DL),
3126                   omp::Directive::OMPD_for, /* ForceSimpleCall */ false,
3127                   /* CheckCancelFlag */ false);
3128   }
3129 
3130   CLI->invalidate();
3131   return AfterIP;
3132 }
3133 
3134 /// Redirect all edges that branch to \p OldTarget to \p NewTarget. That is,
3135 /// after this \p OldTarget will be orphaned.
3136 static void redirectAllPredecessorsTo(BasicBlock *OldTarget,
3137                                       BasicBlock *NewTarget, DebugLoc DL) {
3138   for (BasicBlock *Pred : make_early_inc_range(predecessors(OldTarget)))
3139     redirectTo(Pred, NewTarget, DL);
3140 }
3141 
3142 /// Determine which blocks in \p BBs are reachable from outside and remove the
3143 /// ones that are not reachable from the function.
3144 static void removeUnusedBlocksFromParent(ArrayRef<BasicBlock *> BBs) {
3145   SmallPtrSet<BasicBlock *, 6> BBsToErase{BBs.begin(), BBs.end()};
3146   auto HasRemainingUses = [&BBsToErase](BasicBlock *BB) {
3147     for (Use &U : BB->uses()) {
3148       auto *UseInst = dyn_cast<Instruction>(U.getUser());
3149       if (!UseInst)
3150         continue;
3151       if (BBsToErase.count(UseInst->getParent()))
3152         continue;
3153       return true;
3154     }
3155     return false;
3156   };
3157 
3158   while (true) {
3159     bool Changed = false;
3160     for (BasicBlock *BB : make_early_inc_range(BBsToErase)) {
3161       if (HasRemainingUses(BB)) {
3162         BBsToErase.erase(BB);
3163         Changed = true;
3164       }
3165     }
3166     if (!Changed)
3167       break;
3168   }
3169 
3170   SmallVector<BasicBlock *, 7> BBVec(BBsToErase.begin(), BBsToErase.end());
3171   DeleteDeadBlocks(BBVec);
3172 }
3173 
3174 CanonicalLoopInfo *
3175 OpenMPIRBuilder::collapseLoops(DebugLoc DL, ArrayRef<CanonicalLoopInfo *> Loops,
3176                                InsertPointTy ComputeIP) {
3177   assert(Loops.size() >= 1 && "At least one loop required");
3178   size_t NumLoops = Loops.size();
3179 
3180   // Nothing to do if there is already just one loop.
3181   if (NumLoops == 1)
3182     return Loops.front();
3183 
3184   CanonicalLoopInfo *Outermost = Loops.front();
3185   CanonicalLoopInfo *Innermost = Loops.back();
3186   BasicBlock *OrigPreheader = Outermost->getPreheader();
3187   BasicBlock *OrigAfter = Outermost->getAfter();
3188   Function *F = OrigPreheader->getParent();
3189 
3190   // Loop control blocks that may become orphaned later.
3191   SmallVector<BasicBlock *, 12> OldControlBBs;
3192   OldControlBBs.reserve(6 * Loops.size());
3193   for (CanonicalLoopInfo *Loop : Loops)
3194     Loop->collectControlBlocks(OldControlBBs);
3195 
3196   // Setup the IRBuilder for inserting the trip count computation.
3197   Builder.SetCurrentDebugLocation(DL);
3198   if (ComputeIP.isSet())
3199     Builder.restoreIP(ComputeIP);
3200   else
3201     Builder.restoreIP(Outermost->getPreheaderIP());
3202 
3203   // Derive the collapsed' loop trip count.
3204   // TODO: Find common/largest indvar type.
3205   Value *CollapsedTripCount = nullptr;
3206   for (CanonicalLoopInfo *L : Loops) {
3207     assert(L->isValid() &&
3208            "All loops to collapse must be valid canonical loops");
3209     Value *OrigTripCount = L->getTripCount();
3210     if (!CollapsedTripCount) {
3211       CollapsedTripCount = OrigTripCount;
3212       continue;
3213     }
3214 
3215     // TODO: Enable UndefinedSanitizer to diagnose an overflow here.
3216     CollapsedTripCount = Builder.CreateMul(CollapsedTripCount, OrigTripCount,
3217                                            {}, /*HasNUW=*/true);
3218   }
3219 
3220   // Create the collapsed loop control flow.
3221   CanonicalLoopInfo *Result =
3222       createLoopSkeleton(DL, CollapsedTripCount, F,
3223                          OrigPreheader->getNextNode(), OrigAfter, "collapsed");
3224 
3225   // Build the collapsed loop body code.
3226   // Start with deriving the input loop induction variables from the collapsed
3227   // one, using a divmod scheme. To preserve the original loops' order, the
3228   // innermost loop use the least significant bits.
3229   Builder.restoreIP(Result->getBodyIP());
3230 
3231   Value *Leftover = Result->getIndVar();
3232   SmallVector<Value *> NewIndVars;
3233   NewIndVars.resize(NumLoops);
3234   for (int i = NumLoops - 1; i >= 1; --i) {
3235     Value *OrigTripCount = Loops[i]->getTripCount();
3236 
3237     Value *NewIndVar = Builder.CreateURem(Leftover, OrigTripCount);
3238     NewIndVars[i] = NewIndVar;
3239 
3240     Leftover = Builder.CreateUDiv(Leftover, OrigTripCount);
3241   }
3242   // Outermost loop gets all the remaining bits.
3243   NewIndVars[0] = Leftover;
3244 
3245   // Construct the loop body control flow.
3246   // We progressively construct the branch structure following in direction of
3247   // the control flow, from the leading in-between code, the loop nest body, the
3248   // trailing in-between code, and rejoining the collapsed loop's latch.
3249   // ContinueBlock and ContinuePred keep track of the source(s) of next edge. If
3250   // the ContinueBlock is set, continue with that block. If ContinuePred, use
3251   // its predecessors as sources.
3252   BasicBlock *ContinueBlock = Result->getBody();
3253   BasicBlock *ContinuePred = nullptr;
3254   auto ContinueWith = [&ContinueBlock, &ContinuePred, DL](BasicBlock *Dest,
3255                                                           BasicBlock *NextSrc) {
3256     if (ContinueBlock)
3257       redirectTo(ContinueBlock, Dest, DL);
3258     else
3259       redirectAllPredecessorsTo(ContinuePred, Dest, DL);
3260 
3261     ContinueBlock = nullptr;
3262     ContinuePred = NextSrc;
3263   };
3264 
3265   // The code before the nested loop of each level.
3266   // Because we are sinking it into the nest, it will be executed more often
3267   // that the original loop. More sophisticated schemes could keep track of what
3268   // the in-between code is and instantiate it only once per thread.
3269   for (size_t i = 0; i < NumLoops - 1; ++i)
3270     ContinueWith(Loops[i]->getBody(), Loops[i + 1]->getHeader());
3271 
3272   // Connect the loop nest body.
3273   ContinueWith(Innermost->getBody(), Innermost->getLatch());
3274 
3275   // The code after the nested loop at each level.
3276   for (size_t i = NumLoops - 1; i > 0; --i)
3277     ContinueWith(Loops[i]->getAfter(), Loops[i - 1]->getLatch());
3278 
3279   // Connect the finished loop to the collapsed loop latch.
3280   ContinueWith(Result->getLatch(), nullptr);
3281 
3282   // Replace the input loops with the new collapsed loop.
3283   redirectTo(Outermost->getPreheader(), Result->getPreheader(), DL);
3284   redirectTo(Result->getAfter(), Outermost->getAfter(), DL);
3285 
3286   // Replace the input loop indvars with the derived ones.
3287   for (size_t i = 0; i < NumLoops; ++i)
3288     Loops[i]->getIndVar()->replaceAllUsesWith(NewIndVars[i]);
3289 
3290   // Remove unused parts of the input loops.
3291   removeUnusedBlocksFromParent(OldControlBBs);
3292 
3293   for (CanonicalLoopInfo *L : Loops)
3294     L->invalidate();
3295 
3296 #ifndef NDEBUG
3297   Result->assertOK();
3298 #endif
3299   return Result;
3300 }
3301 
3302 std::vector<CanonicalLoopInfo *>
3303 OpenMPIRBuilder::tileLoops(DebugLoc DL, ArrayRef<CanonicalLoopInfo *> Loops,
3304                            ArrayRef<Value *> TileSizes) {
3305   assert(TileSizes.size() == Loops.size() &&
3306          "Must pass as many tile sizes as there are loops");
3307   int NumLoops = Loops.size();
3308   assert(NumLoops >= 1 && "At least one loop to tile required");
3309 
3310   CanonicalLoopInfo *OutermostLoop = Loops.front();
3311   CanonicalLoopInfo *InnermostLoop = Loops.back();
3312   Function *F = OutermostLoop->getBody()->getParent();
3313   BasicBlock *InnerEnter = InnermostLoop->getBody();
3314   BasicBlock *InnerLatch = InnermostLoop->getLatch();
3315 
3316   // Loop control blocks that may become orphaned later.
3317   SmallVector<BasicBlock *, 12> OldControlBBs;
3318   OldControlBBs.reserve(6 * Loops.size());
3319   for (CanonicalLoopInfo *Loop : Loops)
3320     Loop->collectControlBlocks(OldControlBBs);
3321 
3322   // Collect original trip counts and induction variable to be accessible by
3323   // index. Also, the structure of the original loops is not preserved during
3324   // the construction of the tiled loops, so do it before we scavenge the BBs of
3325   // any original CanonicalLoopInfo.
3326   SmallVector<Value *, 4> OrigTripCounts, OrigIndVars;
3327   for (CanonicalLoopInfo *L : Loops) {
3328     assert(L->isValid() && "All input loops must be valid canonical loops");
3329     OrigTripCounts.push_back(L->getTripCount());
3330     OrigIndVars.push_back(L->getIndVar());
3331   }
3332 
3333   // Collect the code between loop headers. These may contain SSA definitions
3334   // that are used in the loop nest body. To be usable with in the innermost
3335   // body, these BasicBlocks will be sunk into the loop nest body. That is,
3336   // these instructions may be executed more often than before the tiling.
3337   // TODO: It would be sufficient to only sink them into body of the
3338   // corresponding tile loop.
3339   SmallVector<std::pair<BasicBlock *, BasicBlock *>, 4> InbetweenCode;
3340   for (int i = 0; i < NumLoops - 1; ++i) {
3341     CanonicalLoopInfo *Surrounding = Loops[i];
3342     CanonicalLoopInfo *Nested = Loops[i + 1];
3343 
3344     BasicBlock *EnterBB = Surrounding->getBody();
3345     BasicBlock *ExitBB = Nested->getHeader();
3346     InbetweenCode.emplace_back(EnterBB, ExitBB);
3347   }
3348 
3349   // Compute the trip counts of the floor loops.
3350   Builder.SetCurrentDebugLocation(DL);
3351   Builder.restoreIP(OutermostLoop->getPreheaderIP());
3352   SmallVector<Value *, 4> FloorCount, FloorRems;
3353   for (int i = 0; i < NumLoops; ++i) {
3354     Value *TileSize = TileSizes[i];
3355     Value *OrigTripCount = OrigTripCounts[i];
3356     Type *IVType = OrigTripCount->getType();
3357 
3358     Value *FloorTripCount = Builder.CreateUDiv(OrigTripCount, TileSize);
3359     Value *FloorTripRem = Builder.CreateURem(OrigTripCount, TileSize);
3360 
3361     // 0 if tripcount divides the tilesize, 1 otherwise.
3362     // 1 means we need an additional iteration for a partial tile.
3363     //
3364     // Unfortunately we cannot just use the roundup-formula
3365     //   (tripcount + tilesize - 1)/tilesize
3366     // because the summation might overflow. We do not want introduce undefined
3367     // behavior when the untiled loop nest did not.
3368     Value *FloorTripOverflow =
3369         Builder.CreateICmpNE(FloorTripRem, ConstantInt::get(IVType, 0));
3370 
3371     FloorTripOverflow = Builder.CreateZExt(FloorTripOverflow, IVType);
3372     FloorTripCount =
3373         Builder.CreateAdd(FloorTripCount, FloorTripOverflow,
3374                           "omp_floor" + Twine(i) + ".tripcount", true);
3375 
3376     // Remember some values for later use.
3377     FloorCount.push_back(FloorTripCount);
3378     FloorRems.push_back(FloorTripRem);
3379   }
3380 
3381   // Generate the new loop nest, from the outermost to the innermost.
3382   std::vector<CanonicalLoopInfo *> Result;
3383   Result.reserve(NumLoops * 2);
3384 
3385   // The basic block of the surrounding loop that enters the nest generated
3386   // loop.
3387   BasicBlock *Enter = OutermostLoop->getPreheader();
3388 
3389   // The basic block of the surrounding loop where the inner code should
3390   // continue.
3391   BasicBlock *Continue = OutermostLoop->getAfter();
3392 
3393   // Where the next loop basic block should be inserted.
3394   BasicBlock *OutroInsertBefore = InnermostLoop->getExit();
3395 
3396   auto EmbeddNewLoop =
3397       [this, DL, F, InnerEnter, &Enter, &Continue, &OutroInsertBefore](
3398           Value *TripCount, const Twine &Name) -> CanonicalLoopInfo * {
3399     CanonicalLoopInfo *EmbeddedLoop = createLoopSkeleton(
3400         DL, TripCount, F, InnerEnter, OutroInsertBefore, Name);
3401     redirectTo(Enter, EmbeddedLoop->getPreheader(), DL);
3402     redirectTo(EmbeddedLoop->getAfter(), Continue, DL);
3403 
3404     // Setup the position where the next embedded loop connects to this loop.
3405     Enter = EmbeddedLoop->getBody();
3406     Continue = EmbeddedLoop->getLatch();
3407     OutroInsertBefore = EmbeddedLoop->getLatch();
3408     return EmbeddedLoop;
3409   };
3410 
3411   auto EmbeddNewLoops = [&Result, &EmbeddNewLoop](ArrayRef<Value *> TripCounts,
3412                                                   const Twine &NameBase) {
3413     for (auto P : enumerate(TripCounts)) {
3414       CanonicalLoopInfo *EmbeddedLoop =
3415           EmbeddNewLoop(P.value(), NameBase + Twine(P.index()));
3416       Result.push_back(EmbeddedLoop);
3417     }
3418   };
3419 
3420   EmbeddNewLoops(FloorCount, "floor");
3421 
3422   // Within the innermost floor loop, emit the code that computes the tile
3423   // sizes.
3424   Builder.SetInsertPoint(Enter->getTerminator());
3425   SmallVector<Value *, 4> TileCounts;
3426   for (int i = 0; i < NumLoops; ++i) {
3427     CanonicalLoopInfo *FloorLoop = Result[i];
3428     Value *TileSize = TileSizes[i];
3429 
3430     Value *FloorIsEpilogue =
3431         Builder.CreateICmpEQ(FloorLoop->getIndVar(), FloorCount[i]);
3432     Value *TileTripCount =
3433         Builder.CreateSelect(FloorIsEpilogue, FloorRems[i], TileSize);
3434 
3435     TileCounts.push_back(TileTripCount);
3436   }
3437 
3438   // Create the tile loops.
3439   EmbeddNewLoops(TileCounts, "tile");
3440 
3441   // Insert the inbetween code into the body.
3442   BasicBlock *BodyEnter = Enter;
3443   BasicBlock *BodyEntered = nullptr;
3444   for (std::pair<BasicBlock *, BasicBlock *> P : InbetweenCode) {
3445     BasicBlock *EnterBB = P.first;
3446     BasicBlock *ExitBB = P.second;
3447 
3448     if (BodyEnter)
3449       redirectTo(BodyEnter, EnterBB, DL);
3450     else
3451       redirectAllPredecessorsTo(BodyEntered, EnterBB, DL);
3452 
3453     BodyEnter = nullptr;
3454     BodyEntered = ExitBB;
3455   }
3456 
3457   // Append the original loop nest body into the generated loop nest body.
3458   if (BodyEnter)
3459     redirectTo(BodyEnter, InnerEnter, DL);
3460   else
3461     redirectAllPredecessorsTo(BodyEntered, InnerEnter, DL);
3462   redirectAllPredecessorsTo(InnerLatch, Continue, DL);
3463 
3464   // Replace the original induction variable with an induction variable computed
3465   // from the tile and floor induction variables.
3466   Builder.restoreIP(Result.back()->getBodyIP());
3467   for (int i = 0; i < NumLoops; ++i) {
3468     CanonicalLoopInfo *FloorLoop = Result[i];
3469     CanonicalLoopInfo *TileLoop = Result[NumLoops + i];
3470     Value *OrigIndVar = OrigIndVars[i];
3471     Value *Size = TileSizes[i];
3472 
3473     Value *Scale =
3474         Builder.CreateMul(Size, FloorLoop->getIndVar(), {}, /*HasNUW=*/true);
3475     Value *Shift =
3476         Builder.CreateAdd(Scale, TileLoop->getIndVar(), {}, /*HasNUW=*/true);
3477     OrigIndVar->replaceAllUsesWith(Shift);
3478   }
3479 
3480   // Remove unused parts of the original loops.
3481   removeUnusedBlocksFromParent(OldControlBBs);
3482 
3483   for (CanonicalLoopInfo *L : Loops)
3484     L->invalidate();
3485 
3486 #ifndef NDEBUG
3487   for (CanonicalLoopInfo *GenL : Result)
3488     GenL->assertOK();
3489 #endif
3490   return Result;
3491 }
3492 
3493 /// Attach metadata \p Properties to the basic block described by \p BB. If the
3494 /// basic block already has metadata, the basic block properties are appended.
3495 static void addBasicBlockMetadata(BasicBlock *BB,
3496                                   ArrayRef<Metadata *> Properties) {
3497   // Nothing to do if no property to attach.
3498   if (Properties.empty())
3499     return;
3500 
3501   LLVMContext &Ctx = BB->getContext();
3502   SmallVector<Metadata *> NewProperties;
3503   NewProperties.push_back(nullptr);
3504 
3505   // If the basic block already has metadata, prepend it to the new metadata.
3506   MDNode *Existing = BB->getTerminator()->getMetadata(LLVMContext::MD_loop);
3507   if (Existing)
3508     append_range(NewProperties, drop_begin(Existing->operands(), 1));
3509 
3510   append_range(NewProperties, Properties);
3511   MDNode *BasicBlockID = MDNode::getDistinct(Ctx, NewProperties);
3512   BasicBlockID->replaceOperandWith(0, BasicBlockID);
3513 
3514   BB->getTerminator()->setMetadata(LLVMContext::MD_loop, BasicBlockID);
3515 }
3516 
3517 /// Attach loop metadata \p Properties to the loop described by \p Loop. If the
3518 /// loop already has metadata, the loop properties are appended.
3519 static void addLoopMetadata(CanonicalLoopInfo *Loop,
3520                             ArrayRef<Metadata *> Properties) {
3521   assert(Loop->isValid() && "Expecting a valid CanonicalLoopInfo");
3522 
3523   // Attach metadata to the loop's latch
3524   BasicBlock *Latch = Loop->getLatch();
3525   assert(Latch && "A valid CanonicalLoopInfo must have a unique latch");
3526   addBasicBlockMetadata(Latch, Properties);
3527 }
3528 
3529 /// Attach llvm.access.group metadata to the memref instructions of \p Block
3530 static void addSimdMetadata(BasicBlock *Block, MDNode *AccessGroup,
3531                             LoopInfo &LI) {
3532   for (Instruction &I : *Block) {
3533     if (I.mayReadOrWriteMemory()) {
3534       // TODO: This instruction may already have access group from
3535       // other pragmas e.g. #pragma clang loop vectorize.  Append
3536       // so that the existing metadata is not overwritten.
3537       I.setMetadata(LLVMContext::MD_access_group, AccessGroup);
3538     }
3539   }
3540 }
3541 
3542 void OpenMPIRBuilder::unrollLoopFull(DebugLoc, CanonicalLoopInfo *Loop) {
3543   LLVMContext &Ctx = Builder.getContext();
3544   addLoopMetadata(
3545       Loop, {MDNode::get(Ctx, MDString::get(Ctx, "llvm.loop.unroll.enable")),
3546              MDNode::get(Ctx, MDString::get(Ctx, "llvm.loop.unroll.full"))});
3547 }
3548 
3549 void OpenMPIRBuilder::unrollLoopHeuristic(DebugLoc, CanonicalLoopInfo *Loop) {
3550   LLVMContext &Ctx = Builder.getContext();
3551   addLoopMetadata(
3552       Loop, {
3553                 MDNode::get(Ctx, MDString::get(Ctx, "llvm.loop.unroll.enable")),
3554             });
3555 }
3556 
3557 void OpenMPIRBuilder::createIfVersion(CanonicalLoopInfo *CanonicalLoop,
3558                                       Value *IfCond, ValueToValueMapTy &VMap,
3559                                       const Twine &NamePrefix) {
3560   Function *F = CanonicalLoop->getFunction();
3561 
3562   // Define where if branch should be inserted
3563   Instruction *SplitBefore;
3564   if (Instruction::classof(IfCond)) {
3565     SplitBefore = dyn_cast<Instruction>(IfCond);
3566   } else {
3567     SplitBefore = CanonicalLoop->getPreheader()->getTerminator();
3568   }
3569 
3570   // TODO: We should not rely on pass manager. Currently we use pass manager
3571   // only for getting llvm::Loop which corresponds to given CanonicalLoopInfo
3572   // object. We should have a method  which returns all blocks between
3573   // CanonicalLoopInfo::getHeader() and CanonicalLoopInfo::getAfter()
3574   FunctionAnalysisManager FAM;
3575   FAM.registerPass([]() { return DominatorTreeAnalysis(); });
3576   FAM.registerPass([]() { return LoopAnalysis(); });
3577   FAM.registerPass([]() { return PassInstrumentationAnalysis(); });
3578 
3579   // Get the loop which needs to be cloned
3580   LoopAnalysis LIA;
3581   LoopInfo &&LI = LIA.run(*F, FAM);
3582   Loop *L = LI.getLoopFor(CanonicalLoop->getHeader());
3583 
3584   // Create additional blocks for the if statement
3585   BasicBlock *Head = SplitBefore->getParent();
3586   Instruction *HeadOldTerm = Head->getTerminator();
3587   llvm::LLVMContext &C = Head->getContext();
3588   llvm::BasicBlock *ThenBlock = llvm::BasicBlock::Create(
3589       C, NamePrefix + ".if.then", Head->getParent(), Head->getNextNode());
3590   llvm::BasicBlock *ElseBlock = llvm::BasicBlock::Create(
3591       C, NamePrefix + ".if.else", Head->getParent(), CanonicalLoop->getExit());
3592 
3593   // Create if condition branch.
3594   Builder.SetInsertPoint(HeadOldTerm);
3595   Instruction *BrInstr =
3596       Builder.CreateCondBr(IfCond, ThenBlock, /*ifFalse*/ ElseBlock);
3597   InsertPointTy IP{BrInstr->getParent(), ++BrInstr->getIterator()};
3598   // Then block contains branch to omp loop which needs to be vectorized
3599   spliceBB(IP, ThenBlock, false);
3600   ThenBlock->replaceSuccessorsPhiUsesWith(Head, ThenBlock);
3601 
3602   Builder.SetInsertPoint(ElseBlock);
3603 
3604   // Clone loop for the else branch
3605   SmallVector<BasicBlock *, 8> NewBlocks;
3606 
3607   VMap[CanonicalLoop->getPreheader()] = ElseBlock;
3608   for (BasicBlock *Block : L->getBlocks()) {
3609     BasicBlock *NewBB = CloneBasicBlock(Block, VMap, "", F);
3610     NewBB->moveBefore(CanonicalLoop->getExit());
3611     VMap[Block] = NewBB;
3612     NewBlocks.push_back(NewBB);
3613   }
3614   remapInstructionsInBlocks(NewBlocks, VMap);
3615   Builder.CreateBr(NewBlocks.front());
3616 }
3617 
3618 unsigned
3619 OpenMPIRBuilder::getOpenMPDefaultSimdAlign(const Triple &TargetTriple,
3620                                            const StringMap<bool> &Features) {
3621   if (TargetTriple.isX86()) {
3622     if (Features.lookup("avx512f"))
3623       return 512;
3624     else if (Features.lookup("avx"))
3625       return 256;
3626     return 128;
3627   }
3628   if (TargetTriple.isPPC())
3629     return 128;
3630   if (TargetTriple.isWasm())
3631     return 128;
3632   return 0;
3633 }
3634 
3635 void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop,
3636                                 MapVector<Value *, Value *> AlignedVars,
3637                                 Value *IfCond, OrderKind Order,
3638                                 ConstantInt *Simdlen, ConstantInt *Safelen) {
3639   LLVMContext &Ctx = Builder.getContext();
3640 
3641   Function *F = CanonicalLoop->getFunction();
3642 
3643   // TODO: We should not rely on pass manager. Currently we use pass manager
3644   // only for getting llvm::Loop which corresponds to given CanonicalLoopInfo
3645   // object. We should have a method  which returns all blocks between
3646   // CanonicalLoopInfo::getHeader() and CanonicalLoopInfo::getAfter()
3647   FunctionAnalysisManager FAM;
3648   FAM.registerPass([]() { return DominatorTreeAnalysis(); });
3649   FAM.registerPass([]() { return LoopAnalysis(); });
3650   FAM.registerPass([]() { return PassInstrumentationAnalysis(); });
3651 
3652   LoopAnalysis LIA;
3653   LoopInfo &&LI = LIA.run(*F, FAM);
3654 
3655   Loop *L = LI.getLoopFor(CanonicalLoop->getHeader());
3656   if (AlignedVars.size()) {
3657     InsertPointTy IP = Builder.saveIP();
3658     Builder.SetInsertPoint(CanonicalLoop->getPreheader()->getTerminator());
3659     for (auto &AlignedItem : AlignedVars) {
3660       Value *AlignedPtr = AlignedItem.first;
3661       Value *Alignment = AlignedItem.second;
3662       Builder.CreateAlignmentAssumption(F->getParent()->getDataLayout(),
3663                                         AlignedPtr, Alignment);
3664     }
3665     Builder.restoreIP(IP);
3666   }
3667 
3668   if (IfCond) {
3669     ValueToValueMapTy VMap;
3670     createIfVersion(CanonicalLoop, IfCond, VMap, "simd");
3671     // Add metadata to the cloned loop which disables vectorization
3672     Value *MappedLatch = VMap.lookup(CanonicalLoop->getLatch());
3673     assert(MappedLatch &&
3674            "Cannot find value which corresponds to original loop latch");
3675     assert(isa<BasicBlock>(MappedLatch) &&
3676            "Cannot cast mapped latch block value to BasicBlock");
3677     BasicBlock *NewLatchBlock = dyn_cast<BasicBlock>(MappedLatch);
3678     ConstantAsMetadata *BoolConst =
3679         ConstantAsMetadata::get(ConstantInt::getFalse(Type::getInt1Ty(Ctx)));
3680     addBasicBlockMetadata(
3681         NewLatchBlock,
3682         {MDNode::get(Ctx, {MDString::get(Ctx, "llvm.loop.vectorize.enable"),
3683                            BoolConst})});
3684   }
3685 
3686   SmallSet<BasicBlock *, 8> Reachable;
3687 
3688   // Get the basic blocks from the loop in which memref instructions
3689   // can be found.
3690   // TODO: Generalize getting all blocks inside a CanonicalizeLoopInfo,
3691   // preferably without running any passes.
3692   for (BasicBlock *Block : L->getBlocks()) {
3693     if (Block == CanonicalLoop->getCond() ||
3694         Block == CanonicalLoop->getHeader())
3695       continue;
3696     Reachable.insert(Block);
3697   }
3698 
3699   SmallVector<Metadata *> LoopMDList;
3700 
3701   // In presence of finite 'safelen', it may be unsafe to mark all
3702   // the memory instructions parallel, because loop-carried
3703   // dependences of 'safelen' iterations are possible.
3704   // If clause order(concurrent) is specified then the memory instructions
3705   // are marked parallel even if 'safelen' is finite.
3706   if ((Safelen == nullptr) || (Order == OrderKind::OMP_ORDER_concurrent)) {
3707     // Add access group metadata to memory-access instructions.
3708     MDNode *AccessGroup = MDNode::getDistinct(Ctx, {});
3709     for (BasicBlock *BB : Reachable)
3710       addSimdMetadata(BB, AccessGroup, LI);
3711     // TODO:  If the loop has existing parallel access metadata, have
3712     // to combine two lists.
3713     LoopMDList.push_back(MDNode::get(
3714         Ctx, {MDString::get(Ctx, "llvm.loop.parallel_accesses"), AccessGroup}));
3715   }
3716 
3717   // Use the above access group metadata to create loop level
3718   // metadata, which should be distinct for each loop.
3719   ConstantAsMetadata *BoolConst =
3720       ConstantAsMetadata::get(ConstantInt::getTrue(Type::getInt1Ty(Ctx)));
3721   LoopMDList.push_back(MDNode::get(
3722       Ctx, {MDString::get(Ctx, "llvm.loop.vectorize.enable"), BoolConst}));
3723 
3724   if (Simdlen || Safelen) {
3725     // If both simdlen and safelen clauses are specified, the value of the
3726     // simdlen parameter must be less than or equal to the value of the safelen
3727     // parameter. Therefore, use safelen only in the absence of simdlen.
3728     ConstantInt *VectorizeWidth = Simdlen == nullptr ? Safelen : Simdlen;
3729     LoopMDList.push_back(
3730         MDNode::get(Ctx, {MDString::get(Ctx, "llvm.loop.vectorize.width"),
3731                           ConstantAsMetadata::get(VectorizeWidth)}));
3732   }
3733 
3734   addLoopMetadata(CanonicalLoop, LoopMDList);
3735 }
3736 
3737 /// Create the TargetMachine object to query the backend for optimization
3738 /// preferences.
3739 ///
3740 /// Ideally, this would be passed from the front-end to the OpenMPBuilder, but
3741 /// e.g. Clang does not pass it to its CodeGen layer and creates it only when
3742 /// needed for the LLVM pass pipline. We use some default options to avoid
3743 /// having to pass too many settings from the frontend that probably do not
3744 /// matter.
3745 ///
3746 /// Currently, TargetMachine is only used sometimes by the unrollLoopPartial
3747 /// method. If we are going to use TargetMachine for more purposes, especially
3748 /// those that are sensitive to TargetOptions, RelocModel and CodeModel, it
3749 /// might become be worth requiring front-ends to pass on their TargetMachine,
3750 /// or at least cache it between methods. Note that while fontends such as Clang
3751 /// have just a single main TargetMachine per translation unit, "target-cpu" and
3752 /// "target-features" that determine the TargetMachine are per-function and can
3753 /// be overrided using __attribute__((target("OPTIONS"))).
3754 static std::unique_ptr<TargetMachine>
3755 createTargetMachine(Function *F, CodeGenOptLevel OptLevel) {
3756   Module *M = F->getParent();
3757 
3758   StringRef CPU = F->getFnAttribute("target-cpu").getValueAsString();
3759   StringRef Features = F->getFnAttribute("target-features").getValueAsString();
3760   const std::string &Triple = M->getTargetTriple();
3761 
3762   std::string Error;
3763   const llvm::Target *TheTarget = TargetRegistry::lookupTarget(Triple, Error);
3764   if (!TheTarget)
3765     return {};
3766 
3767   llvm::TargetOptions Options;
3768   return std::unique_ptr<TargetMachine>(TheTarget->createTargetMachine(
3769       Triple, CPU, Features, Options, /*RelocModel=*/std::nullopt,
3770       /*CodeModel=*/std::nullopt, OptLevel));
3771 }
3772 
3773 /// Heuristically determine the best-performant unroll factor for \p CLI. This
3774 /// depends on the target processor. We are re-using the same heuristics as the
3775 /// LoopUnrollPass.
3776 static int32_t computeHeuristicUnrollFactor(CanonicalLoopInfo *CLI) {
3777   Function *F = CLI->getFunction();
3778 
3779   // Assume the user requests the most aggressive unrolling, even if the rest of
3780   // the code is optimized using a lower setting.
3781   CodeGenOptLevel OptLevel = CodeGenOptLevel::Aggressive;
3782   std::unique_ptr<TargetMachine> TM = createTargetMachine(F, OptLevel);
3783 
3784   FunctionAnalysisManager FAM;
3785   FAM.registerPass([]() { return TargetLibraryAnalysis(); });
3786   FAM.registerPass([]() { return AssumptionAnalysis(); });
3787   FAM.registerPass([]() { return DominatorTreeAnalysis(); });
3788   FAM.registerPass([]() { return LoopAnalysis(); });
3789   FAM.registerPass([]() { return ScalarEvolutionAnalysis(); });
3790   FAM.registerPass([]() { return PassInstrumentationAnalysis(); });
3791   TargetIRAnalysis TIRA;
3792   if (TM)
3793     TIRA = TargetIRAnalysis(
3794         [&](const Function &F) { return TM->getTargetTransformInfo(F); });
3795   FAM.registerPass([&]() { return TIRA; });
3796 
3797   TargetIRAnalysis::Result &&TTI = TIRA.run(*F, FAM);
3798   ScalarEvolutionAnalysis SEA;
3799   ScalarEvolution &&SE = SEA.run(*F, FAM);
3800   DominatorTreeAnalysis DTA;
3801   DominatorTree &&DT = DTA.run(*F, FAM);
3802   LoopAnalysis LIA;
3803   LoopInfo &&LI = LIA.run(*F, FAM);
3804   AssumptionAnalysis ACT;
3805   AssumptionCache &&AC = ACT.run(*F, FAM);
3806   OptimizationRemarkEmitter ORE{F};
3807 
3808   Loop *L = LI.getLoopFor(CLI->getHeader());
3809   assert(L && "Expecting CanonicalLoopInfo to be recognized as a loop");
3810 
3811   TargetTransformInfo::UnrollingPreferences UP =
3812       gatherUnrollingPreferences(L, SE, TTI,
3813                                  /*BlockFrequencyInfo=*/nullptr,
3814                                  /*ProfileSummaryInfo=*/nullptr, ORE, static_cast<int>(OptLevel),
3815                                  /*UserThreshold=*/std::nullopt,
3816                                  /*UserCount=*/std::nullopt,
3817                                  /*UserAllowPartial=*/true,
3818                                  /*UserAllowRuntime=*/true,
3819                                  /*UserUpperBound=*/std::nullopt,
3820                                  /*UserFullUnrollMaxCount=*/std::nullopt);
3821 
3822   UP.Force = true;
3823 
3824   // Account for additional optimizations taking place before the LoopUnrollPass
3825   // would unroll the loop.
3826   UP.Threshold *= UnrollThresholdFactor;
3827   UP.PartialThreshold *= UnrollThresholdFactor;
3828 
3829   // Use normal unroll factors even if the rest of the code is optimized for
3830   // size.
3831   UP.OptSizeThreshold = UP.Threshold;
3832   UP.PartialOptSizeThreshold = UP.PartialThreshold;
3833 
3834   LLVM_DEBUG(dbgs() << "Unroll heuristic thresholds:\n"
3835                     << "  Threshold=" << UP.Threshold << "\n"
3836                     << "  PartialThreshold=" << UP.PartialThreshold << "\n"
3837                     << "  OptSizeThreshold=" << UP.OptSizeThreshold << "\n"
3838                     << "  PartialOptSizeThreshold="
3839                     << UP.PartialOptSizeThreshold << "\n");
3840 
3841   // Disable peeling.
3842   TargetTransformInfo::PeelingPreferences PP =
3843       gatherPeelingPreferences(L, SE, TTI,
3844                                /*UserAllowPeeling=*/false,
3845                                /*UserAllowProfileBasedPeeling=*/false,
3846                                /*UnrollingSpecficValues=*/false);
3847 
3848   SmallPtrSet<const Value *, 32> EphValues;
3849   CodeMetrics::collectEphemeralValues(L, &AC, EphValues);
3850 
3851   // Assume that reads and writes to stack variables can be eliminated by
3852   // Mem2Reg, SROA or LICM. That is, don't count them towards the loop body's
3853   // size.
3854   for (BasicBlock *BB : L->blocks()) {
3855     for (Instruction &I : *BB) {
3856       Value *Ptr;
3857       if (auto *Load = dyn_cast<LoadInst>(&I)) {
3858         Ptr = Load->getPointerOperand();
3859       } else if (auto *Store = dyn_cast<StoreInst>(&I)) {
3860         Ptr = Store->getPointerOperand();
3861       } else
3862         continue;
3863 
3864       Ptr = Ptr->stripPointerCasts();
3865 
3866       if (auto *Alloca = dyn_cast<AllocaInst>(Ptr)) {
3867         if (Alloca->getParent() == &F->getEntryBlock())
3868           EphValues.insert(&I);
3869       }
3870     }
3871   }
3872 
3873   UnrollCostEstimator UCE(L, TTI, EphValues, UP.BEInsns);
3874 
3875   // Loop is not unrollable if the loop contains certain instructions.
3876   if (!UCE.canUnroll() || UCE.Convergent) {
3877     LLVM_DEBUG(dbgs() << "Loop not considered unrollable\n");
3878     return 1;
3879   }
3880 
3881   LLVM_DEBUG(dbgs() << "Estimated loop size is " << UCE.getRolledLoopSize()
3882                     << "\n");
3883 
3884   // TODO: Determine trip count of \p CLI if constant, computeUnrollCount might
3885   // be able to use it.
3886   int TripCount = 0;
3887   int MaxTripCount = 0;
3888   bool MaxOrZero = false;
3889   unsigned TripMultiple = 0;
3890 
3891   bool UseUpperBound = false;
3892   computeUnrollCount(L, TTI, DT, &LI, &AC, SE, EphValues, &ORE, TripCount,
3893                      MaxTripCount, MaxOrZero, TripMultiple, UCE, UP, PP,
3894                      UseUpperBound);
3895   unsigned Factor = UP.Count;
3896   LLVM_DEBUG(dbgs() << "Suggesting unroll factor of " << Factor << "\n");
3897 
3898   // This function returns 1 to signal to not unroll a loop.
3899   if (Factor == 0)
3900     return 1;
3901   return Factor;
3902 }
3903 
3904 void OpenMPIRBuilder::unrollLoopPartial(DebugLoc DL, CanonicalLoopInfo *Loop,
3905                                         int32_t Factor,
3906                                         CanonicalLoopInfo **UnrolledCLI) {
3907   assert(Factor >= 0 && "Unroll factor must not be negative");
3908 
3909   Function *F = Loop->getFunction();
3910   LLVMContext &Ctx = F->getContext();
3911 
3912   // If the unrolled loop is not used for another loop-associated directive, it
3913   // is sufficient to add metadata for the LoopUnrollPass.
3914   if (!UnrolledCLI) {
3915     SmallVector<Metadata *, 2> LoopMetadata;
3916     LoopMetadata.push_back(
3917         MDNode::get(Ctx, MDString::get(Ctx, "llvm.loop.unroll.enable")));
3918 
3919     if (Factor >= 1) {
3920       ConstantAsMetadata *FactorConst = ConstantAsMetadata::get(
3921           ConstantInt::get(Type::getInt32Ty(Ctx), APInt(32, Factor)));
3922       LoopMetadata.push_back(MDNode::get(
3923           Ctx, {MDString::get(Ctx, "llvm.loop.unroll.count"), FactorConst}));
3924     }
3925 
3926     addLoopMetadata(Loop, LoopMetadata);
3927     return;
3928   }
3929 
3930   // Heuristically determine the unroll factor.
3931   if (Factor == 0)
3932     Factor = computeHeuristicUnrollFactor(Loop);
3933 
3934   // No change required with unroll factor 1.
3935   if (Factor == 1) {
3936     *UnrolledCLI = Loop;
3937     return;
3938   }
3939 
3940   assert(Factor >= 2 &&
3941          "unrolling only makes sense with a factor of 2 or larger");
3942 
3943   Type *IndVarTy = Loop->getIndVarType();
3944 
3945   // Apply partial unrolling by tiling the loop by the unroll-factor, then fully
3946   // unroll the inner loop.
3947   Value *FactorVal =
3948       ConstantInt::get(IndVarTy, APInt(IndVarTy->getIntegerBitWidth(), Factor,
3949                                        /*isSigned=*/false));
3950   std::vector<CanonicalLoopInfo *> LoopNest =
3951       tileLoops(DL, {Loop}, {FactorVal});
3952   assert(LoopNest.size() == 2 && "Expect 2 loops after tiling");
3953   *UnrolledCLI = LoopNest[0];
3954   CanonicalLoopInfo *InnerLoop = LoopNest[1];
3955 
3956   // LoopUnrollPass can only fully unroll loops with constant trip count.
3957   // Unroll by the unroll factor with a fallback epilog for the remainder
3958   // iterations if necessary.
3959   ConstantAsMetadata *FactorConst = ConstantAsMetadata::get(
3960       ConstantInt::get(Type::getInt32Ty(Ctx), APInt(32, Factor)));
3961   addLoopMetadata(
3962       InnerLoop,
3963       {MDNode::get(Ctx, MDString::get(Ctx, "llvm.loop.unroll.enable")),
3964        MDNode::get(
3965            Ctx, {MDString::get(Ctx, "llvm.loop.unroll.count"), FactorConst})});
3966 
3967 #ifndef NDEBUG
3968   (*UnrolledCLI)->assertOK();
3969 #endif
3970 }
3971 
3972 OpenMPIRBuilder::InsertPointTy
3973 OpenMPIRBuilder::createCopyPrivate(const LocationDescription &Loc,
3974                                    llvm::Value *BufSize, llvm::Value *CpyBuf,
3975                                    llvm::Value *CpyFn, llvm::Value *DidIt) {
3976   if (!updateToLocation(Loc))
3977     return Loc.IP;
3978 
3979   uint32_t SrcLocStrSize;
3980   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3981   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3982   Value *ThreadId = getOrCreateThreadID(Ident);
3983 
3984   llvm::Value *DidItLD = Builder.CreateLoad(Builder.getInt32Ty(), DidIt);
3985 
3986   Value *Args[] = {Ident, ThreadId, BufSize, CpyBuf, CpyFn, DidItLD};
3987 
3988   Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_copyprivate);
3989   Builder.CreateCall(Fn, Args);
3990 
3991   return Builder.saveIP();
3992 }
3993 
3994 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createSingle(
3995     const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
3996     FinalizeCallbackTy FiniCB, bool IsNowait, llvm::Value *DidIt) {
3997 
3998   if (!updateToLocation(Loc))
3999     return Loc.IP;
4000 
4001   // If needed (i.e. not null), initialize `DidIt` with 0
4002   if (DidIt) {
4003     Builder.CreateStore(Builder.getInt32(0), DidIt);
4004   }
4005 
4006   Directive OMPD = Directive::OMPD_single;
4007   uint32_t SrcLocStrSize;
4008   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4009   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4010   Value *ThreadId = getOrCreateThreadID(Ident);
4011   Value *Args[] = {Ident, ThreadId};
4012 
4013   Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_single);
4014   Instruction *EntryCall = Builder.CreateCall(EntryRTLFn, Args);
4015 
4016   Function *ExitRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_single);
4017   Instruction *ExitCall = Builder.CreateCall(ExitRTLFn, Args);
4018 
4019   // generates the following:
4020   // if (__kmpc_single()) {
4021   //		.... single region ...
4022   // 		__kmpc_end_single
4023   // }
4024   // __kmpc_barrier
4025 
4026   EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
4027                        /*Conditional*/ true,
4028                        /*hasFinalize*/ true);
4029   if (!IsNowait)
4030     createBarrier(LocationDescription(Builder.saveIP(), Loc.DL),
4031                   omp::Directive::OMPD_unknown, /* ForceSimpleCall */ false,
4032                   /* CheckCancelFlag */ false);
4033   return Builder.saveIP();
4034 }
4035 
4036 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createCritical(
4037     const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
4038     FinalizeCallbackTy FiniCB, StringRef CriticalName, Value *HintInst) {
4039 
4040   if (!updateToLocation(Loc))
4041     return Loc.IP;
4042 
4043   Directive OMPD = Directive::OMPD_critical;
4044   uint32_t SrcLocStrSize;
4045   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4046   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4047   Value *ThreadId = getOrCreateThreadID(Ident);
4048   Value *LockVar = getOMPCriticalRegionLock(CriticalName);
4049   Value *Args[] = {Ident, ThreadId, LockVar};
4050 
4051   SmallVector<llvm::Value *, 4> EnterArgs(std::begin(Args), std::end(Args));
4052   Function *RTFn = nullptr;
4053   if (HintInst) {
4054     // Add Hint to entry Args and create call
4055     EnterArgs.push_back(HintInst);
4056     RTFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_critical_with_hint);
4057   } else {
4058     RTFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_critical);
4059   }
4060   Instruction *EntryCall = Builder.CreateCall(RTFn, EnterArgs);
4061 
4062   Function *ExitRTLFn =
4063       getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_critical);
4064   Instruction *ExitCall = Builder.CreateCall(ExitRTLFn, Args);
4065 
4066   return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
4067                               /*Conditional*/ false, /*hasFinalize*/ true);
4068 }
4069 
4070 OpenMPIRBuilder::InsertPointTy
4071 OpenMPIRBuilder::createOrderedDepend(const LocationDescription &Loc,
4072                                      InsertPointTy AllocaIP, unsigned NumLoops,
4073                                      ArrayRef<llvm::Value *> StoreValues,
4074                                      const Twine &Name, bool IsDependSource) {
4075   assert(
4076       llvm::all_of(StoreValues,
4077                    [](Value *SV) { return SV->getType()->isIntegerTy(64); }) &&
4078       "OpenMP runtime requires depend vec with i64 type");
4079 
4080   if (!updateToLocation(Loc))
4081     return Loc.IP;
4082 
4083   // Allocate space for vector and generate alloc instruction.
4084   auto *ArrI64Ty = ArrayType::get(Int64, NumLoops);
4085   Builder.restoreIP(AllocaIP);
4086   AllocaInst *ArgsBase = Builder.CreateAlloca(ArrI64Ty, nullptr, Name);
4087   ArgsBase->setAlignment(Align(8));
4088   Builder.restoreIP(Loc.IP);
4089 
4090   // Store the index value with offset in depend vector.
4091   for (unsigned I = 0; I < NumLoops; ++I) {
4092     Value *DependAddrGEPIter = Builder.CreateInBoundsGEP(
4093         ArrI64Ty, ArgsBase, {Builder.getInt64(0), Builder.getInt64(I)});
4094     StoreInst *STInst = Builder.CreateStore(StoreValues[I], DependAddrGEPIter);
4095     STInst->setAlignment(Align(8));
4096   }
4097 
4098   Value *DependBaseAddrGEP = Builder.CreateInBoundsGEP(
4099       ArrI64Ty, ArgsBase, {Builder.getInt64(0), Builder.getInt64(0)});
4100 
4101   uint32_t SrcLocStrSize;
4102   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4103   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4104   Value *ThreadId = getOrCreateThreadID(Ident);
4105   Value *Args[] = {Ident, ThreadId, DependBaseAddrGEP};
4106 
4107   Function *RTLFn = nullptr;
4108   if (IsDependSource)
4109     RTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_doacross_post);
4110   else
4111     RTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_doacross_wait);
4112   Builder.CreateCall(RTLFn, Args);
4113 
4114   return Builder.saveIP();
4115 }
4116 
4117 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createOrderedThreadsSimd(
4118     const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
4119     FinalizeCallbackTy FiniCB, bool IsThreads) {
4120   if (!updateToLocation(Loc))
4121     return Loc.IP;
4122 
4123   Directive OMPD = Directive::OMPD_ordered;
4124   Instruction *EntryCall = nullptr;
4125   Instruction *ExitCall = nullptr;
4126 
4127   if (IsThreads) {
4128     uint32_t SrcLocStrSize;
4129     Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4130     Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4131     Value *ThreadId = getOrCreateThreadID(Ident);
4132     Value *Args[] = {Ident, ThreadId};
4133 
4134     Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_ordered);
4135     EntryCall = Builder.CreateCall(EntryRTLFn, Args);
4136 
4137     Function *ExitRTLFn =
4138         getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_ordered);
4139     ExitCall = Builder.CreateCall(ExitRTLFn, Args);
4140   }
4141 
4142   return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
4143                               /*Conditional*/ false, /*hasFinalize*/ true);
4144 }
4145 
4146 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::EmitOMPInlinedRegion(
4147     Directive OMPD, Instruction *EntryCall, Instruction *ExitCall,
4148     BodyGenCallbackTy BodyGenCB, FinalizeCallbackTy FiniCB, bool Conditional,
4149     bool HasFinalize, bool IsCancellable) {
4150 
4151   if (HasFinalize)
4152     FinalizationStack.push_back({FiniCB, OMPD, IsCancellable});
4153 
4154   // Create inlined region's entry and body blocks, in preparation
4155   // for conditional creation
4156   BasicBlock *EntryBB = Builder.GetInsertBlock();
4157   Instruction *SplitPos = EntryBB->getTerminator();
4158   if (!isa_and_nonnull<BranchInst>(SplitPos))
4159     SplitPos = new UnreachableInst(Builder.getContext(), EntryBB);
4160   BasicBlock *ExitBB = EntryBB->splitBasicBlock(SplitPos, "omp_region.end");
4161   BasicBlock *FiniBB =
4162       EntryBB->splitBasicBlock(EntryBB->getTerminator(), "omp_region.finalize");
4163 
4164   Builder.SetInsertPoint(EntryBB->getTerminator());
4165   emitCommonDirectiveEntry(OMPD, EntryCall, ExitBB, Conditional);
4166 
4167   // generate body
4168   BodyGenCB(/* AllocaIP */ InsertPointTy(),
4169             /* CodeGenIP */ Builder.saveIP());
4170 
4171   // emit exit call and do any needed finalization.
4172   auto FinIP = InsertPointTy(FiniBB, FiniBB->getFirstInsertionPt());
4173   assert(FiniBB->getTerminator()->getNumSuccessors() == 1 &&
4174          FiniBB->getTerminator()->getSuccessor(0) == ExitBB &&
4175          "Unexpected control flow graph state!!");
4176   emitCommonDirectiveExit(OMPD, FinIP, ExitCall, HasFinalize);
4177   assert(FiniBB->getUniquePredecessor()->getUniqueSuccessor() == FiniBB &&
4178          "Unexpected Control Flow State!");
4179   MergeBlockIntoPredecessor(FiniBB);
4180 
4181   // If we are skipping the region of a non conditional, remove the exit
4182   // block, and clear the builder's insertion point.
4183   assert(SplitPos->getParent() == ExitBB &&
4184          "Unexpected Insertion point location!");
4185   auto merged = MergeBlockIntoPredecessor(ExitBB);
4186   BasicBlock *ExitPredBB = SplitPos->getParent();
4187   auto InsertBB = merged ? ExitPredBB : ExitBB;
4188   if (!isa_and_nonnull<BranchInst>(SplitPos))
4189     SplitPos->eraseFromParent();
4190   Builder.SetInsertPoint(InsertBB);
4191 
4192   return Builder.saveIP();
4193 }
4194 
4195 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitCommonDirectiveEntry(
4196     Directive OMPD, Value *EntryCall, BasicBlock *ExitBB, bool Conditional) {
4197   // if nothing to do, Return current insertion point.
4198   if (!Conditional || !EntryCall)
4199     return Builder.saveIP();
4200 
4201   BasicBlock *EntryBB = Builder.GetInsertBlock();
4202   Value *CallBool = Builder.CreateIsNotNull(EntryCall);
4203   auto *ThenBB = BasicBlock::Create(M.getContext(), "omp_region.body");
4204   auto *UI = new UnreachableInst(Builder.getContext(), ThenBB);
4205 
4206   // Emit thenBB and set the Builder's insertion point there for
4207   // body generation next. Place the block after the current block.
4208   Function *CurFn = EntryBB->getParent();
4209   CurFn->insert(std::next(EntryBB->getIterator()), ThenBB);
4210 
4211   // Move Entry branch to end of ThenBB, and replace with conditional
4212   // branch (If-stmt)
4213   Instruction *EntryBBTI = EntryBB->getTerminator();
4214   Builder.CreateCondBr(CallBool, ThenBB, ExitBB);
4215   EntryBBTI->removeFromParent();
4216   Builder.SetInsertPoint(UI);
4217   Builder.Insert(EntryBBTI);
4218   UI->eraseFromParent();
4219   Builder.SetInsertPoint(ThenBB->getTerminator());
4220 
4221   // return an insertion point to ExitBB.
4222   return IRBuilder<>::InsertPoint(ExitBB, ExitBB->getFirstInsertionPt());
4223 }
4224 
4225 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitCommonDirectiveExit(
4226     omp::Directive OMPD, InsertPointTy FinIP, Instruction *ExitCall,
4227     bool HasFinalize) {
4228 
4229   Builder.restoreIP(FinIP);
4230 
4231   // If there is finalization to do, emit it before the exit call
4232   if (HasFinalize) {
4233     assert(!FinalizationStack.empty() &&
4234            "Unexpected finalization stack state!");
4235 
4236     FinalizationInfo Fi = FinalizationStack.pop_back_val();
4237     assert(Fi.DK == OMPD && "Unexpected Directive for Finalization call!");
4238 
4239     Fi.FiniCB(FinIP);
4240 
4241     BasicBlock *FiniBB = FinIP.getBlock();
4242     Instruction *FiniBBTI = FiniBB->getTerminator();
4243 
4244     // set Builder IP for call creation
4245     Builder.SetInsertPoint(FiniBBTI);
4246   }
4247 
4248   if (!ExitCall)
4249     return Builder.saveIP();
4250 
4251   // place the Exitcall as last instruction before Finalization block terminator
4252   ExitCall->removeFromParent();
4253   Builder.Insert(ExitCall);
4254 
4255   return IRBuilder<>::InsertPoint(ExitCall->getParent(),
4256                                   ExitCall->getIterator());
4257 }
4258 
4259 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createCopyinClauseBlocks(
4260     InsertPointTy IP, Value *MasterAddr, Value *PrivateAddr,
4261     llvm::IntegerType *IntPtrTy, bool BranchtoEnd) {
4262   if (!IP.isSet())
4263     return IP;
4264 
4265   IRBuilder<>::InsertPointGuard IPG(Builder);
4266 
4267   // creates the following CFG structure
4268   //	   OMP_Entry : (MasterAddr != PrivateAddr)?
4269   //       F     T
4270   //       |      \
4271   //       |     copin.not.master
4272   //       |      /
4273   //       v     /
4274   //   copyin.not.master.end
4275   //		     |
4276   //         v
4277   //   OMP.Entry.Next
4278 
4279   BasicBlock *OMP_Entry = IP.getBlock();
4280   Function *CurFn = OMP_Entry->getParent();
4281   BasicBlock *CopyBegin =
4282       BasicBlock::Create(M.getContext(), "copyin.not.master", CurFn);
4283   BasicBlock *CopyEnd = nullptr;
4284 
4285   // If entry block is terminated, split to preserve the branch to following
4286   // basic block (i.e. OMP.Entry.Next), otherwise, leave everything as is.
4287   if (isa_and_nonnull<BranchInst>(OMP_Entry->getTerminator())) {
4288     CopyEnd = OMP_Entry->splitBasicBlock(OMP_Entry->getTerminator(),
4289                                          "copyin.not.master.end");
4290     OMP_Entry->getTerminator()->eraseFromParent();
4291   } else {
4292     CopyEnd =
4293         BasicBlock::Create(M.getContext(), "copyin.not.master.end", CurFn);
4294   }
4295 
4296   Builder.SetInsertPoint(OMP_Entry);
4297   Value *MasterPtr = Builder.CreatePtrToInt(MasterAddr, IntPtrTy);
4298   Value *PrivatePtr = Builder.CreatePtrToInt(PrivateAddr, IntPtrTy);
4299   Value *cmp = Builder.CreateICmpNE(MasterPtr, PrivatePtr);
4300   Builder.CreateCondBr(cmp, CopyBegin, CopyEnd);
4301 
4302   Builder.SetInsertPoint(CopyBegin);
4303   if (BranchtoEnd)
4304     Builder.SetInsertPoint(Builder.CreateBr(CopyEnd));
4305 
4306   return Builder.saveIP();
4307 }
4308 
4309 CallInst *OpenMPIRBuilder::createOMPAlloc(const LocationDescription &Loc,
4310                                           Value *Size, Value *Allocator,
4311                                           std::string Name) {
4312   IRBuilder<>::InsertPointGuard IPG(Builder);
4313   Builder.restoreIP(Loc.IP);
4314 
4315   uint32_t SrcLocStrSize;
4316   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4317   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4318   Value *ThreadId = getOrCreateThreadID(Ident);
4319   Value *Args[] = {ThreadId, Size, Allocator};
4320 
4321   Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_alloc);
4322 
4323   return Builder.CreateCall(Fn, Args, Name);
4324 }
4325 
4326 CallInst *OpenMPIRBuilder::createOMPFree(const LocationDescription &Loc,
4327                                          Value *Addr, Value *Allocator,
4328                                          std::string Name) {
4329   IRBuilder<>::InsertPointGuard IPG(Builder);
4330   Builder.restoreIP(Loc.IP);
4331 
4332   uint32_t SrcLocStrSize;
4333   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4334   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4335   Value *ThreadId = getOrCreateThreadID(Ident);
4336   Value *Args[] = {ThreadId, Addr, Allocator};
4337   Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_free);
4338   return Builder.CreateCall(Fn, Args, Name);
4339 }
4340 
4341 CallInst *OpenMPIRBuilder::createOMPInteropInit(
4342     const LocationDescription &Loc, Value *InteropVar,
4343     omp::OMPInteropType InteropType, Value *Device, Value *NumDependences,
4344     Value *DependenceAddress, bool HaveNowaitClause) {
4345   IRBuilder<>::InsertPointGuard IPG(Builder);
4346   Builder.restoreIP(Loc.IP);
4347 
4348   uint32_t SrcLocStrSize;
4349   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4350   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4351   Value *ThreadId = getOrCreateThreadID(Ident);
4352   if (Device == nullptr)
4353     Device = ConstantInt::get(Int32, -1);
4354   Constant *InteropTypeVal = ConstantInt::get(Int32, (int)InteropType);
4355   if (NumDependences == nullptr) {
4356     NumDependences = ConstantInt::get(Int32, 0);
4357     PointerType *PointerTypeVar = PointerType::getUnqual(M.getContext());
4358     DependenceAddress = ConstantPointerNull::get(PointerTypeVar);
4359   }
4360   Value *HaveNowaitClauseVal = ConstantInt::get(Int32, HaveNowaitClause);
4361   Value *Args[] = {
4362       Ident,  ThreadId,       InteropVar,        InteropTypeVal,
4363       Device, NumDependences, DependenceAddress, HaveNowaitClauseVal};
4364 
4365   Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___tgt_interop_init);
4366 
4367   return Builder.CreateCall(Fn, Args);
4368 }
4369 
4370 CallInst *OpenMPIRBuilder::createOMPInteropDestroy(
4371     const LocationDescription &Loc, Value *InteropVar, Value *Device,
4372     Value *NumDependences, Value *DependenceAddress, bool HaveNowaitClause) {
4373   IRBuilder<>::InsertPointGuard IPG(Builder);
4374   Builder.restoreIP(Loc.IP);
4375 
4376   uint32_t SrcLocStrSize;
4377   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4378   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4379   Value *ThreadId = getOrCreateThreadID(Ident);
4380   if (Device == nullptr)
4381     Device = ConstantInt::get(Int32, -1);
4382   if (NumDependences == nullptr) {
4383     NumDependences = ConstantInt::get(Int32, 0);
4384     PointerType *PointerTypeVar = PointerType::getUnqual(M.getContext());
4385     DependenceAddress = ConstantPointerNull::get(PointerTypeVar);
4386   }
4387   Value *HaveNowaitClauseVal = ConstantInt::get(Int32, HaveNowaitClause);
4388   Value *Args[] = {
4389       Ident,          ThreadId,          InteropVar,         Device,
4390       NumDependences, DependenceAddress, HaveNowaitClauseVal};
4391 
4392   Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___tgt_interop_destroy);
4393 
4394   return Builder.CreateCall(Fn, Args);
4395 }
4396 
4397 CallInst *OpenMPIRBuilder::createOMPInteropUse(const LocationDescription &Loc,
4398                                                Value *InteropVar, Value *Device,
4399                                                Value *NumDependences,
4400                                                Value *DependenceAddress,
4401                                                bool HaveNowaitClause) {
4402   IRBuilder<>::InsertPointGuard IPG(Builder);
4403   Builder.restoreIP(Loc.IP);
4404   uint32_t SrcLocStrSize;
4405   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4406   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4407   Value *ThreadId = getOrCreateThreadID(Ident);
4408   if (Device == nullptr)
4409     Device = ConstantInt::get(Int32, -1);
4410   if (NumDependences == nullptr) {
4411     NumDependences = ConstantInt::get(Int32, 0);
4412     PointerType *PointerTypeVar = PointerType::getUnqual(M.getContext());
4413     DependenceAddress = ConstantPointerNull::get(PointerTypeVar);
4414   }
4415   Value *HaveNowaitClauseVal = ConstantInt::get(Int32, HaveNowaitClause);
4416   Value *Args[] = {
4417       Ident,          ThreadId,          InteropVar,         Device,
4418       NumDependences, DependenceAddress, HaveNowaitClauseVal};
4419 
4420   Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___tgt_interop_use);
4421 
4422   return Builder.CreateCall(Fn, Args);
4423 }
4424 
4425 CallInst *OpenMPIRBuilder::createCachedThreadPrivate(
4426     const LocationDescription &Loc, llvm::Value *Pointer,
4427     llvm::ConstantInt *Size, const llvm::Twine &Name) {
4428   IRBuilder<>::InsertPointGuard IPG(Builder);
4429   Builder.restoreIP(Loc.IP);
4430 
4431   uint32_t SrcLocStrSize;
4432   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4433   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4434   Value *ThreadId = getOrCreateThreadID(Ident);
4435   Constant *ThreadPrivateCache =
4436       getOrCreateInternalVariable(Int8PtrPtr, Name.str());
4437   llvm::Value *Args[] = {Ident, ThreadId, Pointer, Size, ThreadPrivateCache};
4438 
4439   Function *Fn =
4440       getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_threadprivate_cached);
4441 
4442   return Builder.CreateCall(Fn, Args);
4443 }
4444 
4445 OpenMPIRBuilder::InsertPointTy
4446 OpenMPIRBuilder::createTargetInit(const LocationDescription &Loc, bool IsSPMD,
4447                                   int32_t MinThreadsVal, int32_t MaxThreadsVal,
4448                                   int32_t MinTeamsVal, int32_t MaxTeamsVal) {
4449   if (!updateToLocation(Loc))
4450     return Loc.IP;
4451 
4452   uint32_t SrcLocStrSize;
4453   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4454   Constant *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4455   Constant *IsSPMDVal = ConstantInt::getSigned(
4456       Int8, IsSPMD ? OMP_TGT_EXEC_MODE_SPMD : OMP_TGT_EXEC_MODE_GENERIC);
4457   Constant *UseGenericStateMachineVal = ConstantInt::getSigned(Int8, !IsSPMD);
4458   Constant *MayUseNestedParallelismVal = ConstantInt::getSigned(Int8, true);
4459   Constant *DebugIndentionLevelVal = ConstantInt::getSigned(Int16, 0);
4460 
4461   Function *Kernel = Builder.GetInsertBlock()->getParent();
4462 
4463   // Manifest the launch configuration in the metadata matching the kernel
4464   // environment.
4465   if (MinTeamsVal > 1 || MaxTeamsVal > 0)
4466     writeTeamsForKernel(T, *Kernel, MinTeamsVal, MaxTeamsVal);
4467 
4468   // For max values, < 0 means unset, == 0 means set but unknown.
4469   if (MaxThreadsVal < 0)
4470     MaxThreadsVal = std::max(
4471         int32_t(getGridValue(T, Kernel).GV_Default_WG_Size), MinThreadsVal);
4472 
4473   if (MaxThreadsVal > 0)
4474     writeThreadBoundsForKernel(T, *Kernel, MinThreadsVal, MaxThreadsVal);
4475 
4476   Constant *MinThreads = ConstantInt::getSigned(Int32, MinThreadsVal);
4477   Constant *MaxThreads = ConstantInt::getSigned(Int32, MaxThreadsVal);
4478   Constant *MinTeams = ConstantInt::getSigned(Int32, MinTeamsVal);
4479   Constant *MaxTeams = ConstantInt::getSigned(Int32, MaxTeamsVal);
4480   Constant *ReductionDataSize = ConstantInt::getSigned(Int32, 0);
4481   Constant *ReductionBufferLength = ConstantInt::getSigned(Int32, 0);
4482 
4483   // We need to strip the debug prefix to get the correct kernel name.
4484   StringRef KernelName = Kernel->getName();
4485   const std::string DebugPrefix = "_debug__";
4486   if (KernelName.ends_with(DebugPrefix))
4487     KernelName = KernelName.drop_back(DebugPrefix.length());
4488 
4489   Function *Fn = getOrCreateRuntimeFunctionPtr(
4490       omp::RuntimeFunction::OMPRTL___kmpc_target_init);
4491   const DataLayout &DL = Fn->getParent()->getDataLayout();
4492 
4493   Twine DynamicEnvironmentName = KernelName + "_dynamic_environment";
4494   Constant *DynamicEnvironmentInitializer =
4495       ConstantStruct::get(DynamicEnvironment, {DebugIndentionLevelVal});
4496   GlobalVariable *DynamicEnvironmentGV = new GlobalVariable(
4497       M, DynamicEnvironment, /*IsConstant=*/false, GlobalValue::WeakODRLinkage,
4498       DynamicEnvironmentInitializer, DynamicEnvironmentName,
4499       /*InsertBefore=*/nullptr, GlobalValue::NotThreadLocal,
4500       DL.getDefaultGlobalsAddressSpace());
4501   DynamicEnvironmentGV->setVisibility(GlobalValue::ProtectedVisibility);
4502 
4503   Constant *DynamicEnvironment =
4504       DynamicEnvironmentGV->getType() == DynamicEnvironmentPtr
4505           ? DynamicEnvironmentGV
4506           : ConstantExpr::getAddrSpaceCast(DynamicEnvironmentGV,
4507                                            DynamicEnvironmentPtr);
4508 
4509   Constant *ConfigurationEnvironmentInitializer = ConstantStruct::get(
4510       ConfigurationEnvironment, {
4511                                     UseGenericStateMachineVal,
4512                                     MayUseNestedParallelismVal,
4513                                     IsSPMDVal,
4514                                     MinThreads,
4515                                     MaxThreads,
4516                                     MinTeams,
4517                                     MaxTeams,
4518                                     ReductionDataSize,
4519                                     ReductionBufferLength,
4520                                 });
4521   Constant *KernelEnvironmentInitializer = ConstantStruct::get(
4522       KernelEnvironment, {
4523                              ConfigurationEnvironmentInitializer,
4524                              Ident,
4525                              DynamicEnvironment,
4526                          });
4527   Twine KernelEnvironmentName = KernelName + "_kernel_environment";
4528   GlobalVariable *KernelEnvironmentGV = new GlobalVariable(
4529       M, KernelEnvironment, /*IsConstant=*/true, GlobalValue::WeakODRLinkage,
4530       KernelEnvironmentInitializer, KernelEnvironmentName,
4531       /*InsertBefore=*/nullptr, GlobalValue::NotThreadLocal,
4532       DL.getDefaultGlobalsAddressSpace());
4533   KernelEnvironmentGV->setVisibility(GlobalValue::ProtectedVisibility);
4534 
4535   Constant *KernelEnvironment =
4536       KernelEnvironmentGV->getType() == KernelEnvironmentPtr
4537           ? KernelEnvironmentGV
4538           : ConstantExpr::getAddrSpaceCast(KernelEnvironmentGV,
4539                                            KernelEnvironmentPtr);
4540   Value *KernelLaunchEnvironment = Kernel->getArg(0);
4541   CallInst *ThreadKind =
4542       Builder.CreateCall(Fn, {KernelEnvironment, KernelLaunchEnvironment});
4543 
4544   Value *ExecUserCode = Builder.CreateICmpEQ(
4545       ThreadKind, ConstantInt::get(ThreadKind->getType(), -1),
4546       "exec_user_code");
4547 
4548   // ThreadKind = __kmpc_target_init(...)
4549   // if (ThreadKind == -1)
4550   //   user_code
4551   // else
4552   //   return;
4553 
4554   auto *UI = Builder.CreateUnreachable();
4555   BasicBlock *CheckBB = UI->getParent();
4556   BasicBlock *UserCodeEntryBB = CheckBB->splitBasicBlock(UI, "user_code.entry");
4557 
4558   BasicBlock *WorkerExitBB = BasicBlock::Create(
4559       CheckBB->getContext(), "worker.exit", CheckBB->getParent());
4560   Builder.SetInsertPoint(WorkerExitBB);
4561   Builder.CreateRetVoid();
4562 
4563   auto *CheckBBTI = CheckBB->getTerminator();
4564   Builder.SetInsertPoint(CheckBBTI);
4565   Builder.CreateCondBr(ExecUserCode, UI->getParent(), WorkerExitBB);
4566 
4567   CheckBBTI->eraseFromParent();
4568   UI->eraseFromParent();
4569 
4570   // Continue in the "user_code" block, see diagram above and in
4571   // openmp/libomptarget/deviceRTLs/common/include/target.h .
4572   return InsertPointTy(UserCodeEntryBB, UserCodeEntryBB->getFirstInsertionPt());
4573 }
4574 
4575 void OpenMPIRBuilder::createTargetDeinit(const LocationDescription &Loc,
4576                                          int32_t TeamsReductionDataSize,
4577                                          int32_t TeamsReductionBufferLength) {
4578   if (!updateToLocation(Loc))
4579     return;
4580 
4581   Function *Fn = getOrCreateRuntimeFunctionPtr(
4582       omp::RuntimeFunction::OMPRTL___kmpc_target_deinit);
4583 
4584   Builder.CreateCall(Fn, {});
4585 
4586   if (!TeamsReductionBufferLength || !TeamsReductionDataSize)
4587     return;
4588 
4589   Function *Kernel = Builder.GetInsertBlock()->getParent();
4590   // We need to strip the debug prefix to get the correct kernel name.
4591   StringRef KernelName = Kernel->getName();
4592   const std::string DebugPrefix = "_debug__";
4593   if (KernelName.ends_with(DebugPrefix))
4594     KernelName = KernelName.drop_back(DebugPrefix.length());
4595   auto *KernelEnvironmentGV =
4596       M.getNamedGlobal((KernelName + "_kernel_environment").str());
4597   assert(KernelEnvironmentGV && "Expected kernel environment global\n");
4598   auto *KernelEnvironmentInitializer = KernelEnvironmentGV->getInitializer();
4599   auto *NewInitializer = ConstantFoldInsertValueInstruction(
4600       KernelEnvironmentInitializer,
4601       ConstantInt::get(Int32, TeamsReductionDataSize), {0, 7});
4602   NewInitializer = ConstantFoldInsertValueInstruction(
4603       NewInitializer, ConstantInt::get(Int32, TeamsReductionBufferLength),
4604       {0, 8});
4605   KernelEnvironmentGV->setInitializer(NewInitializer);
4606 }
4607 
4608 static MDNode *getNVPTXMDNode(Function &Kernel, StringRef Name) {
4609   Module &M = *Kernel.getParent();
4610   NamedMDNode *MD = M.getOrInsertNamedMetadata("nvvm.annotations");
4611   for (auto *Op : MD->operands()) {
4612     if (Op->getNumOperands() != 3)
4613       continue;
4614     auto *KernelOp = dyn_cast<ConstantAsMetadata>(Op->getOperand(0));
4615     if (!KernelOp || KernelOp->getValue() != &Kernel)
4616       continue;
4617     auto *Prop = dyn_cast<MDString>(Op->getOperand(1));
4618     if (!Prop || Prop->getString() != Name)
4619       continue;
4620     return Op;
4621   }
4622   return nullptr;
4623 }
4624 
4625 static void updateNVPTXMetadata(Function &Kernel, StringRef Name, int32_t Value,
4626                                 bool Min) {
4627   // Update the "maxntidx" metadata for NVIDIA, or add it.
4628   MDNode *ExistingOp = getNVPTXMDNode(Kernel, Name);
4629   if (ExistingOp) {
4630     auto *OldVal = cast<ConstantAsMetadata>(ExistingOp->getOperand(2));
4631     int32_t OldLimit = cast<ConstantInt>(OldVal->getValue())->getZExtValue();
4632     ExistingOp->replaceOperandWith(
4633         2, ConstantAsMetadata::get(ConstantInt::get(
4634                OldVal->getValue()->getType(),
4635                Min ? std::min(OldLimit, Value) : std::max(OldLimit, Value))));
4636   } else {
4637     LLVMContext &Ctx = Kernel.getContext();
4638     Metadata *MDVals[] = {ConstantAsMetadata::get(&Kernel),
4639                           MDString::get(Ctx, Name),
4640                           ConstantAsMetadata::get(
4641                               ConstantInt::get(Type::getInt32Ty(Ctx), Value))};
4642     // Append metadata to nvvm.annotations
4643     Module &M = *Kernel.getParent();
4644     NamedMDNode *MD = M.getOrInsertNamedMetadata("nvvm.annotations");
4645     MD->addOperand(MDNode::get(Ctx, MDVals));
4646   }
4647 }
4648 
4649 std::pair<int32_t, int32_t>
4650 OpenMPIRBuilder::readThreadBoundsForKernel(const Triple &T, Function &Kernel) {
4651   int32_t ThreadLimit =
4652       Kernel.getFnAttributeAsParsedInteger("omp_target_thread_limit");
4653 
4654   if (T.isAMDGPU()) {
4655     const auto &Attr = Kernel.getFnAttribute("amdgpu-flat-work-group-size");
4656     if (!Attr.isValid() || !Attr.isStringAttribute())
4657       return {0, ThreadLimit};
4658     auto [LBStr, UBStr] = Attr.getValueAsString().split(',');
4659     int32_t LB, UB;
4660     if (!llvm::to_integer(UBStr, UB, 10))
4661       return {0, ThreadLimit};
4662     UB = ThreadLimit ? std::min(ThreadLimit, UB) : UB;
4663     if (!llvm::to_integer(LBStr, LB, 10))
4664       return {0, UB};
4665     return {LB, UB};
4666   }
4667 
4668   if (MDNode *ExistingOp = getNVPTXMDNode(Kernel, "maxntidx")) {
4669     auto *OldVal = cast<ConstantAsMetadata>(ExistingOp->getOperand(2));
4670     int32_t UB = cast<ConstantInt>(OldVal->getValue())->getZExtValue();
4671     return {0, ThreadLimit ? std::min(ThreadLimit, UB) : UB};
4672   }
4673   return {0, ThreadLimit};
4674 }
4675 
4676 void OpenMPIRBuilder::writeThreadBoundsForKernel(const Triple &T,
4677                                                  Function &Kernel, int32_t LB,
4678                                                  int32_t UB) {
4679   Kernel.addFnAttr("omp_target_thread_limit", std::to_string(UB));
4680 
4681   if (T.isAMDGPU()) {
4682     Kernel.addFnAttr("amdgpu-flat-work-group-size",
4683                      llvm::utostr(LB) + "," + llvm::utostr(UB));
4684     return;
4685   }
4686 
4687   updateNVPTXMetadata(Kernel, "maxntidx", UB, true);
4688 }
4689 
4690 std::pair<int32_t, int32_t>
4691 OpenMPIRBuilder::readTeamBoundsForKernel(const Triple &, Function &Kernel) {
4692   // TODO: Read from backend annotations if available.
4693   return {0, Kernel.getFnAttributeAsParsedInteger("omp_target_num_teams")};
4694 }
4695 
4696 void OpenMPIRBuilder::writeTeamsForKernel(const Triple &T, Function &Kernel,
4697                                           int32_t LB, int32_t UB) {
4698   if (T.isNVPTX()) {
4699     if (UB > 0)
4700       updateNVPTXMetadata(Kernel, "maxclusterrank", UB, true);
4701     updateNVPTXMetadata(Kernel, "minctasm", LB, false);
4702   }
4703   Kernel.addFnAttr("omp_target_num_teams", std::to_string(LB));
4704 }
4705 
4706 void OpenMPIRBuilder::setOutlinedTargetRegionFunctionAttributes(
4707     Function *OutlinedFn) {
4708   if (Config.isTargetDevice()) {
4709     OutlinedFn->setLinkage(GlobalValue::WeakODRLinkage);
4710     // TODO: Determine if DSO local can be set to true.
4711     OutlinedFn->setDSOLocal(false);
4712     OutlinedFn->setVisibility(GlobalValue::ProtectedVisibility);
4713     if (T.isAMDGCN())
4714       OutlinedFn->setCallingConv(CallingConv::AMDGPU_KERNEL);
4715   }
4716 }
4717 
4718 Constant *OpenMPIRBuilder::createOutlinedFunctionID(Function *OutlinedFn,
4719                                                     StringRef EntryFnIDName) {
4720   if (Config.isTargetDevice()) {
4721     assert(OutlinedFn && "The outlined function must exist if embedded");
4722     return OutlinedFn;
4723   }
4724 
4725   return new GlobalVariable(
4726       M, Builder.getInt8Ty(), /*isConstant=*/true, GlobalValue::WeakAnyLinkage,
4727       Constant::getNullValue(Builder.getInt8Ty()), EntryFnIDName);
4728 }
4729 
4730 Constant *OpenMPIRBuilder::createTargetRegionEntryAddr(Function *OutlinedFn,
4731                                                        StringRef EntryFnName) {
4732   if (OutlinedFn)
4733     return OutlinedFn;
4734 
4735   assert(!M.getGlobalVariable(EntryFnName, true) &&
4736          "Named kernel already exists?");
4737   return new GlobalVariable(
4738       M, Builder.getInt8Ty(), /*isConstant=*/true, GlobalValue::InternalLinkage,
4739       Constant::getNullValue(Builder.getInt8Ty()), EntryFnName);
4740 }
4741 
4742 void OpenMPIRBuilder::emitTargetRegionFunction(
4743     TargetRegionEntryInfo &EntryInfo,
4744     FunctionGenCallback &GenerateFunctionCallback, bool IsOffloadEntry,
4745     Function *&OutlinedFn, Constant *&OutlinedFnID) {
4746 
4747   SmallString<64> EntryFnName;
4748   OffloadInfoManager.getTargetRegionEntryFnName(EntryFnName, EntryInfo);
4749 
4750   OutlinedFn = Config.isTargetDevice() || !Config.openMPOffloadMandatory()
4751                    ? GenerateFunctionCallback(EntryFnName)
4752                    : nullptr;
4753 
4754   // If this target outline function is not an offload entry, we don't need to
4755   // register it. This may be in the case of a false if clause, or if there are
4756   // no OpenMP targets.
4757   if (!IsOffloadEntry)
4758     return;
4759 
4760   std::string EntryFnIDName =
4761       Config.isTargetDevice()
4762           ? std::string(EntryFnName)
4763           : createPlatformSpecificName({EntryFnName, "region_id"});
4764 
4765   OutlinedFnID = registerTargetRegionFunction(EntryInfo, OutlinedFn,
4766                                               EntryFnName, EntryFnIDName);
4767 }
4768 
4769 Constant *OpenMPIRBuilder::registerTargetRegionFunction(
4770     TargetRegionEntryInfo &EntryInfo, Function *OutlinedFn,
4771     StringRef EntryFnName, StringRef EntryFnIDName) {
4772   if (OutlinedFn)
4773     setOutlinedTargetRegionFunctionAttributes(OutlinedFn);
4774   auto OutlinedFnID = createOutlinedFunctionID(OutlinedFn, EntryFnIDName);
4775   auto EntryAddr = createTargetRegionEntryAddr(OutlinedFn, EntryFnName);
4776   OffloadInfoManager.registerTargetRegionEntryInfo(
4777       EntryInfo, EntryAddr, OutlinedFnID,
4778       OffloadEntriesInfoManager::OMPTargetRegionEntryTargetRegion);
4779   return OutlinedFnID;
4780 }
4781 
4782 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetData(
4783     const LocationDescription &Loc, InsertPointTy AllocaIP,
4784     InsertPointTy CodeGenIP, Value *DeviceID, Value *IfCond,
4785     TargetDataInfo &Info, GenMapInfoCallbackTy GenMapInfoCB,
4786     omp::RuntimeFunction *MapperFunc,
4787     function_ref<InsertPointTy(InsertPointTy CodeGenIP, BodyGenTy BodyGenType)>
4788         BodyGenCB,
4789     function_ref<void(unsigned int, Value *)> DeviceAddrCB,
4790     function_ref<Value *(unsigned int)> CustomMapperCB, Value *SrcLocInfo) {
4791   if (!updateToLocation(Loc))
4792     return InsertPointTy();
4793 
4794   // Disable TargetData CodeGen on Device pass.
4795   if (Config.IsTargetDevice.value_or(false))
4796     return Builder.saveIP();
4797 
4798   Builder.restoreIP(CodeGenIP);
4799   bool IsStandAlone = !BodyGenCB;
4800   MapInfosTy *MapInfo;
4801   // Generate the code for the opening of the data environment. Capture all the
4802   // arguments of the runtime call by reference because they are used in the
4803   // closing of the region.
4804   auto BeginThenGen = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
4805     MapInfo = &GenMapInfoCB(Builder.saveIP());
4806     emitOffloadingArrays(AllocaIP, Builder.saveIP(), *MapInfo, Info,
4807                          /*IsNonContiguous=*/true, DeviceAddrCB,
4808                          CustomMapperCB);
4809 
4810     TargetDataRTArgs RTArgs;
4811     emitOffloadingArraysArgument(Builder, RTArgs, Info,
4812                                  !MapInfo->Names.empty());
4813 
4814     // Emit the number of elements in the offloading arrays.
4815     Value *PointerNum = Builder.getInt32(Info.NumberOfPtrs);
4816 
4817     // Source location for the ident struct
4818     if (!SrcLocInfo) {
4819       uint32_t SrcLocStrSize;
4820       Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4821       SrcLocInfo = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4822     }
4823 
4824     Value *OffloadingArgs[] = {SrcLocInfo,           DeviceID,
4825                                PointerNum,           RTArgs.BasePointersArray,
4826                                RTArgs.PointersArray, RTArgs.SizesArray,
4827                                RTArgs.MapTypesArray, RTArgs.MapNamesArray,
4828                                RTArgs.MappersArray};
4829 
4830     if (IsStandAlone) {
4831       assert(MapperFunc && "MapperFunc missing for standalone target data");
4832       Builder.CreateCall(getOrCreateRuntimeFunctionPtr(*MapperFunc),
4833                          OffloadingArgs);
4834     } else {
4835       Function *BeginMapperFunc = getOrCreateRuntimeFunctionPtr(
4836           omp::OMPRTL___tgt_target_data_begin_mapper);
4837 
4838       Builder.CreateCall(BeginMapperFunc, OffloadingArgs);
4839 
4840       for (auto DeviceMap : Info.DevicePtrInfoMap) {
4841         if (isa<AllocaInst>(DeviceMap.second.second)) {
4842           auto *LI =
4843               Builder.CreateLoad(Builder.getPtrTy(), DeviceMap.second.first);
4844           Builder.CreateStore(LI, DeviceMap.second.second);
4845         }
4846       }
4847 
4848       // If device pointer privatization is required, emit the body of the
4849       // region here. It will have to be duplicated: with and without
4850       // privatization.
4851       Builder.restoreIP(BodyGenCB(Builder.saveIP(), BodyGenTy::Priv));
4852     }
4853   };
4854 
4855   // If we need device pointer privatization, we need to emit the body of the
4856   // region with no privatization in the 'else' branch of the conditional.
4857   // Otherwise, we don't have to do anything.
4858   auto BeginElseGen = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
4859     Builder.restoreIP(BodyGenCB(Builder.saveIP(), BodyGenTy::DupNoPriv));
4860   };
4861 
4862   // Generate code for the closing of the data region.
4863   auto EndThenGen = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
4864     TargetDataRTArgs RTArgs;
4865     emitOffloadingArraysArgument(Builder, RTArgs, Info, !MapInfo->Names.empty(),
4866                                  /*ForEndCall=*/true);
4867 
4868     // Emit the number of elements in the offloading arrays.
4869     Value *PointerNum = Builder.getInt32(Info.NumberOfPtrs);
4870 
4871     // Source location for the ident struct
4872     if (!SrcLocInfo) {
4873       uint32_t SrcLocStrSize;
4874       Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4875       SrcLocInfo = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4876     }
4877 
4878     Value *OffloadingArgs[] = {SrcLocInfo,           DeviceID,
4879                                PointerNum,           RTArgs.BasePointersArray,
4880                                RTArgs.PointersArray, RTArgs.SizesArray,
4881                                RTArgs.MapTypesArray, RTArgs.MapNamesArray,
4882                                RTArgs.MappersArray};
4883     Function *EndMapperFunc =
4884         getOrCreateRuntimeFunctionPtr(omp::OMPRTL___tgt_target_data_end_mapper);
4885 
4886     Builder.CreateCall(EndMapperFunc, OffloadingArgs);
4887   };
4888 
4889   // We don't have to do anything to close the region if the if clause evaluates
4890   // to false.
4891   auto EndElseGen = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {};
4892 
4893   if (BodyGenCB) {
4894     if (IfCond) {
4895       emitIfClause(IfCond, BeginThenGen, BeginElseGen, AllocaIP);
4896     } else {
4897       BeginThenGen(AllocaIP, Builder.saveIP());
4898     }
4899 
4900     // If we don't require privatization of device pointers, we emit the body in
4901     // between the runtime calls. This avoids duplicating the body code.
4902     Builder.restoreIP(BodyGenCB(Builder.saveIP(), BodyGenTy::NoPriv));
4903 
4904     if (IfCond) {
4905       emitIfClause(IfCond, EndThenGen, EndElseGen, AllocaIP);
4906     } else {
4907       EndThenGen(AllocaIP, Builder.saveIP());
4908     }
4909   } else {
4910     if (IfCond) {
4911       emitIfClause(IfCond, BeginThenGen, EndElseGen, AllocaIP);
4912     } else {
4913       BeginThenGen(AllocaIP, Builder.saveIP());
4914     }
4915   }
4916 
4917   return Builder.saveIP();
4918 }
4919 
4920 FunctionCallee
4921 OpenMPIRBuilder::createForStaticInitFunction(unsigned IVSize, bool IVSigned,
4922                                              bool IsGPUDistribute) {
4923   assert((IVSize == 32 || IVSize == 64) &&
4924          "IV size is not compatible with the omp runtime");
4925   RuntimeFunction Name;
4926   if (IsGPUDistribute)
4927     Name = IVSize == 32
4928                ? (IVSigned ? omp::OMPRTL___kmpc_distribute_static_init_4
4929                            : omp::OMPRTL___kmpc_distribute_static_init_4u)
4930                : (IVSigned ? omp::OMPRTL___kmpc_distribute_static_init_8
4931                            : omp::OMPRTL___kmpc_distribute_static_init_8u);
4932   else
4933     Name = IVSize == 32 ? (IVSigned ? omp::OMPRTL___kmpc_for_static_init_4
4934                                     : omp::OMPRTL___kmpc_for_static_init_4u)
4935                         : (IVSigned ? omp::OMPRTL___kmpc_for_static_init_8
4936                                     : omp::OMPRTL___kmpc_for_static_init_8u);
4937 
4938   return getOrCreateRuntimeFunction(M, Name);
4939 }
4940 
4941 FunctionCallee OpenMPIRBuilder::createDispatchInitFunction(unsigned IVSize,
4942                                                            bool IVSigned) {
4943   assert((IVSize == 32 || IVSize == 64) &&
4944          "IV size is not compatible with the omp runtime");
4945   RuntimeFunction Name = IVSize == 32
4946                              ? (IVSigned ? omp::OMPRTL___kmpc_dispatch_init_4
4947                                          : omp::OMPRTL___kmpc_dispatch_init_4u)
4948                              : (IVSigned ? omp::OMPRTL___kmpc_dispatch_init_8
4949                                          : omp::OMPRTL___kmpc_dispatch_init_8u);
4950 
4951   return getOrCreateRuntimeFunction(M, Name);
4952 }
4953 
4954 FunctionCallee OpenMPIRBuilder::createDispatchNextFunction(unsigned IVSize,
4955                                                            bool IVSigned) {
4956   assert((IVSize == 32 || IVSize == 64) &&
4957          "IV size is not compatible with the omp runtime");
4958   RuntimeFunction Name = IVSize == 32
4959                              ? (IVSigned ? omp::OMPRTL___kmpc_dispatch_next_4
4960                                          : omp::OMPRTL___kmpc_dispatch_next_4u)
4961                              : (IVSigned ? omp::OMPRTL___kmpc_dispatch_next_8
4962                                          : omp::OMPRTL___kmpc_dispatch_next_8u);
4963 
4964   return getOrCreateRuntimeFunction(M, Name);
4965 }
4966 
4967 FunctionCallee OpenMPIRBuilder::createDispatchFiniFunction(unsigned IVSize,
4968                                                            bool IVSigned) {
4969   assert((IVSize == 32 || IVSize == 64) &&
4970          "IV size is not compatible with the omp runtime");
4971   RuntimeFunction Name = IVSize == 32
4972                              ? (IVSigned ? omp::OMPRTL___kmpc_dispatch_fini_4
4973                                          : omp::OMPRTL___kmpc_dispatch_fini_4u)
4974                              : (IVSigned ? omp::OMPRTL___kmpc_dispatch_fini_8
4975                                          : omp::OMPRTL___kmpc_dispatch_fini_8u);
4976 
4977   return getOrCreateRuntimeFunction(M, Name);
4978 }
4979 
4980 static void replaceConstatExprUsesInFuncWithInstr(ConstantExpr *ConstExpr,
4981                                                   Function *Func) {
4982   for (User *User : make_early_inc_range(ConstExpr->users()))
4983     if (auto *Instr = dyn_cast<Instruction>(User))
4984       if (Instr->getFunction() == Func)
4985         Instr->replaceUsesOfWith(ConstExpr, ConstExpr->getAsInstruction(Instr));
4986 }
4987 
4988 static void replaceConstantValueUsesInFuncWithInstr(llvm::Value *Input,
4989                                                     Function *Func) {
4990   for (User *User : make_early_inc_range(Input->users()))
4991     if (auto *Const = dyn_cast<Constant>(User))
4992       if (auto *ConstExpr = dyn_cast<ConstantExpr>(Const))
4993         replaceConstatExprUsesInFuncWithInstr(ConstExpr, Func);
4994 }
4995 
4996 static Function *createOutlinedFunction(
4997     OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, StringRef FuncName,
4998     SmallVectorImpl<Value *> &Inputs,
4999     OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
5000     OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy &ArgAccessorFuncCB) {
5001   SmallVector<Type *> ParameterTypes;
5002   if (OMPBuilder.Config.isTargetDevice()) {
5003     // Add the "implicit" runtime argument we use to provide launch specific
5004     // information for target devices.
5005     auto *Int8PtrTy = PointerType::getUnqual(Builder.getContext());
5006     ParameterTypes.push_back(Int8PtrTy);
5007 
5008     // All parameters to target devices are passed as pointers
5009     // or i64. This assumes 64-bit address spaces/pointers.
5010     for (auto &Arg : Inputs)
5011       ParameterTypes.push_back(Arg->getType()->isPointerTy()
5012                                    ? Arg->getType()
5013                                    : Type::getInt64Ty(Builder.getContext()));
5014   } else {
5015     for (auto &Arg : Inputs)
5016       ParameterTypes.push_back(Arg->getType());
5017   }
5018 
5019   auto FuncType = FunctionType::get(Builder.getVoidTy(), ParameterTypes,
5020                                     /*isVarArg*/ false);
5021   auto Func = Function::Create(FuncType, GlobalValue::InternalLinkage, FuncName,
5022                                Builder.GetInsertBlock()->getModule());
5023 
5024   // Save insert point.
5025   auto OldInsertPoint = Builder.saveIP();
5026 
5027   // Generate the region into the function.
5028   BasicBlock *EntryBB = BasicBlock::Create(Builder.getContext(), "entry", Func);
5029   Builder.SetInsertPoint(EntryBB);
5030 
5031   // Insert target init call in the device compilation pass.
5032   if (OMPBuilder.Config.isTargetDevice())
5033     Builder.restoreIP(OMPBuilder.createTargetInit(Builder, /*IsSPMD*/ false));
5034 
5035   BasicBlock *UserCodeEntryBB = Builder.GetInsertBlock();
5036 
5037   // Insert target deinit call in the device compilation pass.
5038   Builder.restoreIP(CBFunc(Builder.saveIP(), Builder.saveIP()));
5039   if (OMPBuilder.Config.isTargetDevice())
5040     OMPBuilder.createTargetDeinit(Builder);
5041 
5042   // Insert return instruction.
5043   Builder.CreateRetVoid();
5044 
5045   // New Alloca IP at entry point of created device function.
5046   Builder.SetInsertPoint(EntryBB->getFirstNonPHI());
5047   auto AllocaIP = Builder.saveIP();
5048 
5049   Builder.SetInsertPoint(UserCodeEntryBB->getFirstNonPHIOrDbg());
5050 
5051   // Skip the artificial dyn_ptr on the device.
5052   const auto &ArgRange =
5053       OMPBuilder.Config.isTargetDevice()
5054           ? make_range(Func->arg_begin() + 1, Func->arg_end())
5055           : Func->args();
5056 
5057   // Rewrite uses of input valus to parameters.
5058   for (auto InArg : zip(Inputs, ArgRange)) {
5059     Value *Input = std::get<0>(InArg);
5060     Argument &Arg = std::get<1>(InArg);
5061     Value *InputCopy = nullptr;
5062 
5063     Builder.restoreIP(
5064         ArgAccessorFuncCB(Arg, Input, InputCopy, AllocaIP, Builder.saveIP()));
5065 
5066     // Things like GEP's can come in the form of Constants. Constants and
5067     // ConstantExpr's do not have access to the knowledge of what they're
5068     // contained in, so we must dig a little to find an instruction so we can
5069     // tell if they're used inside of the function we're outlining. We also
5070     // replace the original constant expression with a new instruction
5071     // equivalent; an instruction as it allows easy modification in the
5072     // following loop, as we can now know the constant (instruction) is owned by
5073     // our target function and replaceUsesOfWith can now be invoked on it
5074     // (cannot do this with constants it seems). A brand new one also allows us
5075     // to be cautious as it is perhaps possible the old expression was used
5076     // inside of the function but exists and is used externally (unlikely by the
5077     // nature of a Constant, but still).
5078     replaceConstantValueUsesInFuncWithInstr(Input, Func);
5079 
5080     // Collect all the instructions
5081     for (User *User : make_early_inc_range(Input->users()))
5082       if (auto *Instr = dyn_cast<Instruction>(User))
5083         if (Instr->getFunction() == Func)
5084           Instr->replaceUsesOfWith(Input, InputCopy);
5085   }
5086 
5087   // Restore insert point.
5088   Builder.restoreIP(OldInsertPoint);
5089 
5090   return Func;
5091 }
5092 
5093 static void emitTargetOutlinedFunction(
5094     OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
5095     TargetRegionEntryInfo &EntryInfo, Function *&OutlinedFn,
5096     Constant *&OutlinedFnID, SmallVectorImpl<Value *> &Inputs,
5097     OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
5098     OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy &ArgAccessorFuncCB) {
5099 
5100   OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction =
5101       [&OMPBuilder, &Builder, &Inputs, &CBFunc,
5102        &ArgAccessorFuncCB](StringRef EntryFnName) {
5103         return createOutlinedFunction(OMPBuilder, Builder, EntryFnName, Inputs,
5104                                       CBFunc, ArgAccessorFuncCB);
5105       };
5106 
5107   OMPBuilder.emitTargetRegionFunction(EntryInfo, GenerateOutlinedFunction, true,
5108                                       OutlinedFn, OutlinedFnID);
5109 }
5110 
5111 static void emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
5112                            OpenMPIRBuilder::InsertPointTy AllocaIP,
5113                            Function *OutlinedFn, Constant *OutlinedFnID,
5114                            int32_t NumTeams, int32_t NumThreads,
5115                            SmallVectorImpl<Value *> &Args,
5116                            OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB) {
5117 
5118   OpenMPIRBuilder::TargetDataInfo Info(
5119       /*RequiresDevicePointerInfo=*/false,
5120       /*SeparateBeginEndCalls=*/true);
5121 
5122   OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP());
5123   OMPBuilder.emitOffloadingArrays(AllocaIP, Builder.saveIP(), MapInfo, Info,
5124                                   /*IsNonContiguous=*/true);
5125 
5126   OpenMPIRBuilder::TargetDataRTArgs RTArgs;
5127   OMPBuilder.emitOffloadingArraysArgument(Builder, RTArgs, Info,
5128                                           !MapInfo.Names.empty());
5129 
5130   //  emitKernelLaunch
5131   auto &&EmitTargetCallFallbackCB =
5132       [&](OpenMPIRBuilder::InsertPointTy IP) -> OpenMPIRBuilder::InsertPointTy {
5133     Builder.restoreIP(IP);
5134     Builder.CreateCall(OutlinedFn, Args);
5135     return Builder.saveIP();
5136   };
5137 
5138   unsigned NumTargetItems = MapInfo.BasePointers.size();
5139   // TODO: Use correct device ID
5140   Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF);
5141   Value *NumTeamsVal = Builder.getInt32(NumTeams);
5142   Value *NumThreadsVal = Builder.getInt32(NumThreads);
5143   uint32_t SrcLocStrSize;
5144   Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
5145   Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
5146                                              llvm::omp::IdentFlag(0), 0);
5147   // TODO: Use correct NumIterations
5148   Value *NumIterations = Builder.getInt64(0);
5149   // TODO: Use correct DynCGGroupMem
5150   Value *DynCGGroupMem = Builder.getInt32(0);
5151 
5152   bool HasNoWait = false;
5153 
5154   OpenMPIRBuilder::TargetKernelArgs KArgs(NumTargetItems, RTArgs, NumIterations,
5155                                           NumTeamsVal, NumThreadsVal,
5156                                           DynCGGroupMem, HasNoWait);
5157 
5158   Builder.restoreIP(OMPBuilder.emitKernelLaunch(
5159       Builder, OutlinedFn, OutlinedFnID, EmitTargetCallFallbackCB, KArgs,
5160       DeviceID, RTLoc, AllocaIP));
5161 }
5162 
5163 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTarget(
5164     const LocationDescription &Loc, InsertPointTy AllocaIP,
5165     InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo, int32_t NumTeams,
5166     int32_t NumThreads, SmallVectorImpl<Value *> &Args,
5167     GenMapInfoCallbackTy GenMapInfoCB,
5168     OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
5169     OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB) {
5170   if (!updateToLocation(Loc))
5171     return InsertPointTy();
5172 
5173   Builder.restoreIP(CodeGenIP);
5174 
5175   Function *OutlinedFn;
5176   Constant *OutlinedFnID;
5177   emitTargetOutlinedFunction(*this, Builder, EntryInfo, OutlinedFn,
5178                              OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB);
5179   if (!Config.isTargetDevice())
5180     emitTargetCall(*this, Builder, AllocaIP, OutlinedFn, OutlinedFnID, NumTeams,
5181                    NumThreads, Args, GenMapInfoCB);
5182 
5183   return Builder.saveIP();
5184 }
5185 
5186 std::string OpenMPIRBuilder::getNameWithSeparators(ArrayRef<StringRef> Parts,
5187                                                    StringRef FirstSeparator,
5188                                                    StringRef Separator) {
5189   SmallString<128> Buffer;
5190   llvm::raw_svector_ostream OS(Buffer);
5191   StringRef Sep = FirstSeparator;
5192   for (StringRef Part : Parts) {
5193     OS << Sep << Part;
5194     Sep = Separator;
5195   }
5196   return OS.str().str();
5197 }
5198 
5199 std::string
5200 OpenMPIRBuilder::createPlatformSpecificName(ArrayRef<StringRef> Parts) const {
5201   return OpenMPIRBuilder::getNameWithSeparators(Parts, Config.firstSeparator(),
5202                                                 Config.separator());
5203 }
5204 
5205 GlobalVariable *
5206 OpenMPIRBuilder::getOrCreateInternalVariable(Type *Ty, const StringRef &Name,
5207                                              unsigned AddressSpace) {
5208   auto &Elem = *InternalVars.try_emplace(Name, nullptr).first;
5209   if (Elem.second) {
5210     assert(Elem.second->getValueType() == Ty &&
5211            "OMP internal variable has different type than requested");
5212   } else {
5213     // TODO: investigate the appropriate linkage type used for the global
5214     // variable for possibly changing that to internal or private, or maybe
5215     // create different versions of the function for different OMP internal
5216     // variables.
5217     auto Linkage = this->M.getTargetTriple().rfind("wasm32") == 0
5218                        ? GlobalValue::ExternalLinkage
5219                        : GlobalValue::CommonLinkage;
5220     auto *GV = new GlobalVariable(M, Ty, /*IsConstant=*/false, Linkage,
5221                                   Constant::getNullValue(Ty), Elem.first(),
5222                                   /*InsertBefore=*/nullptr,
5223                                   GlobalValue::NotThreadLocal, AddressSpace);
5224     const DataLayout &DL = M.getDataLayout();
5225     const llvm::Align TypeAlign = DL.getABITypeAlign(Ty);
5226     const llvm::Align PtrAlign = DL.getPointerABIAlignment(AddressSpace);
5227     GV->setAlignment(std::max(TypeAlign, PtrAlign));
5228     Elem.second = GV;
5229   }
5230 
5231   return Elem.second;
5232 }
5233 
5234 Value *OpenMPIRBuilder::getOMPCriticalRegionLock(StringRef CriticalName) {
5235   std::string Prefix = Twine("gomp_critical_user_", CriticalName).str();
5236   std::string Name = getNameWithSeparators({Prefix, "var"}, ".", ".");
5237   return getOrCreateInternalVariable(KmpCriticalNameTy, Name);
5238 }
5239 
5240 Value *OpenMPIRBuilder::getSizeInBytes(Value *BasePtr) {
5241   LLVMContext &Ctx = Builder.getContext();
5242   Value *Null =
5243       Constant::getNullValue(PointerType::getUnqual(BasePtr->getContext()));
5244   Value *SizeGep =
5245       Builder.CreateGEP(BasePtr->getType(), Null, Builder.getInt32(1));
5246   Value *SizePtrToInt = Builder.CreatePtrToInt(SizeGep, Type::getInt64Ty(Ctx));
5247   return SizePtrToInt;
5248 }
5249 
5250 GlobalVariable *
5251 OpenMPIRBuilder::createOffloadMaptypes(SmallVectorImpl<uint64_t> &Mappings,
5252                                        std::string VarName) {
5253   llvm::Constant *MaptypesArrayInit =
5254       llvm::ConstantDataArray::get(M.getContext(), Mappings);
5255   auto *MaptypesArrayGlobal = new llvm::GlobalVariable(
5256       M, MaptypesArrayInit->getType(),
5257       /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, MaptypesArrayInit,
5258       VarName);
5259   MaptypesArrayGlobal->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global);
5260   return MaptypesArrayGlobal;
5261 }
5262 
5263 void OpenMPIRBuilder::createMapperAllocas(const LocationDescription &Loc,
5264                                           InsertPointTy AllocaIP,
5265                                           unsigned NumOperands,
5266                                           struct MapperAllocas &MapperAllocas) {
5267   if (!updateToLocation(Loc))
5268     return;
5269 
5270   auto *ArrI8PtrTy = ArrayType::get(Int8Ptr, NumOperands);
5271   auto *ArrI64Ty = ArrayType::get(Int64, NumOperands);
5272   Builder.restoreIP(AllocaIP);
5273   AllocaInst *ArgsBase = Builder.CreateAlloca(
5274       ArrI8PtrTy, /* ArraySize = */ nullptr, ".offload_baseptrs");
5275   AllocaInst *Args = Builder.CreateAlloca(ArrI8PtrTy, /* ArraySize = */ nullptr,
5276                                           ".offload_ptrs");
5277   AllocaInst *ArgSizes = Builder.CreateAlloca(
5278       ArrI64Ty, /* ArraySize = */ nullptr, ".offload_sizes");
5279   Builder.restoreIP(Loc.IP);
5280   MapperAllocas.ArgsBase = ArgsBase;
5281   MapperAllocas.Args = Args;
5282   MapperAllocas.ArgSizes = ArgSizes;
5283 }
5284 
5285 void OpenMPIRBuilder::emitMapperCall(const LocationDescription &Loc,
5286                                      Function *MapperFunc, Value *SrcLocInfo,
5287                                      Value *MaptypesArg, Value *MapnamesArg,
5288                                      struct MapperAllocas &MapperAllocas,
5289                                      int64_t DeviceID, unsigned NumOperands) {
5290   if (!updateToLocation(Loc))
5291     return;
5292 
5293   auto *ArrI8PtrTy = ArrayType::get(Int8Ptr, NumOperands);
5294   auto *ArrI64Ty = ArrayType::get(Int64, NumOperands);
5295   Value *ArgsBaseGEP =
5296       Builder.CreateInBoundsGEP(ArrI8PtrTy, MapperAllocas.ArgsBase,
5297                                 {Builder.getInt32(0), Builder.getInt32(0)});
5298   Value *ArgsGEP =
5299       Builder.CreateInBoundsGEP(ArrI8PtrTy, MapperAllocas.Args,
5300                                 {Builder.getInt32(0), Builder.getInt32(0)});
5301   Value *ArgSizesGEP =
5302       Builder.CreateInBoundsGEP(ArrI64Ty, MapperAllocas.ArgSizes,
5303                                 {Builder.getInt32(0), Builder.getInt32(0)});
5304   Value *NullPtr =
5305       Constant::getNullValue(PointerType::getUnqual(Int8Ptr->getContext()));
5306   Builder.CreateCall(MapperFunc,
5307                      {SrcLocInfo, Builder.getInt64(DeviceID),
5308                       Builder.getInt32(NumOperands), ArgsBaseGEP, ArgsGEP,
5309                       ArgSizesGEP, MaptypesArg, MapnamesArg, NullPtr});
5310 }
5311 
5312 void OpenMPIRBuilder::emitOffloadingArraysArgument(IRBuilderBase &Builder,
5313                                                    TargetDataRTArgs &RTArgs,
5314                                                    TargetDataInfo &Info,
5315                                                    bool EmitDebug,
5316                                                    bool ForEndCall) {
5317   assert((!ForEndCall || Info.separateBeginEndCalls()) &&
5318          "expected region end call to runtime only when end call is separate");
5319   auto UnqualPtrTy = PointerType::getUnqual(M.getContext());
5320   auto VoidPtrTy = UnqualPtrTy;
5321   auto VoidPtrPtrTy = UnqualPtrTy;
5322   auto Int64Ty = Type::getInt64Ty(M.getContext());
5323   auto Int64PtrTy = UnqualPtrTy;
5324 
5325   if (!Info.NumberOfPtrs) {
5326     RTArgs.BasePointersArray = ConstantPointerNull::get(VoidPtrPtrTy);
5327     RTArgs.PointersArray = ConstantPointerNull::get(VoidPtrPtrTy);
5328     RTArgs.SizesArray = ConstantPointerNull::get(Int64PtrTy);
5329     RTArgs.MapTypesArray = ConstantPointerNull::get(Int64PtrTy);
5330     RTArgs.MapNamesArray = ConstantPointerNull::get(VoidPtrPtrTy);
5331     RTArgs.MappersArray = ConstantPointerNull::get(VoidPtrPtrTy);
5332     return;
5333   }
5334 
5335   RTArgs.BasePointersArray = Builder.CreateConstInBoundsGEP2_32(
5336       ArrayType::get(VoidPtrTy, Info.NumberOfPtrs),
5337       Info.RTArgs.BasePointersArray,
5338       /*Idx0=*/0, /*Idx1=*/0);
5339   RTArgs.PointersArray = Builder.CreateConstInBoundsGEP2_32(
5340       ArrayType::get(VoidPtrTy, Info.NumberOfPtrs), Info.RTArgs.PointersArray,
5341       /*Idx0=*/0,
5342       /*Idx1=*/0);
5343   RTArgs.SizesArray = Builder.CreateConstInBoundsGEP2_32(
5344       ArrayType::get(Int64Ty, Info.NumberOfPtrs), Info.RTArgs.SizesArray,
5345       /*Idx0=*/0, /*Idx1=*/0);
5346   RTArgs.MapTypesArray = Builder.CreateConstInBoundsGEP2_32(
5347       ArrayType::get(Int64Ty, Info.NumberOfPtrs),
5348       ForEndCall && Info.RTArgs.MapTypesArrayEnd ? Info.RTArgs.MapTypesArrayEnd
5349                                                  : Info.RTArgs.MapTypesArray,
5350       /*Idx0=*/0,
5351       /*Idx1=*/0);
5352 
5353   // Only emit the mapper information arrays if debug information is
5354   // requested.
5355   if (!EmitDebug)
5356     RTArgs.MapNamesArray = ConstantPointerNull::get(VoidPtrPtrTy);
5357   else
5358     RTArgs.MapNamesArray = Builder.CreateConstInBoundsGEP2_32(
5359         ArrayType::get(VoidPtrTy, Info.NumberOfPtrs), Info.RTArgs.MapNamesArray,
5360         /*Idx0=*/0,
5361         /*Idx1=*/0);
5362   // If there is no user-defined mapper, set the mapper array to nullptr to
5363   // avoid an unnecessary data privatization
5364   if (!Info.HasMapper)
5365     RTArgs.MappersArray = ConstantPointerNull::get(VoidPtrPtrTy);
5366   else
5367     RTArgs.MappersArray =
5368         Builder.CreatePointerCast(Info.RTArgs.MappersArray, VoidPtrPtrTy);
5369 }
5370 
5371 void OpenMPIRBuilder::emitNonContiguousDescriptor(InsertPointTy AllocaIP,
5372                                                   InsertPointTy CodeGenIP,
5373                                                   MapInfosTy &CombinedInfo,
5374                                                   TargetDataInfo &Info) {
5375   MapInfosTy::StructNonContiguousInfo &NonContigInfo =
5376       CombinedInfo.NonContigInfo;
5377 
5378   // Build an array of struct descriptor_dim and then assign it to
5379   // offload_args.
5380   //
5381   // struct descriptor_dim {
5382   //  uint64_t offset;
5383   //  uint64_t count;
5384   //  uint64_t stride
5385   // };
5386   Type *Int64Ty = Builder.getInt64Ty();
5387   StructType *DimTy = StructType::create(
5388       M.getContext(), ArrayRef<Type *>({Int64Ty, Int64Ty, Int64Ty}),
5389       "struct.descriptor_dim");
5390 
5391   enum { OffsetFD = 0, CountFD, StrideFD };
5392   // We need two index variable here since the size of "Dims" is the same as
5393   // the size of Components, however, the size of offset, count, and stride is
5394   // equal to the size of base declaration that is non-contiguous.
5395   for (unsigned I = 0, L = 0, E = NonContigInfo.Dims.size(); I < E; ++I) {
5396     // Skip emitting ir if dimension size is 1 since it cannot be
5397     // non-contiguous.
5398     if (NonContigInfo.Dims[I] == 1)
5399       continue;
5400     Builder.restoreIP(AllocaIP);
5401     ArrayType *ArrayTy = ArrayType::get(DimTy, NonContigInfo.Dims[I]);
5402     AllocaInst *DimsAddr =
5403         Builder.CreateAlloca(ArrayTy, /* ArraySize = */ nullptr, "dims");
5404     Builder.restoreIP(CodeGenIP);
5405     for (unsigned II = 0, EE = NonContigInfo.Dims[I]; II < EE; ++II) {
5406       unsigned RevIdx = EE - II - 1;
5407       Value *DimsLVal = Builder.CreateInBoundsGEP(
5408           DimsAddr->getAllocatedType(), DimsAddr,
5409           {Builder.getInt64(0), Builder.getInt64(II)});
5410       // Offset
5411       Value *OffsetLVal = Builder.CreateStructGEP(DimTy, DimsLVal, OffsetFD);
5412       Builder.CreateAlignedStore(
5413           NonContigInfo.Offsets[L][RevIdx], OffsetLVal,
5414           M.getDataLayout().getPrefTypeAlign(OffsetLVal->getType()));
5415       // Count
5416       Value *CountLVal = Builder.CreateStructGEP(DimTy, DimsLVal, CountFD);
5417       Builder.CreateAlignedStore(
5418           NonContigInfo.Counts[L][RevIdx], CountLVal,
5419           M.getDataLayout().getPrefTypeAlign(CountLVal->getType()));
5420       // Stride
5421       Value *StrideLVal = Builder.CreateStructGEP(DimTy, DimsLVal, StrideFD);
5422       Builder.CreateAlignedStore(
5423           NonContigInfo.Strides[L][RevIdx], StrideLVal,
5424           M.getDataLayout().getPrefTypeAlign(CountLVal->getType()));
5425     }
5426     // args[I] = &dims
5427     Builder.restoreIP(CodeGenIP);
5428     Value *DAddr = Builder.CreatePointerBitCastOrAddrSpaceCast(
5429         DimsAddr, Builder.getPtrTy());
5430     Value *P = Builder.CreateConstInBoundsGEP2_32(
5431         ArrayType::get(Builder.getPtrTy(), Info.NumberOfPtrs),
5432         Info.RTArgs.PointersArray, 0, I);
5433     Builder.CreateAlignedStore(
5434         DAddr, P, M.getDataLayout().getPrefTypeAlign(Builder.getPtrTy()));
5435     ++L;
5436   }
5437 }
5438 
5439 void OpenMPIRBuilder::emitOffloadingArrays(
5440     InsertPointTy AllocaIP, InsertPointTy CodeGenIP, MapInfosTy &CombinedInfo,
5441     TargetDataInfo &Info, bool IsNonContiguous,
5442     function_ref<void(unsigned int, Value *)> DeviceAddrCB,
5443     function_ref<Value *(unsigned int)> CustomMapperCB) {
5444 
5445   // Reset the array information.
5446   Info.clearArrayInfo();
5447   Info.NumberOfPtrs = CombinedInfo.BasePointers.size();
5448 
5449   if (Info.NumberOfPtrs == 0)
5450     return;
5451 
5452   Builder.restoreIP(AllocaIP);
5453   // Detect if we have any capture size requiring runtime evaluation of the
5454   // size so that a constant array could be eventually used.
5455   ArrayType *PointerArrayType =
5456       ArrayType::get(Builder.getPtrTy(), Info.NumberOfPtrs);
5457 
5458   Info.RTArgs.BasePointersArray = Builder.CreateAlloca(
5459       PointerArrayType, /* ArraySize = */ nullptr, ".offload_baseptrs");
5460 
5461   Info.RTArgs.PointersArray = Builder.CreateAlloca(
5462       PointerArrayType, /* ArraySize = */ nullptr, ".offload_ptrs");
5463   AllocaInst *MappersArray = Builder.CreateAlloca(
5464       PointerArrayType, /* ArraySize = */ nullptr, ".offload_mappers");
5465   Info.RTArgs.MappersArray = MappersArray;
5466 
5467   // If we don't have any VLA types or other types that require runtime
5468   // evaluation, we can use a constant array for the map sizes, otherwise we
5469   // need to fill up the arrays as we do for the pointers.
5470   Type *Int64Ty = Builder.getInt64Ty();
5471   SmallVector<Constant *> ConstSizes(CombinedInfo.Sizes.size(),
5472                                      ConstantInt::get(Int64Ty, 0));
5473   SmallBitVector RuntimeSizes(CombinedInfo.Sizes.size());
5474   for (unsigned I = 0, E = CombinedInfo.Sizes.size(); I < E; ++I) {
5475     if (auto *CI = dyn_cast<Constant>(CombinedInfo.Sizes[I])) {
5476       if (!isa<ConstantExpr>(CI) && !isa<GlobalValue>(CI)) {
5477         if (IsNonContiguous &&
5478             static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
5479                 CombinedInfo.Types[I] &
5480                 OpenMPOffloadMappingFlags::OMP_MAP_NON_CONTIG))
5481           ConstSizes[I] =
5482               ConstantInt::get(Int64Ty, CombinedInfo.NonContigInfo.Dims[I]);
5483         else
5484           ConstSizes[I] = CI;
5485         continue;
5486       }
5487     }
5488     RuntimeSizes.set(I);
5489   }
5490 
5491   if (RuntimeSizes.all()) {
5492     ArrayType *SizeArrayType = ArrayType::get(Int64Ty, Info.NumberOfPtrs);
5493     Info.RTArgs.SizesArray = Builder.CreateAlloca(
5494         SizeArrayType, /* ArraySize = */ nullptr, ".offload_sizes");
5495     Builder.restoreIP(CodeGenIP);
5496   } else {
5497     auto *SizesArrayInit = ConstantArray::get(
5498         ArrayType::get(Int64Ty, ConstSizes.size()), ConstSizes);
5499     std::string Name = createPlatformSpecificName({"offload_sizes"});
5500     auto *SizesArrayGbl =
5501         new GlobalVariable(M, SizesArrayInit->getType(), /*isConstant=*/true,
5502                            GlobalValue::PrivateLinkage, SizesArrayInit, Name);
5503     SizesArrayGbl->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
5504 
5505     if (!RuntimeSizes.any()) {
5506       Info.RTArgs.SizesArray = SizesArrayGbl;
5507     } else {
5508       unsigned IndexSize = M.getDataLayout().getIndexSizeInBits(0);
5509       Align OffloadSizeAlign = M.getDataLayout().getABIIntegerTypeAlignment(64);
5510       ArrayType *SizeArrayType = ArrayType::get(Int64Ty, Info.NumberOfPtrs);
5511       AllocaInst *Buffer = Builder.CreateAlloca(
5512           SizeArrayType, /* ArraySize = */ nullptr, ".offload_sizes");
5513       Buffer->setAlignment(OffloadSizeAlign);
5514       Builder.restoreIP(CodeGenIP);
5515       Builder.CreateMemCpy(
5516           Buffer, M.getDataLayout().getPrefTypeAlign(Buffer->getType()),
5517           SizesArrayGbl, OffloadSizeAlign,
5518           Builder.getIntN(
5519               IndexSize,
5520               Buffer->getAllocationSize(M.getDataLayout())->getFixedValue()));
5521 
5522       Info.RTArgs.SizesArray = Buffer;
5523     }
5524     Builder.restoreIP(CodeGenIP);
5525   }
5526 
5527   // The map types are always constant so we don't need to generate code to
5528   // fill arrays. Instead, we create an array constant.
5529   SmallVector<uint64_t, 4> Mapping;
5530   for (auto mapFlag : CombinedInfo.Types)
5531     Mapping.push_back(
5532         static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
5533             mapFlag));
5534   std::string MaptypesName = createPlatformSpecificName({"offload_maptypes"});
5535   auto *MapTypesArrayGbl = createOffloadMaptypes(Mapping, MaptypesName);
5536   Info.RTArgs.MapTypesArray = MapTypesArrayGbl;
5537 
5538   // The information types are only built if provided.
5539   if (!CombinedInfo.Names.empty()) {
5540     std::string MapnamesName = createPlatformSpecificName({"offload_mapnames"});
5541     auto *MapNamesArrayGbl =
5542         createOffloadMapnames(CombinedInfo.Names, MapnamesName);
5543     Info.RTArgs.MapNamesArray = MapNamesArrayGbl;
5544   } else {
5545     Info.RTArgs.MapNamesArray =
5546         Constant::getNullValue(PointerType::getUnqual(Builder.getContext()));
5547   }
5548 
5549   // If there's a present map type modifier, it must not be applied to the end
5550   // of a region, so generate a separate map type array in that case.
5551   if (Info.separateBeginEndCalls()) {
5552     bool EndMapTypesDiffer = false;
5553     for (uint64_t &Type : Mapping) {
5554       if (Type & static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
5555                      OpenMPOffloadMappingFlags::OMP_MAP_PRESENT)) {
5556         Type &= ~static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
5557             OpenMPOffloadMappingFlags::OMP_MAP_PRESENT);
5558         EndMapTypesDiffer = true;
5559       }
5560     }
5561     if (EndMapTypesDiffer) {
5562       MapTypesArrayGbl = createOffloadMaptypes(Mapping, MaptypesName);
5563       Info.RTArgs.MapTypesArrayEnd = MapTypesArrayGbl;
5564     }
5565   }
5566 
5567   PointerType *PtrTy = Builder.getPtrTy();
5568   for (unsigned I = 0; I < Info.NumberOfPtrs; ++I) {
5569     Value *BPVal = CombinedInfo.BasePointers[I];
5570     Value *BP = Builder.CreateConstInBoundsGEP2_32(
5571         ArrayType::get(PtrTy, Info.NumberOfPtrs), Info.RTArgs.BasePointersArray,
5572         0, I);
5573     Builder.CreateAlignedStore(BPVal, BP,
5574                                M.getDataLayout().getPrefTypeAlign(PtrTy));
5575 
5576     if (Info.requiresDevicePointerInfo()) {
5577       if (CombinedInfo.DevicePointers[I] == DeviceInfoTy::Pointer) {
5578         CodeGenIP = Builder.saveIP();
5579         Builder.restoreIP(AllocaIP);
5580         Info.DevicePtrInfoMap[BPVal] = {BP, Builder.CreateAlloca(PtrTy)};
5581         Builder.restoreIP(CodeGenIP);
5582         if (DeviceAddrCB)
5583           DeviceAddrCB(I, Info.DevicePtrInfoMap[BPVal].second);
5584       } else if (CombinedInfo.DevicePointers[I] == DeviceInfoTy::Address) {
5585         Info.DevicePtrInfoMap[BPVal] = {BP, BP};
5586         if (DeviceAddrCB)
5587           DeviceAddrCB(I, BP);
5588       }
5589     }
5590 
5591     Value *PVal = CombinedInfo.Pointers[I];
5592     Value *P = Builder.CreateConstInBoundsGEP2_32(
5593         ArrayType::get(PtrTy, Info.NumberOfPtrs), Info.RTArgs.PointersArray, 0,
5594         I);
5595     // TODO: Check alignment correct.
5596     Builder.CreateAlignedStore(PVal, P,
5597                                M.getDataLayout().getPrefTypeAlign(PtrTy));
5598 
5599     if (RuntimeSizes.test(I)) {
5600       Value *S = Builder.CreateConstInBoundsGEP2_32(
5601           ArrayType::get(Int64Ty, Info.NumberOfPtrs), Info.RTArgs.SizesArray,
5602           /*Idx0=*/0,
5603           /*Idx1=*/I);
5604       Builder.CreateAlignedStore(Builder.CreateIntCast(CombinedInfo.Sizes[I],
5605                                                        Int64Ty,
5606                                                        /*isSigned=*/true),
5607                                  S, M.getDataLayout().getPrefTypeAlign(PtrTy));
5608     }
5609     // Fill up the mapper array.
5610     unsigned IndexSize = M.getDataLayout().getIndexSizeInBits(0);
5611     Value *MFunc = ConstantPointerNull::get(PtrTy);
5612     if (CustomMapperCB)
5613       if (Value *CustomMFunc = CustomMapperCB(I))
5614         MFunc = Builder.CreatePointerCast(CustomMFunc, PtrTy);
5615     Value *MAddr = Builder.CreateInBoundsGEP(
5616         MappersArray->getAllocatedType(), MappersArray,
5617         {Builder.getIntN(IndexSize, 0), Builder.getIntN(IndexSize, I)});
5618     Builder.CreateAlignedStore(
5619         MFunc, MAddr, M.getDataLayout().getPrefTypeAlign(MAddr->getType()));
5620   }
5621 
5622   if (!IsNonContiguous || CombinedInfo.NonContigInfo.Offsets.empty() ||
5623       Info.NumberOfPtrs == 0)
5624     return;
5625   emitNonContiguousDescriptor(AllocaIP, CodeGenIP, CombinedInfo, Info);
5626 }
5627 
5628 void OpenMPIRBuilder::emitBranch(BasicBlock *Target) {
5629   BasicBlock *CurBB = Builder.GetInsertBlock();
5630 
5631   if (!CurBB || CurBB->getTerminator()) {
5632     // If there is no insert point or the previous block is already
5633     // terminated, don't touch it.
5634   } else {
5635     // Otherwise, create a fall-through branch.
5636     Builder.CreateBr(Target);
5637   }
5638 
5639   Builder.ClearInsertionPoint();
5640 }
5641 
5642 void OpenMPIRBuilder::emitBlock(BasicBlock *BB, Function *CurFn,
5643                                 bool IsFinished) {
5644   BasicBlock *CurBB = Builder.GetInsertBlock();
5645 
5646   // Fall out of the current block (if necessary).
5647   emitBranch(BB);
5648 
5649   if (IsFinished && BB->use_empty()) {
5650     BB->eraseFromParent();
5651     return;
5652   }
5653 
5654   // Place the block after the current block, if possible, or else at
5655   // the end of the function.
5656   if (CurBB && CurBB->getParent())
5657     CurFn->insert(std::next(CurBB->getIterator()), BB);
5658   else
5659     CurFn->insert(CurFn->end(), BB);
5660   Builder.SetInsertPoint(BB);
5661 }
5662 
5663 void OpenMPIRBuilder::emitIfClause(Value *Cond, BodyGenCallbackTy ThenGen,
5664                                    BodyGenCallbackTy ElseGen,
5665                                    InsertPointTy AllocaIP) {
5666   // If the condition constant folds and can be elided, try to avoid emitting
5667   // the condition and the dead arm of the if/else.
5668   if (auto *CI = dyn_cast<ConstantInt>(Cond)) {
5669     auto CondConstant = CI->getSExtValue();
5670     if (CondConstant)
5671       ThenGen(AllocaIP, Builder.saveIP());
5672     else
5673       ElseGen(AllocaIP, Builder.saveIP());
5674     return;
5675   }
5676 
5677   Function *CurFn = Builder.GetInsertBlock()->getParent();
5678 
5679   // Otherwise, the condition did not fold, or we couldn't elide it.  Just
5680   // emit the conditional branch.
5681   BasicBlock *ThenBlock = BasicBlock::Create(M.getContext(), "omp_if.then");
5682   BasicBlock *ElseBlock = BasicBlock::Create(M.getContext(), "omp_if.else");
5683   BasicBlock *ContBlock = BasicBlock::Create(M.getContext(), "omp_if.end");
5684   Builder.CreateCondBr(Cond, ThenBlock, ElseBlock);
5685   // Emit the 'then' code.
5686   emitBlock(ThenBlock, CurFn);
5687   ThenGen(AllocaIP, Builder.saveIP());
5688   emitBranch(ContBlock);
5689   // Emit the 'else' code if present.
5690   // There is no need to emit line number for unconditional branch.
5691   emitBlock(ElseBlock, CurFn);
5692   ElseGen(AllocaIP, Builder.saveIP());
5693   // There is no need to emit line number for unconditional branch.
5694   emitBranch(ContBlock);
5695   // Emit the continuation block for code after the if.
5696   emitBlock(ContBlock, CurFn, /*IsFinished=*/true);
5697 }
5698 
5699 bool OpenMPIRBuilder::checkAndEmitFlushAfterAtomic(
5700     const LocationDescription &Loc, llvm::AtomicOrdering AO, AtomicKind AK) {
5701   assert(!(AO == AtomicOrdering::NotAtomic ||
5702            AO == llvm::AtomicOrdering::Unordered) &&
5703          "Unexpected Atomic Ordering.");
5704 
5705   bool Flush = false;
5706   llvm::AtomicOrdering FlushAO = AtomicOrdering::Monotonic;
5707 
5708   switch (AK) {
5709   case Read:
5710     if (AO == AtomicOrdering::Acquire || AO == AtomicOrdering::AcquireRelease ||
5711         AO == AtomicOrdering::SequentiallyConsistent) {
5712       FlushAO = AtomicOrdering::Acquire;
5713       Flush = true;
5714     }
5715     break;
5716   case Write:
5717   case Compare:
5718   case Update:
5719     if (AO == AtomicOrdering::Release || AO == AtomicOrdering::AcquireRelease ||
5720         AO == AtomicOrdering::SequentiallyConsistent) {
5721       FlushAO = AtomicOrdering::Release;
5722       Flush = true;
5723     }
5724     break;
5725   case Capture:
5726     switch (AO) {
5727     case AtomicOrdering::Acquire:
5728       FlushAO = AtomicOrdering::Acquire;
5729       Flush = true;
5730       break;
5731     case AtomicOrdering::Release:
5732       FlushAO = AtomicOrdering::Release;
5733       Flush = true;
5734       break;
5735     case AtomicOrdering::AcquireRelease:
5736     case AtomicOrdering::SequentiallyConsistent:
5737       FlushAO = AtomicOrdering::AcquireRelease;
5738       Flush = true;
5739       break;
5740     default:
5741       // do nothing - leave silently.
5742       break;
5743     }
5744   }
5745 
5746   if (Flush) {
5747     // Currently Flush RT call still doesn't take memory_ordering, so for when
5748     // that happens, this tries to do the resolution of which atomic ordering
5749     // to use with but issue the flush call
5750     // TODO: pass `FlushAO` after memory ordering support is added
5751     (void)FlushAO;
5752     emitFlush(Loc);
5753   }
5754 
5755   // for AO == AtomicOrdering::Monotonic and  all other case combinations
5756   // do nothing
5757   return Flush;
5758 }
5759 
5760 OpenMPIRBuilder::InsertPointTy
5761 OpenMPIRBuilder::createAtomicRead(const LocationDescription &Loc,
5762                                   AtomicOpValue &X, AtomicOpValue &V,
5763                                   AtomicOrdering AO) {
5764   if (!updateToLocation(Loc))
5765     return Loc.IP;
5766 
5767   assert(X.Var->getType()->isPointerTy() &&
5768          "OMP Atomic expects a pointer to target memory");
5769   Type *XElemTy = X.ElemTy;
5770   assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
5771           XElemTy->isPointerTy()) &&
5772          "OMP atomic read expected a scalar type");
5773 
5774   Value *XRead = nullptr;
5775 
5776   if (XElemTy->isIntegerTy()) {
5777     LoadInst *XLD =
5778         Builder.CreateLoad(XElemTy, X.Var, X.IsVolatile, "omp.atomic.read");
5779     XLD->setAtomic(AO);
5780     XRead = cast<Value>(XLD);
5781   } else {
5782     // We need to perform atomic op as integer
5783     IntegerType *IntCastTy =
5784         IntegerType::get(M.getContext(), XElemTy->getScalarSizeInBits());
5785     LoadInst *XLoad =
5786         Builder.CreateLoad(IntCastTy, X.Var, X.IsVolatile, "omp.atomic.load");
5787     XLoad->setAtomic(AO);
5788     if (XElemTy->isFloatingPointTy()) {
5789       XRead = Builder.CreateBitCast(XLoad, XElemTy, "atomic.flt.cast");
5790     } else {
5791       XRead = Builder.CreateIntToPtr(XLoad, XElemTy, "atomic.ptr.cast");
5792     }
5793   }
5794   checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Read);
5795   Builder.CreateStore(XRead, V.Var, V.IsVolatile);
5796   return Builder.saveIP();
5797 }
5798 
5799 OpenMPIRBuilder::InsertPointTy
5800 OpenMPIRBuilder::createAtomicWrite(const LocationDescription &Loc,
5801                                    AtomicOpValue &X, Value *Expr,
5802                                    AtomicOrdering AO) {
5803   if (!updateToLocation(Loc))
5804     return Loc.IP;
5805 
5806   assert(X.Var->getType()->isPointerTy() &&
5807          "OMP Atomic expects a pointer to target memory");
5808   Type *XElemTy = X.ElemTy;
5809   assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
5810           XElemTy->isPointerTy()) &&
5811          "OMP atomic write expected a scalar type");
5812 
5813   if (XElemTy->isIntegerTy()) {
5814     StoreInst *XSt = Builder.CreateStore(Expr, X.Var, X.IsVolatile);
5815     XSt->setAtomic(AO);
5816   } else {
5817     // We need to bitcast and perform atomic op as integers
5818     IntegerType *IntCastTy =
5819         IntegerType::get(M.getContext(), XElemTy->getScalarSizeInBits());
5820     Value *ExprCast =
5821         Builder.CreateBitCast(Expr, IntCastTy, "atomic.src.int.cast");
5822     StoreInst *XSt = Builder.CreateStore(ExprCast, X.Var, X.IsVolatile);
5823     XSt->setAtomic(AO);
5824   }
5825 
5826   checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Write);
5827   return Builder.saveIP();
5828 }
5829 
5830 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicUpdate(
5831     const LocationDescription &Loc, InsertPointTy AllocaIP, AtomicOpValue &X,
5832     Value *Expr, AtomicOrdering AO, AtomicRMWInst::BinOp RMWOp,
5833     AtomicUpdateCallbackTy &UpdateOp, bool IsXBinopExpr) {
5834   assert(!isConflictIP(Loc.IP, AllocaIP) && "IPs must not be ambiguous");
5835   if (!updateToLocation(Loc))
5836     return Loc.IP;
5837 
5838   LLVM_DEBUG({
5839     Type *XTy = X.Var->getType();
5840     assert(XTy->isPointerTy() &&
5841            "OMP Atomic expects a pointer to target memory");
5842     Type *XElemTy = X.ElemTy;
5843     assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
5844             XElemTy->isPointerTy()) &&
5845            "OMP atomic update expected a scalar type");
5846     assert((RMWOp != AtomicRMWInst::Max) && (RMWOp != AtomicRMWInst::Min) &&
5847            (RMWOp != AtomicRMWInst::UMax) && (RMWOp != AtomicRMWInst::UMin) &&
5848            "OpenMP atomic does not support LT or GT operations");
5849   });
5850 
5851   emitAtomicUpdate(AllocaIP, X.Var, X.ElemTy, Expr, AO, RMWOp, UpdateOp,
5852                    X.IsVolatile, IsXBinopExpr);
5853   checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Update);
5854   return Builder.saveIP();
5855 }
5856 
5857 // FIXME: Duplicating AtomicExpand
5858 Value *OpenMPIRBuilder::emitRMWOpAsInstruction(Value *Src1, Value *Src2,
5859                                                AtomicRMWInst::BinOp RMWOp) {
5860   switch (RMWOp) {
5861   case AtomicRMWInst::Add:
5862     return Builder.CreateAdd(Src1, Src2);
5863   case AtomicRMWInst::Sub:
5864     return Builder.CreateSub(Src1, Src2);
5865   case AtomicRMWInst::And:
5866     return Builder.CreateAnd(Src1, Src2);
5867   case AtomicRMWInst::Nand:
5868     return Builder.CreateNeg(Builder.CreateAnd(Src1, Src2));
5869   case AtomicRMWInst::Or:
5870     return Builder.CreateOr(Src1, Src2);
5871   case AtomicRMWInst::Xor:
5872     return Builder.CreateXor(Src1, Src2);
5873   case AtomicRMWInst::Xchg:
5874   case AtomicRMWInst::FAdd:
5875   case AtomicRMWInst::FSub:
5876   case AtomicRMWInst::BAD_BINOP:
5877   case AtomicRMWInst::Max:
5878   case AtomicRMWInst::Min:
5879   case AtomicRMWInst::UMax:
5880   case AtomicRMWInst::UMin:
5881   case AtomicRMWInst::FMax:
5882   case AtomicRMWInst::FMin:
5883   case AtomicRMWInst::UIncWrap:
5884   case AtomicRMWInst::UDecWrap:
5885     llvm_unreachable("Unsupported atomic update operation");
5886   }
5887   llvm_unreachable("Unsupported atomic update operation");
5888 }
5889 
5890 std::pair<Value *, Value *> OpenMPIRBuilder::emitAtomicUpdate(
5891     InsertPointTy AllocaIP, Value *X, Type *XElemTy, Value *Expr,
5892     AtomicOrdering AO, AtomicRMWInst::BinOp RMWOp,
5893     AtomicUpdateCallbackTy &UpdateOp, bool VolatileX, bool IsXBinopExpr) {
5894   // TODO: handle the case where XElemTy is not byte-sized or not a power of 2
5895   // or a complex datatype.
5896   bool emitRMWOp = false;
5897   switch (RMWOp) {
5898   case AtomicRMWInst::Add:
5899   case AtomicRMWInst::And:
5900   case AtomicRMWInst::Nand:
5901   case AtomicRMWInst::Or:
5902   case AtomicRMWInst::Xor:
5903   case AtomicRMWInst::Xchg:
5904     emitRMWOp = XElemTy;
5905     break;
5906   case AtomicRMWInst::Sub:
5907     emitRMWOp = (IsXBinopExpr && XElemTy);
5908     break;
5909   default:
5910     emitRMWOp = false;
5911   }
5912   emitRMWOp &= XElemTy->isIntegerTy();
5913 
5914   std::pair<Value *, Value *> Res;
5915   if (emitRMWOp) {
5916     Res.first = Builder.CreateAtomicRMW(RMWOp, X, Expr, llvm::MaybeAlign(), AO);
5917     // not needed except in case of postfix captures. Generate anyway for
5918     // consistency with the else part. Will be removed with any DCE pass.
5919     // AtomicRMWInst::Xchg does not have a coressponding instruction.
5920     if (RMWOp == AtomicRMWInst::Xchg)
5921       Res.second = Res.first;
5922     else
5923       Res.second = emitRMWOpAsInstruction(Res.first, Expr, RMWOp);
5924   } else {
5925     IntegerType *IntCastTy =
5926         IntegerType::get(M.getContext(), XElemTy->getScalarSizeInBits());
5927     LoadInst *OldVal =
5928         Builder.CreateLoad(IntCastTy, X, X->getName() + ".atomic.load");
5929     OldVal->setAtomic(AO);
5930     // CurBB
5931     // |     /---\
5932 		// ContBB    |
5933     // |     \---/
5934     // ExitBB
5935     BasicBlock *CurBB = Builder.GetInsertBlock();
5936     Instruction *CurBBTI = CurBB->getTerminator();
5937     CurBBTI = CurBBTI ? CurBBTI : Builder.CreateUnreachable();
5938     BasicBlock *ExitBB =
5939         CurBB->splitBasicBlock(CurBBTI, X->getName() + ".atomic.exit");
5940     BasicBlock *ContBB = CurBB->splitBasicBlock(CurBB->getTerminator(),
5941                                                 X->getName() + ".atomic.cont");
5942     ContBB->getTerminator()->eraseFromParent();
5943     Builder.restoreIP(AllocaIP);
5944     AllocaInst *NewAtomicAddr = Builder.CreateAlloca(XElemTy);
5945     NewAtomicAddr->setName(X->getName() + "x.new.val");
5946     Builder.SetInsertPoint(ContBB);
5947     llvm::PHINode *PHI = Builder.CreatePHI(OldVal->getType(), 2);
5948     PHI->addIncoming(OldVal, CurBB);
5949     bool IsIntTy = XElemTy->isIntegerTy();
5950     Value *OldExprVal = PHI;
5951     if (!IsIntTy) {
5952       if (XElemTy->isFloatingPointTy()) {
5953         OldExprVal = Builder.CreateBitCast(PHI, XElemTy,
5954                                            X->getName() + ".atomic.fltCast");
5955       } else {
5956         OldExprVal = Builder.CreateIntToPtr(PHI, XElemTy,
5957                                             X->getName() + ".atomic.ptrCast");
5958       }
5959     }
5960 
5961     Value *Upd = UpdateOp(OldExprVal, Builder);
5962     Builder.CreateStore(Upd, NewAtomicAddr);
5963     LoadInst *DesiredVal = Builder.CreateLoad(IntCastTy, NewAtomicAddr);
5964     AtomicOrdering Failure =
5965         llvm::AtomicCmpXchgInst::getStrongestFailureOrdering(AO);
5966     AtomicCmpXchgInst *Result = Builder.CreateAtomicCmpXchg(
5967         X, PHI, DesiredVal, llvm::MaybeAlign(), AO, Failure);
5968     Result->setVolatile(VolatileX);
5969     Value *PreviousVal = Builder.CreateExtractValue(Result, /*Idxs=*/0);
5970     Value *SuccessFailureVal = Builder.CreateExtractValue(Result, /*Idxs=*/1);
5971     PHI->addIncoming(PreviousVal, Builder.GetInsertBlock());
5972     Builder.CreateCondBr(SuccessFailureVal, ExitBB, ContBB);
5973 
5974     Res.first = OldExprVal;
5975     Res.second = Upd;
5976 
5977     // set Insertion point in exit block
5978     if (UnreachableInst *ExitTI =
5979             dyn_cast<UnreachableInst>(ExitBB->getTerminator())) {
5980       CurBBTI->eraseFromParent();
5981       Builder.SetInsertPoint(ExitBB);
5982     } else {
5983       Builder.SetInsertPoint(ExitTI);
5984     }
5985   }
5986 
5987   return Res;
5988 }
5989 
5990 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCapture(
5991     const LocationDescription &Loc, InsertPointTy AllocaIP, AtomicOpValue &X,
5992     AtomicOpValue &V, Value *Expr, AtomicOrdering AO,
5993     AtomicRMWInst::BinOp RMWOp, AtomicUpdateCallbackTy &UpdateOp,
5994     bool UpdateExpr, bool IsPostfixUpdate, bool IsXBinopExpr) {
5995   if (!updateToLocation(Loc))
5996     return Loc.IP;
5997 
5998   LLVM_DEBUG({
5999     Type *XTy = X.Var->getType();
6000     assert(XTy->isPointerTy() &&
6001            "OMP Atomic expects a pointer to target memory");
6002     Type *XElemTy = X.ElemTy;
6003     assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
6004             XElemTy->isPointerTy()) &&
6005            "OMP atomic capture expected a scalar type");
6006     assert((RMWOp != AtomicRMWInst::Max) && (RMWOp != AtomicRMWInst::Min) &&
6007            "OpenMP atomic does not support LT or GT operations");
6008   });
6009 
6010   // If UpdateExpr is 'x' updated with some `expr` not based on 'x',
6011   // 'x' is simply atomically rewritten with 'expr'.
6012   AtomicRMWInst::BinOp AtomicOp = (UpdateExpr ? RMWOp : AtomicRMWInst::Xchg);
6013   std::pair<Value *, Value *> Result =
6014       emitAtomicUpdate(AllocaIP, X.Var, X.ElemTy, Expr, AO, AtomicOp, UpdateOp,
6015                        X.IsVolatile, IsXBinopExpr);
6016 
6017   Value *CapturedVal = (IsPostfixUpdate ? Result.first : Result.second);
6018   Builder.CreateStore(CapturedVal, V.Var, V.IsVolatile);
6019 
6020   checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Capture);
6021   return Builder.saveIP();
6022 }
6023 
6024 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare(
6025     const LocationDescription &Loc, AtomicOpValue &X, AtomicOpValue &V,
6026     AtomicOpValue &R, Value *E, Value *D, AtomicOrdering AO,
6027     omp::OMPAtomicCompareOp Op, bool IsXBinopExpr, bool IsPostfixUpdate,
6028     bool IsFailOnly) {
6029 
6030   AtomicOrdering Failure = AtomicCmpXchgInst::getStrongestFailureOrdering(AO);
6031   return createAtomicCompare(Loc, X, V, R, E, D, AO, Op, IsXBinopExpr,
6032                              IsPostfixUpdate, IsFailOnly, Failure);
6033 }
6034 
6035 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare(
6036     const LocationDescription &Loc, AtomicOpValue &X, AtomicOpValue &V,
6037     AtomicOpValue &R, Value *E, Value *D, AtomicOrdering AO,
6038     omp::OMPAtomicCompareOp Op, bool IsXBinopExpr, bool IsPostfixUpdate,
6039     bool IsFailOnly, AtomicOrdering Failure) {
6040 
6041   if (!updateToLocation(Loc))
6042     return Loc.IP;
6043 
6044   assert(X.Var->getType()->isPointerTy() &&
6045          "OMP atomic expects a pointer to target memory");
6046   // compare capture
6047   if (V.Var) {
6048     assert(V.Var->getType()->isPointerTy() && "v.var must be of pointer type");
6049     assert(V.ElemTy == X.ElemTy && "x and v must be of same type");
6050   }
6051 
6052   bool IsInteger = E->getType()->isIntegerTy();
6053 
6054   if (Op == OMPAtomicCompareOp::EQ) {
6055     AtomicCmpXchgInst *Result = nullptr;
6056     if (!IsInteger) {
6057       IntegerType *IntCastTy =
6058           IntegerType::get(M.getContext(), X.ElemTy->getScalarSizeInBits());
6059       Value *EBCast = Builder.CreateBitCast(E, IntCastTy);
6060       Value *DBCast = Builder.CreateBitCast(D, IntCastTy);
6061       Result = Builder.CreateAtomicCmpXchg(X.Var, EBCast, DBCast, MaybeAlign(),
6062                                            AO, Failure);
6063     } else {
6064       Result =
6065           Builder.CreateAtomicCmpXchg(X.Var, E, D, MaybeAlign(), AO, Failure);
6066     }
6067 
6068     if (V.Var) {
6069       Value *OldValue = Builder.CreateExtractValue(Result, /*Idxs=*/0);
6070       if (!IsInteger)
6071         OldValue = Builder.CreateBitCast(OldValue, X.ElemTy);
6072       assert(OldValue->getType() == V.ElemTy &&
6073              "OldValue and V must be of same type");
6074       if (IsPostfixUpdate) {
6075         Builder.CreateStore(OldValue, V.Var, V.IsVolatile);
6076       } else {
6077         Value *SuccessOrFail = Builder.CreateExtractValue(Result, /*Idxs=*/1);
6078         if (IsFailOnly) {
6079           // CurBB----
6080           //   |     |
6081           //   v     |
6082           // ContBB  |
6083           //   |     |
6084           //   v     |
6085           // ExitBB <-
6086           //
6087           // where ContBB only contains the store of old value to 'v'.
6088           BasicBlock *CurBB = Builder.GetInsertBlock();
6089           Instruction *CurBBTI = CurBB->getTerminator();
6090           CurBBTI = CurBBTI ? CurBBTI : Builder.CreateUnreachable();
6091           BasicBlock *ExitBB = CurBB->splitBasicBlock(
6092               CurBBTI, X.Var->getName() + ".atomic.exit");
6093           BasicBlock *ContBB = CurBB->splitBasicBlock(
6094               CurBB->getTerminator(), X.Var->getName() + ".atomic.cont");
6095           ContBB->getTerminator()->eraseFromParent();
6096           CurBB->getTerminator()->eraseFromParent();
6097 
6098           Builder.CreateCondBr(SuccessOrFail, ExitBB, ContBB);
6099 
6100           Builder.SetInsertPoint(ContBB);
6101           Builder.CreateStore(OldValue, V.Var);
6102           Builder.CreateBr(ExitBB);
6103 
6104           if (UnreachableInst *ExitTI =
6105                   dyn_cast<UnreachableInst>(ExitBB->getTerminator())) {
6106             CurBBTI->eraseFromParent();
6107             Builder.SetInsertPoint(ExitBB);
6108           } else {
6109             Builder.SetInsertPoint(ExitTI);
6110           }
6111         } else {
6112           Value *CapturedValue =
6113               Builder.CreateSelect(SuccessOrFail, E, OldValue);
6114           Builder.CreateStore(CapturedValue, V.Var, V.IsVolatile);
6115         }
6116       }
6117     }
6118     // The comparison result has to be stored.
6119     if (R.Var) {
6120       assert(R.Var->getType()->isPointerTy() &&
6121              "r.var must be of pointer type");
6122       assert(R.ElemTy->isIntegerTy() && "r must be of integral type");
6123 
6124       Value *SuccessFailureVal = Builder.CreateExtractValue(Result, /*Idxs=*/1);
6125       Value *ResultCast = R.IsSigned
6126                               ? Builder.CreateSExt(SuccessFailureVal, R.ElemTy)
6127                               : Builder.CreateZExt(SuccessFailureVal, R.ElemTy);
6128       Builder.CreateStore(ResultCast, R.Var, R.IsVolatile);
6129     }
6130   } else {
6131     assert((Op == OMPAtomicCompareOp::MAX || Op == OMPAtomicCompareOp::MIN) &&
6132            "Op should be either max or min at this point");
6133     assert(!IsFailOnly && "IsFailOnly is only valid when the comparison is ==");
6134 
6135     // Reverse the ordop as the OpenMP forms are different from LLVM forms.
6136     // Let's take max as example.
6137     // OpenMP form:
6138     // x = x > expr ? expr : x;
6139     // LLVM form:
6140     // *ptr = *ptr > val ? *ptr : val;
6141     // We need to transform to LLVM form.
6142     // x = x <= expr ? x : expr;
6143     AtomicRMWInst::BinOp NewOp;
6144     if (IsXBinopExpr) {
6145       if (IsInteger) {
6146         if (X.IsSigned)
6147           NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::Min
6148                                                 : AtomicRMWInst::Max;
6149         else
6150           NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::UMin
6151                                                 : AtomicRMWInst::UMax;
6152       } else {
6153         NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::FMin
6154                                               : AtomicRMWInst::FMax;
6155       }
6156     } else {
6157       if (IsInteger) {
6158         if (X.IsSigned)
6159           NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::Max
6160                                                 : AtomicRMWInst::Min;
6161         else
6162           NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::UMax
6163                                                 : AtomicRMWInst::UMin;
6164       } else {
6165         NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::FMax
6166                                               : AtomicRMWInst::FMin;
6167       }
6168     }
6169 
6170     AtomicRMWInst *OldValue =
6171         Builder.CreateAtomicRMW(NewOp, X.Var, E, MaybeAlign(), AO);
6172     if (V.Var) {
6173       Value *CapturedValue = nullptr;
6174       if (IsPostfixUpdate) {
6175         CapturedValue = OldValue;
6176       } else {
6177         CmpInst::Predicate Pred;
6178         switch (NewOp) {
6179         case AtomicRMWInst::Max:
6180           Pred = CmpInst::ICMP_SGT;
6181           break;
6182         case AtomicRMWInst::UMax:
6183           Pred = CmpInst::ICMP_UGT;
6184           break;
6185         case AtomicRMWInst::FMax:
6186           Pred = CmpInst::FCMP_OGT;
6187           break;
6188         case AtomicRMWInst::Min:
6189           Pred = CmpInst::ICMP_SLT;
6190           break;
6191         case AtomicRMWInst::UMin:
6192           Pred = CmpInst::ICMP_ULT;
6193           break;
6194         case AtomicRMWInst::FMin:
6195           Pred = CmpInst::FCMP_OLT;
6196           break;
6197         default:
6198           llvm_unreachable("unexpected comparison op");
6199         }
6200         Value *NonAtomicCmp = Builder.CreateCmp(Pred, OldValue, E);
6201         CapturedValue = Builder.CreateSelect(NonAtomicCmp, E, OldValue);
6202       }
6203       Builder.CreateStore(CapturedValue, V.Var, V.IsVolatile);
6204     }
6205   }
6206 
6207   checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Compare);
6208 
6209   return Builder.saveIP();
6210 }
6211 
6212 OpenMPIRBuilder::InsertPointTy
6213 OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
6214                              BodyGenCallbackTy BodyGenCB, Value *NumTeamsLower,
6215                              Value *NumTeamsUpper, Value *ThreadLimit,
6216                              Value *IfExpr) {
6217   if (!updateToLocation(Loc))
6218     return InsertPointTy();
6219 
6220   uint32_t SrcLocStrSize;
6221   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
6222   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
6223   Function *CurrentFunction = Builder.GetInsertBlock()->getParent();
6224 
6225   // Outer allocation basicblock is the entry block of the current function.
6226   BasicBlock &OuterAllocaBB = CurrentFunction->getEntryBlock();
6227   if (&OuterAllocaBB == Builder.GetInsertBlock()) {
6228     BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "teams.entry");
6229     Builder.SetInsertPoint(BodyBB, BodyBB->begin());
6230   }
6231 
6232   // The current basic block is split into four basic blocks. After outlining,
6233   // they will be mapped as follows:
6234   // ```
6235   // def current_fn() {
6236   //   current_basic_block:
6237   //     br label %teams.exit
6238   //   teams.exit:
6239   //     ; instructions after teams
6240   // }
6241   //
6242   // def outlined_fn() {
6243   //   teams.alloca:
6244   //     br label %teams.body
6245   //   teams.body:
6246   //     ; instructions within teams body
6247   // }
6248   // ```
6249   BasicBlock *ExitBB = splitBB(Builder, /*CreateBranch=*/true, "teams.exit");
6250   BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "teams.body");
6251   BasicBlock *AllocaBB =
6252       splitBB(Builder, /*CreateBranch=*/true, "teams.alloca");
6253 
6254   // Push num_teams
6255   if (NumTeamsLower || NumTeamsUpper || ThreadLimit || IfExpr) {
6256     assert((NumTeamsLower == nullptr || NumTeamsUpper != nullptr) &&
6257            "if lowerbound is non-null, then upperbound must also be non-null "
6258            "for bounds on num_teams");
6259 
6260     if (NumTeamsUpper == nullptr)
6261       NumTeamsUpper = Builder.getInt32(0);
6262 
6263     if (NumTeamsLower == nullptr)
6264       NumTeamsLower = NumTeamsUpper;
6265 
6266     if (IfExpr) {
6267       assert(IfExpr->getType()->isIntegerTy() &&
6268              "argument to if clause must be an integer value");
6269 
6270       // upper = ifexpr ? upper : 1
6271       if (IfExpr->getType() != Int1)
6272         IfExpr = Builder.CreateICmpNE(IfExpr,
6273                                       ConstantInt::get(IfExpr->getType(), 0));
6274       NumTeamsUpper = Builder.CreateSelect(
6275           IfExpr, NumTeamsUpper, Builder.getInt32(1), "numTeamsUpper");
6276 
6277       // lower = ifexpr ? lower : 1
6278       NumTeamsLower = Builder.CreateSelect(
6279           IfExpr, NumTeamsLower, Builder.getInt32(1), "numTeamsLower");
6280     }
6281 
6282     if (ThreadLimit == nullptr)
6283       ThreadLimit = Builder.getInt32(0);
6284 
6285     Value *ThreadNum = getOrCreateThreadID(Ident);
6286     Builder.CreateCall(
6287         getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_num_teams_51),
6288         {Ident, ThreadNum, NumTeamsLower, NumTeamsUpper, ThreadLimit});
6289   }
6290   // Generate the body of teams.
6291   InsertPointTy AllocaIP(AllocaBB, AllocaBB->begin());
6292   InsertPointTy CodeGenIP(BodyBB, BodyBB->begin());
6293   BodyGenCB(AllocaIP, CodeGenIP);
6294 
6295   OutlineInfo OI;
6296   OI.EntryBB = AllocaBB;
6297   OI.ExitBB = ExitBB;
6298   OI.OuterAllocaBB = &OuterAllocaBB;
6299 
6300   // Insert fake values for global tid and bound tid.
6301   std::stack<Instruction *> ToBeDeleted;
6302   InsertPointTy OuterAllocaIP(&OuterAllocaBB, OuterAllocaBB.begin());
6303   OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
6304       Builder, OuterAllocaIP, ToBeDeleted, AllocaIP, "gid", true));
6305   OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
6306       Builder, OuterAllocaIP, ToBeDeleted, AllocaIP, "tid", true));
6307 
6308   OI.PostOutlineCB = [this, Ident, ToBeDeleted](Function &OutlinedFn) mutable {
6309     // The stale call instruction will be replaced with a new call instruction
6310     // for runtime call with the outlined function.
6311 
6312     assert(OutlinedFn.getNumUses() == 1 &&
6313            "there must be a single user for the outlined function");
6314     CallInst *StaleCI = cast<CallInst>(OutlinedFn.user_back());
6315     ToBeDeleted.push(StaleCI);
6316 
6317     assert((OutlinedFn.arg_size() == 2 || OutlinedFn.arg_size() == 3) &&
6318            "Outlined function must have two or three arguments only");
6319 
6320     bool HasShared = OutlinedFn.arg_size() == 3;
6321 
6322     OutlinedFn.getArg(0)->setName("global.tid.ptr");
6323     OutlinedFn.getArg(1)->setName("bound.tid.ptr");
6324     if (HasShared)
6325       OutlinedFn.getArg(2)->setName("data");
6326 
6327     // Call to the runtime function for teams in the current function.
6328     assert(StaleCI && "Error while outlining - no CallInst user found for the "
6329                       "outlined function.");
6330     Builder.SetInsertPoint(StaleCI);
6331     SmallVector<Value *> Args = {
6332         Ident, Builder.getInt32(StaleCI->arg_size() - 2), &OutlinedFn};
6333     if (HasShared)
6334       Args.push_back(StaleCI->getArgOperand(2));
6335     Builder.CreateCall(getOrCreateRuntimeFunctionPtr(
6336                            omp::RuntimeFunction::OMPRTL___kmpc_fork_teams),
6337                        Args);
6338 
6339     while (!ToBeDeleted.empty()) {
6340       ToBeDeleted.top()->eraseFromParent();
6341       ToBeDeleted.pop();
6342     }
6343   };
6344 
6345   addOutlineInfo(std::move(OI));
6346 
6347   Builder.SetInsertPoint(ExitBB, ExitBB->begin());
6348 
6349   return Builder.saveIP();
6350 }
6351 
6352 GlobalVariable *
6353 OpenMPIRBuilder::createOffloadMapnames(SmallVectorImpl<llvm::Constant *> &Names,
6354                                        std::string VarName) {
6355   llvm::Constant *MapNamesArrayInit = llvm::ConstantArray::get(
6356       llvm::ArrayType::get(llvm::PointerType::getUnqual(M.getContext()),
6357                            Names.size()),
6358       Names);
6359   auto *MapNamesArrayGlobal = new llvm::GlobalVariable(
6360       M, MapNamesArrayInit->getType(),
6361       /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, MapNamesArrayInit,
6362       VarName);
6363   return MapNamesArrayGlobal;
6364 }
6365 
6366 // Create all simple and struct types exposed by the runtime and remember
6367 // the llvm::PointerTypes of them for easy access later.
6368 void OpenMPIRBuilder::initializeTypes(Module &M) {
6369   LLVMContext &Ctx = M.getContext();
6370   StructType *T;
6371 #define OMP_TYPE(VarName, InitValue) VarName = InitValue;
6372 #define OMP_ARRAY_TYPE(VarName, ElemTy, ArraySize)                             \
6373   VarName##Ty = ArrayType::get(ElemTy, ArraySize);                             \
6374   VarName##PtrTy = PointerType::getUnqual(VarName##Ty);
6375 #define OMP_FUNCTION_TYPE(VarName, IsVarArg, ReturnType, ...)                  \
6376   VarName = FunctionType::get(ReturnType, {__VA_ARGS__}, IsVarArg);            \
6377   VarName##Ptr = PointerType::getUnqual(VarName);
6378 #define OMP_STRUCT_TYPE(VarName, StructName, Packed, ...)                      \
6379   T = StructType::getTypeByName(Ctx, StructName);                              \
6380   if (!T)                                                                      \
6381     T = StructType::create(Ctx, {__VA_ARGS__}, StructName, Packed);            \
6382   VarName = T;                                                                 \
6383   VarName##Ptr = PointerType::getUnqual(T);
6384 #include "llvm/Frontend/OpenMP/OMPKinds.def"
6385 }
6386 
6387 void OpenMPIRBuilder::OutlineInfo::collectBlocks(
6388     SmallPtrSetImpl<BasicBlock *> &BlockSet,
6389     SmallVectorImpl<BasicBlock *> &BlockVector) {
6390   SmallVector<BasicBlock *, 32> Worklist;
6391   BlockSet.insert(EntryBB);
6392   BlockSet.insert(ExitBB);
6393 
6394   Worklist.push_back(EntryBB);
6395   while (!Worklist.empty()) {
6396     BasicBlock *BB = Worklist.pop_back_val();
6397     BlockVector.push_back(BB);
6398     for (BasicBlock *SuccBB : successors(BB))
6399       if (BlockSet.insert(SuccBB).second)
6400         Worklist.push_back(SuccBB);
6401   }
6402 }
6403 
6404 void OpenMPIRBuilder::createOffloadEntry(Constant *ID, Constant *Addr,
6405                                          uint64_t Size, int32_t Flags,
6406                                          GlobalValue::LinkageTypes,
6407                                          StringRef Name) {
6408   if (!Config.isGPU()) {
6409     llvm::offloading::emitOffloadingEntry(
6410         M, ID, Name.empty() ? Addr->getName() : Name, Size, Flags, /*Data=*/0,
6411         "omp_offloading_entries");
6412     return;
6413   }
6414   // TODO: Add support for global variables on the device after declare target
6415   // support.
6416   Function *Fn = dyn_cast<Function>(Addr);
6417   if (!Fn)
6418     return;
6419 
6420   Module &M = *(Fn->getParent());
6421   LLVMContext &Ctx = M.getContext();
6422 
6423   // Get "nvvm.annotations" metadata node.
6424   NamedMDNode *MD = M.getOrInsertNamedMetadata("nvvm.annotations");
6425 
6426   Metadata *MDVals[] = {
6427       ConstantAsMetadata::get(Fn), MDString::get(Ctx, "kernel"),
6428       ConstantAsMetadata::get(ConstantInt::get(Type::getInt32Ty(Ctx), 1))};
6429   // Append metadata to nvvm.annotations.
6430   MD->addOperand(MDNode::get(Ctx, MDVals));
6431 
6432   // Add a function attribute for the kernel.
6433   Fn->addFnAttr(Attribute::get(Ctx, "kernel"));
6434   if (T.isAMDGCN())
6435     Fn->addFnAttr("uniform-work-group-size", "true");
6436   Fn->addFnAttr(Attribute::MustProgress);
6437 }
6438 
6439 // We only generate metadata for function that contain target regions.
6440 void OpenMPIRBuilder::createOffloadEntriesAndInfoMetadata(
6441     EmitMetadataErrorReportFunctionTy &ErrorFn) {
6442 
6443   // If there are no entries, we don't need to do anything.
6444   if (OffloadInfoManager.empty())
6445     return;
6446 
6447   LLVMContext &C = M.getContext();
6448   SmallVector<std::pair<const OffloadEntriesInfoManager::OffloadEntryInfo *,
6449                         TargetRegionEntryInfo>,
6450               16>
6451       OrderedEntries(OffloadInfoManager.size());
6452 
6453   // Auxiliary methods to create metadata values and strings.
6454   auto &&GetMDInt = [this](unsigned V) {
6455     return ConstantAsMetadata::get(ConstantInt::get(Builder.getInt32Ty(), V));
6456   };
6457 
6458   auto &&GetMDString = [&C](StringRef V) { return MDString::get(C, V); };
6459 
6460   // Create the offloading info metadata node.
6461   NamedMDNode *MD = M.getOrInsertNamedMetadata("omp_offload.info");
6462   auto &&TargetRegionMetadataEmitter =
6463       [&C, MD, &OrderedEntries, &GetMDInt, &GetMDString](
6464           const TargetRegionEntryInfo &EntryInfo,
6465           const OffloadEntriesInfoManager::OffloadEntryInfoTargetRegion &E) {
6466         // Generate metadata for target regions. Each entry of this metadata
6467         // contains:
6468         // - Entry 0 -> Kind of this type of metadata (0).
6469         // - Entry 1 -> Device ID of the file where the entry was identified.
6470         // - Entry 2 -> File ID of the file where the entry was identified.
6471         // - Entry 3 -> Mangled name of the function where the entry was
6472         // identified.
6473         // - Entry 4 -> Line in the file where the entry was identified.
6474         // - Entry 5 -> Count of regions at this DeviceID/FilesID/Line.
6475         // - Entry 6 -> Order the entry was created.
6476         // The first element of the metadata node is the kind.
6477         Metadata *Ops[] = {
6478             GetMDInt(E.getKind()),      GetMDInt(EntryInfo.DeviceID),
6479             GetMDInt(EntryInfo.FileID), GetMDString(EntryInfo.ParentName),
6480             GetMDInt(EntryInfo.Line),   GetMDInt(EntryInfo.Count),
6481             GetMDInt(E.getOrder())};
6482 
6483         // Save this entry in the right position of the ordered entries array.
6484         OrderedEntries[E.getOrder()] = std::make_pair(&E, EntryInfo);
6485 
6486         // Add metadata to the named metadata node.
6487         MD->addOperand(MDNode::get(C, Ops));
6488       };
6489 
6490   OffloadInfoManager.actOnTargetRegionEntriesInfo(TargetRegionMetadataEmitter);
6491 
6492   // Create function that emits metadata for each device global variable entry;
6493   auto &&DeviceGlobalVarMetadataEmitter =
6494       [&C, &OrderedEntries, &GetMDInt, &GetMDString, MD](
6495           StringRef MangledName,
6496           const OffloadEntriesInfoManager::OffloadEntryInfoDeviceGlobalVar &E) {
6497         // Generate metadata for global variables. Each entry of this metadata
6498         // contains:
6499         // - Entry 0 -> Kind of this type of metadata (1).
6500         // - Entry 1 -> Mangled name of the variable.
6501         // - Entry 2 -> Declare target kind.
6502         // - Entry 3 -> Order the entry was created.
6503         // The first element of the metadata node is the kind.
6504         Metadata *Ops[] = {GetMDInt(E.getKind()), GetMDString(MangledName),
6505                            GetMDInt(E.getFlags()), GetMDInt(E.getOrder())};
6506 
6507         // Save this entry in the right position of the ordered entries array.
6508         TargetRegionEntryInfo varInfo(MangledName, 0, 0, 0);
6509         OrderedEntries[E.getOrder()] = std::make_pair(&E, varInfo);
6510 
6511         // Add metadata to the named metadata node.
6512         MD->addOperand(MDNode::get(C, Ops));
6513       };
6514 
6515   OffloadInfoManager.actOnDeviceGlobalVarEntriesInfo(
6516       DeviceGlobalVarMetadataEmitter);
6517 
6518   for (const auto &E : OrderedEntries) {
6519     assert(E.first && "All ordered entries must exist!");
6520     if (const auto *CE =
6521             dyn_cast<OffloadEntriesInfoManager::OffloadEntryInfoTargetRegion>(
6522                 E.first)) {
6523       if (!CE->getID() || !CE->getAddress()) {
6524         // Do not blame the entry if the parent funtion is not emitted.
6525         TargetRegionEntryInfo EntryInfo = E.second;
6526         StringRef FnName = EntryInfo.ParentName;
6527         if (!M.getNamedValue(FnName))
6528           continue;
6529         ErrorFn(EMIT_MD_TARGET_REGION_ERROR, EntryInfo);
6530         continue;
6531       }
6532       createOffloadEntry(CE->getID(), CE->getAddress(),
6533                          /*Size=*/0, CE->getFlags(),
6534                          GlobalValue::WeakAnyLinkage);
6535     } else if (const auto *CE = dyn_cast<
6536                    OffloadEntriesInfoManager::OffloadEntryInfoDeviceGlobalVar>(
6537                    E.first)) {
6538       OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind Flags =
6539           static_cast<OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind>(
6540               CE->getFlags());
6541       switch (Flags) {
6542       case OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter:
6543       case OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo:
6544         if (Config.isTargetDevice() && Config.hasRequiresUnifiedSharedMemory())
6545           continue;
6546         if (!CE->getAddress()) {
6547           ErrorFn(EMIT_MD_DECLARE_TARGET_ERROR, E.second);
6548           continue;
6549         }
6550         // The vaiable has no definition - no need to add the entry.
6551         if (CE->getVarSize() == 0)
6552           continue;
6553         break;
6554       case OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink:
6555         assert(((Config.isTargetDevice() && !CE->getAddress()) ||
6556                 (!Config.isTargetDevice() && CE->getAddress())) &&
6557                "Declaret target link address is set.");
6558         if (Config.isTargetDevice())
6559           continue;
6560         if (!CE->getAddress()) {
6561           ErrorFn(EMIT_MD_GLOBAL_VAR_LINK_ERROR, TargetRegionEntryInfo());
6562           continue;
6563         }
6564         break;
6565       default:
6566         break;
6567       }
6568 
6569       // Hidden or internal symbols on the device are not externally visible.
6570       // We should not attempt to register them by creating an offloading
6571       // entry. Indirect variables are handled separately on the device.
6572       if (auto *GV = dyn_cast<GlobalValue>(CE->getAddress()))
6573         if ((GV->hasLocalLinkage() || GV->hasHiddenVisibility()) &&
6574             Flags != OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirect)
6575           continue;
6576 
6577       // Indirect globals need to use a special name that doesn't match the name
6578       // of the associated host global.
6579       if (Flags == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirect)
6580         createOffloadEntry(CE->getAddress(), CE->getAddress(), CE->getVarSize(),
6581                            Flags, CE->getLinkage(), CE->getVarName());
6582       else
6583         createOffloadEntry(CE->getAddress(), CE->getAddress(), CE->getVarSize(),
6584                            Flags, CE->getLinkage());
6585 
6586     } else {
6587       llvm_unreachable("Unsupported entry kind.");
6588     }
6589   }
6590 }
6591 
6592 void TargetRegionEntryInfo::getTargetRegionEntryFnName(
6593     SmallVectorImpl<char> &Name, StringRef ParentName, unsigned DeviceID,
6594     unsigned FileID, unsigned Line, unsigned Count) {
6595   raw_svector_ostream OS(Name);
6596   OS << "__omp_offloading" << llvm::format("_%x", DeviceID)
6597      << llvm::format("_%x_", FileID) << ParentName << "_l" << Line;
6598   if (Count)
6599     OS << "_" << Count;
6600 }
6601 
6602 void OffloadEntriesInfoManager::getTargetRegionEntryFnName(
6603     SmallVectorImpl<char> &Name, const TargetRegionEntryInfo &EntryInfo) {
6604   unsigned NewCount = getTargetRegionEntryInfoCount(EntryInfo);
6605   TargetRegionEntryInfo::getTargetRegionEntryFnName(
6606       Name, EntryInfo.ParentName, EntryInfo.DeviceID, EntryInfo.FileID,
6607       EntryInfo.Line, NewCount);
6608 }
6609 
6610 TargetRegionEntryInfo
6611 OpenMPIRBuilder::getTargetEntryUniqueInfo(FileIdentifierInfoCallbackTy CallBack,
6612                                           StringRef ParentName) {
6613   sys::fs::UniqueID ID;
6614   auto FileIDInfo = CallBack();
6615   if (auto EC = sys::fs::getUniqueID(std::get<0>(FileIDInfo), ID)) {
6616     report_fatal_error(("Unable to get unique ID for file, during "
6617                         "getTargetEntryUniqueInfo, error message: " +
6618                         EC.message())
6619                            .c_str());
6620   }
6621 
6622   return TargetRegionEntryInfo(ParentName, ID.getDevice(), ID.getFile(),
6623                                std::get<1>(FileIDInfo));
6624 }
6625 
6626 unsigned OpenMPIRBuilder::getFlagMemberOffset() {
6627   unsigned Offset = 0;
6628   for (uint64_t Remain =
6629            static_cast<std::underlying_type_t<omp::OpenMPOffloadMappingFlags>>(
6630                omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF);
6631        !(Remain & 1); Remain = Remain >> 1)
6632     Offset++;
6633   return Offset;
6634 }
6635 
6636 omp::OpenMPOffloadMappingFlags
6637 OpenMPIRBuilder::getMemberOfFlag(unsigned Position) {
6638   // Rotate by getFlagMemberOffset() bits.
6639   return static_cast<omp::OpenMPOffloadMappingFlags>(((uint64_t)Position + 1)
6640                                                      << getFlagMemberOffset());
6641 }
6642 
6643 void OpenMPIRBuilder::setCorrectMemberOfFlag(
6644     omp::OpenMPOffloadMappingFlags &Flags,
6645     omp::OpenMPOffloadMappingFlags MemberOfFlag) {
6646   // If the entry is PTR_AND_OBJ but has not been marked with the special
6647   // placeholder value 0xFFFF in the MEMBER_OF field, then it should not be
6648   // marked as MEMBER_OF.
6649   if (static_cast<std::underlying_type_t<omp::OpenMPOffloadMappingFlags>>(
6650           Flags & omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ) &&
6651       static_cast<std::underlying_type_t<omp::OpenMPOffloadMappingFlags>>(
6652           (Flags & omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF) !=
6653           omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF))
6654     return;
6655 
6656   // Reset the placeholder value to prepare the flag for the assignment of the
6657   // proper MEMBER_OF value.
6658   Flags &= ~omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
6659   Flags |= MemberOfFlag;
6660 }
6661 
6662 Constant *OpenMPIRBuilder::getAddrOfDeclareTargetVar(
6663     OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind CaptureClause,
6664     OffloadEntriesInfoManager::OMPTargetDeviceClauseKind DeviceClause,
6665     bool IsDeclaration, bool IsExternallyVisible,
6666     TargetRegionEntryInfo EntryInfo, StringRef MangledName,
6667     std::vector<GlobalVariable *> &GeneratedRefs, bool OpenMPSIMD,
6668     std::vector<Triple> TargetTriple, Type *LlvmPtrTy,
6669     std::function<Constant *()> GlobalInitializer,
6670     std::function<GlobalValue::LinkageTypes()> VariableLinkage) {
6671   // TODO: convert this to utilise the IRBuilder Config rather than
6672   // a passed down argument.
6673   if (OpenMPSIMD)
6674     return nullptr;
6675 
6676   if (CaptureClause == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink ||
6677       ((CaptureClause == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo ||
6678         CaptureClause ==
6679             OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter) &&
6680        Config.hasRequiresUnifiedSharedMemory())) {
6681     SmallString<64> PtrName;
6682     {
6683       raw_svector_ostream OS(PtrName);
6684       OS << MangledName;
6685       if (!IsExternallyVisible)
6686         OS << format("_%x", EntryInfo.FileID);
6687       OS << "_decl_tgt_ref_ptr";
6688     }
6689 
6690     Value *Ptr = M.getNamedValue(PtrName);
6691 
6692     if (!Ptr) {
6693       GlobalValue *GlobalValue = M.getNamedValue(MangledName);
6694       Ptr = getOrCreateInternalVariable(LlvmPtrTy, PtrName);
6695 
6696       auto *GV = cast<GlobalVariable>(Ptr);
6697       GV->setLinkage(GlobalValue::WeakAnyLinkage);
6698 
6699       if (!Config.isTargetDevice()) {
6700         if (GlobalInitializer)
6701           GV->setInitializer(GlobalInitializer());
6702         else
6703           GV->setInitializer(GlobalValue);
6704       }
6705 
6706       registerTargetGlobalVariable(
6707           CaptureClause, DeviceClause, IsDeclaration, IsExternallyVisible,
6708           EntryInfo, MangledName, GeneratedRefs, OpenMPSIMD, TargetTriple,
6709           GlobalInitializer, VariableLinkage, LlvmPtrTy, cast<Constant>(Ptr));
6710     }
6711 
6712     return cast<Constant>(Ptr);
6713   }
6714 
6715   return nullptr;
6716 }
6717 
6718 void OpenMPIRBuilder::registerTargetGlobalVariable(
6719     OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind CaptureClause,
6720     OffloadEntriesInfoManager::OMPTargetDeviceClauseKind DeviceClause,
6721     bool IsDeclaration, bool IsExternallyVisible,
6722     TargetRegionEntryInfo EntryInfo, StringRef MangledName,
6723     std::vector<GlobalVariable *> &GeneratedRefs, bool OpenMPSIMD,
6724     std::vector<Triple> TargetTriple,
6725     std::function<Constant *()> GlobalInitializer,
6726     std::function<GlobalValue::LinkageTypes()> VariableLinkage, Type *LlvmPtrTy,
6727     Constant *Addr) {
6728   if (DeviceClause != OffloadEntriesInfoManager::OMPTargetDeviceClauseAny ||
6729       (TargetTriple.empty() && !Config.isTargetDevice()))
6730     return;
6731 
6732   OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind Flags;
6733   StringRef VarName;
6734   int64_t VarSize;
6735   GlobalValue::LinkageTypes Linkage;
6736 
6737   if ((CaptureClause == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo ||
6738        CaptureClause ==
6739            OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter) &&
6740       !Config.hasRequiresUnifiedSharedMemory()) {
6741     Flags = OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
6742     VarName = MangledName;
6743     GlobalValue *LlvmVal = M.getNamedValue(VarName);
6744 
6745     if (!IsDeclaration)
6746       VarSize = divideCeil(
6747           M.getDataLayout().getTypeSizeInBits(LlvmVal->getValueType()), 8);
6748     else
6749       VarSize = 0;
6750     Linkage = (VariableLinkage) ? VariableLinkage() : LlvmVal->getLinkage();
6751 
6752     // This is a workaround carried over from Clang which prevents undesired
6753     // optimisation of internal variables.
6754     if (Config.isTargetDevice() &&
6755         (!IsExternallyVisible || Linkage == GlobalValue::LinkOnceODRLinkage)) {
6756       // Do not create a "ref-variable" if the original is not also available
6757       // on the host.
6758       if (!OffloadInfoManager.hasDeviceGlobalVarEntryInfo(VarName))
6759         return;
6760 
6761       std::string RefName = createPlatformSpecificName({VarName, "ref"});
6762 
6763       if (!M.getNamedValue(RefName)) {
6764         Constant *AddrRef =
6765             getOrCreateInternalVariable(Addr->getType(), RefName);
6766         auto *GvAddrRef = cast<GlobalVariable>(AddrRef);
6767         GvAddrRef->setConstant(true);
6768         GvAddrRef->setLinkage(GlobalValue::InternalLinkage);
6769         GvAddrRef->setInitializer(Addr);
6770         GeneratedRefs.push_back(GvAddrRef);
6771       }
6772     }
6773   } else {
6774     if (CaptureClause == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink)
6775       Flags = OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink;
6776     else
6777       Flags = OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
6778 
6779     if (Config.isTargetDevice()) {
6780       VarName = (Addr) ? Addr->getName() : "";
6781       Addr = nullptr;
6782     } else {
6783       Addr = getAddrOfDeclareTargetVar(
6784           CaptureClause, DeviceClause, IsDeclaration, IsExternallyVisible,
6785           EntryInfo, MangledName, GeneratedRefs, OpenMPSIMD, TargetTriple,
6786           LlvmPtrTy, GlobalInitializer, VariableLinkage);
6787       VarName = (Addr) ? Addr->getName() : "";
6788     }
6789     VarSize = M.getDataLayout().getPointerSize();
6790     Linkage = GlobalValue::WeakAnyLinkage;
6791   }
6792 
6793   OffloadInfoManager.registerDeviceGlobalVarEntryInfo(VarName, Addr, VarSize,
6794                                                       Flags, Linkage);
6795 }
6796 
6797 /// Loads all the offload entries information from the host IR
6798 /// metadata.
6799 void OpenMPIRBuilder::loadOffloadInfoMetadata(Module &M) {
6800   // If we are in target mode, load the metadata from the host IR. This code has
6801   // to match the metadata creation in createOffloadEntriesAndInfoMetadata().
6802 
6803   NamedMDNode *MD = M.getNamedMetadata(ompOffloadInfoName);
6804   if (!MD)
6805     return;
6806 
6807   for (MDNode *MN : MD->operands()) {
6808     auto &&GetMDInt = [MN](unsigned Idx) {
6809       auto *V = cast<ConstantAsMetadata>(MN->getOperand(Idx));
6810       return cast<ConstantInt>(V->getValue())->getZExtValue();
6811     };
6812 
6813     auto &&GetMDString = [MN](unsigned Idx) {
6814       auto *V = cast<MDString>(MN->getOperand(Idx));
6815       return V->getString();
6816     };
6817 
6818     switch (GetMDInt(0)) {
6819     default:
6820       llvm_unreachable("Unexpected metadata!");
6821       break;
6822     case OffloadEntriesInfoManager::OffloadEntryInfo::
6823         OffloadingEntryInfoTargetRegion: {
6824       TargetRegionEntryInfo EntryInfo(/*ParentName=*/GetMDString(3),
6825                                       /*DeviceID=*/GetMDInt(1),
6826                                       /*FileID=*/GetMDInt(2),
6827                                       /*Line=*/GetMDInt(4),
6828                                       /*Count=*/GetMDInt(5));
6829       OffloadInfoManager.initializeTargetRegionEntryInfo(EntryInfo,
6830                                                          /*Order=*/GetMDInt(6));
6831       break;
6832     }
6833     case OffloadEntriesInfoManager::OffloadEntryInfo::
6834         OffloadingEntryInfoDeviceGlobalVar:
6835       OffloadInfoManager.initializeDeviceGlobalVarEntryInfo(
6836           /*MangledName=*/GetMDString(1),
6837           static_cast<OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind>(
6838               /*Flags=*/GetMDInt(2)),
6839           /*Order=*/GetMDInt(3));
6840       break;
6841     }
6842   }
6843 }
6844 
6845 void OpenMPIRBuilder::loadOffloadInfoMetadata(StringRef HostFilePath) {
6846   if (HostFilePath.empty())
6847     return;
6848 
6849   auto Buf = MemoryBuffer::getFile(HostFilePath);
6850   if (std::error_code Err = Buf.getError()) {
6851     report_fatal_error(("error opening host file from host file path inside of "
6852                         "OpenMPIRBuilder: " +
6853                         Err.message())
6854                            .c_str());
6855   }
6856 
6857   LLVMContext Ctx;
6858   auto M = expectedToErrorOrAndEmitErrors(
6859       Ctx, parseBitcodeFile(Buf.get()->getMemBufferRef(), Ctx));
6860   if (std::error_code Err = M.getError()) {
6861     report_fatal_error(
6862         ("error parsing host file inside of OpenMPIRBuilder: " + Err.message())
6863             .c_str());
6864   }
6865 
6866   loadOffloadInfoMetadata(*M.get());
6867 }
6868 
6869 Function *OpenMPIRBuilder::createRegisterRequires(StringRef Name) {
6870   // Skip the creation of the registration function if this is device codegen
6871   if (Config.isTargetDevice())
6872     return nullptr;
6873 
6874   Builder.ClearInsertionPoint();
6875 
6876   // Create registration function prototype
6877   auto *RegFnTy = FunctionType::get(Builder.getVoidTy(), {});
6878   auto *RegFn = Function::Create(
6879       RegFnTy, GlobalVariable::LinkageTypes::InternalLinkage, Name, M);
6880   RegFn->setSection(".text.startup");
6881   RegFn->addFnAttr(Attribute::NoInline);
6882   RegFn->addFnAttr(Attribute::NoUnwind);
6883 
6884   // Create registration function body
6885   auto *BB = BasicBlock::Create(M.getContext(), "entry", RegFn);
6886   ConstantInt *FlagsVal =
6887       ConstantInt::getSigned(Builder.getInt64Ty(), Config.getRequiresFlags());
6888   Function *RTLRegFn = getOrCreateRuntimeFunctionPtr(
6889       omp::RuntimeFunction::OMPRTL___tgt_register_requires);
6890 
6891   Builder.SetInsertPoint(BB);
6892   Builder.CreateCall(RTLRegFn, {FlagsVal});
6893   Builder.CreateRetVoid();
6894 
6895   return RegFn;
6896 }
6897 
6898 //===----------------------------------------------------------------------===//
6899 // OffloadEntriesInfoManager
6900 //===----------------------------------------------------------------------===//
6901 
6902 bool OffloadEntriesInfoManager::empty() const {
6903   return OffloadEntriesTargetRegion.empty() &&
6904          OffloadEntriesDeviceGlobalVar.empty();
6905 }
6906 
6907 unsigned OffloadEntriesInfoManager::getTargetRegionEntryInfoCount(
6908     const TargetRegionEntryInfo &EntryInfo) const {
6909   auto It = OffloadEntriesTargetRegionCount.find(
6910       getTargetRegionEntryCountKey(EntryInfo));
6911   if (It == OffloadEntriesTargetRegionCount.end())
6912     return 0;
6913   return It->second;
6914 }
6915 
6916 void OffloadEntriesInfoManager::incrementTargetRegionEntryInfoCount(
6917     const TargetRegionEntryInfo &EntryInfo) {
6918   OffloadEntriesTargetRegionCount[getTargetRegionEntryCountKey(EntryInfo)] =
6919       EntryInfo.Count + 1;
6920 }
6921 
6922 /// Initialize target region entry.
6923 void OffloadEntriesInfoManager::initializeTargetRegionEntryInfo(
6924     const TargetRegionEntryInfo &EntryInfo, unsigned Order) {
6925   OffloadEntriesTargetRegion[EntryInfo] =
6926       OffloadEntryInfoTargetRegion(Order, /*Addr=*/nullptr, /*ID=*/nullptr,
6927                                    OMPTargetRegionEntryTargetRegion);
6928   ++OffloadingEntriesNum;
6929 }
6930 
6931 void OffloadEntriesInfoManager::registerTargetRegionEntryInfo(
6932     TargetRegionEntryInfo EntryInfo, Constant *Addr, Constant *ID,
6933     OMPTargetRegionEntryKind Flags) {
6934   assert(EntryInfo.Count == 0 && "expected default EntryInfo");
6935 
6936   // Update the EntryInfo with the next available count for this location.
6937   EntryInfo.Count = getTargetRegionEntryInfoCount(EntryInfo);
6938 
6939   // If we are emitting code for a target, the entry is already initialized,
6940   // only has to be registered.
6941   if (OMPBuilder->Config.isTargetDevice()) {
6942     // This could happen if the device compilation is invoked standalone.
6943     if (!hasTargetRegionEntryInfo(EntryInfo)) {
6944       return;
6945     }
6946     auto &Entry = OffloadEntriesTargetRegion[EntryInfo];
6947     Entry.setAddress(Addr);
6948     Entry.setID(ID);
6949     Entry.setFlags(Flags);
6950   } else {
6951     if (Flags == OffloadEntriesInfoManager::OMPTargetRegionEntryTargetRegion &&
6952         hasTargetRegionEntryInfo(EntryInfo, /*IgnoreAddressId*/ true))
6953       return;
6954     assert(!hasTargetRegionEntryInfo(EntryInfo) &&
6955            "Target region entry already registered!");
6956     OffloadEntryInfoTargetRegion Entry(OffloadingEntriesNum, Addr, ID, Flags);
6957     OffloadEntriesTargetRegion[EntryInfo] = Entry;
6958     ++OffloadingEntriesNum;
6959   }
6960   incrementTargetRegionEntryInfoCount(EntryInfo);
6961 }
6962 
6963 bool OffloadEntriesInfoManager::hasTargetRegionEntryInfo(
6964     TargetRegionEntryInfo EntryInfo, bool IgnoreAddressId) const {
6965 
6966   // Update the EntryInfo with the next available count for this location.
6967   EntryInfo.Count = getTargetRegionEntryInfoCount(EntryInfo);
6968 
6969   auto It = OffloadEntriesTargetRegion.find(EntryInfo);
6970   if (It == OffloadEntriesTargetRegion.end()) {
6971     return false;
6972   }
6973   // Fail if this entry is already registered.
6974   if (!IgnoreAddressId && (It->second.getAddress() || It->second.getID()))
6975     return false;
6976   return true;
6977 }
6978 
6979 void OffloadEntriesInfoManager::actOnTargetRegionEntriesInfo(
6980     const OffloadTargetRegionEntryInfoActTy &Action) {
6981   // Scan all target region entries and perform the provided action.
6982   for (const auto &It : OffloadEntriesTargetRegion) {
6983     Action(It.first, It.second);
6984   }
6985 }
6986 
6987 void OffloadEntriesInfoManager::initializeDeviceGlobalVarEntryInfo(
6988     StringRef Name, OMPTargetGlobalVarEntryKind Flags, unsigned Order) {
6989   OffloadEntriesDeviceGlobalVar.try_emplace(Name, Order, Flags);
6990   ++OffloadingEntriesNum;
6991 }
6992 
6993 void OffloadEntriesInfoManager::registerDeviceGlobalVarEntryInfo(
6994     StringRef VarName, Constant *Addr, int64_t VarSize,
6995     OMPTargetGlobalVarEntryKind Flags, GlobalValue::LinkageTypes Linkage) {
6996   if (OMPBuilder->Config.isTargetDevice()) {
6997     // This could happen if the device compilation is invoked standalone.
6998     if (!hasDeviceGlobalVarEntryInfo(VarName))
6999       return;
7000     auto &Entry = OffloadEntriesDeviceGlobalVar[VarName];
7001     if (Entry.getAddress() && hasDeviceGlobalVarEntryInfo(VarName)) {
7002       if (Entry.getVarSize() == 0) {
7003         Entry.setVarSize(VarSize);
7004         Entry.setLinkage(Linkage);
7005       }
7006       return;
7007     }
7008     Entry.setVarSize(VarSize);
7009     Entry.setLinkage(Linkage);
7010     Entry.setAddress(Addr);
7011   } else {
7012     if (hasDeviceGlobalVarEntryInfo(VarName)) {
7013       auto &Entry = OffloadEntriesDeviceGlobalVar[VarName];
7014       assert(Entry.isValid() && Entry.getFlags() == Flags &&
7015              "Entry not initialized!");
7016       if (Entry.getVarSize() == 0) {
7017         Entry.setVarSize(VarSize);
7018         Entry.setLinkage(Linkage);
7019       }
7020       return;
7021     }
7022     if (Flags == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirect)
7023       OffloadEntriesDeviceGlobalVar.try_emplace(VarName, OffloadingEntriesNum,
7024                                                 Addr, VarSize, Flags, Linkage,
7025                                                 VarName.str());
7026     else
7027       OffloadEntriesDeviceGlobalVar.try_emplace(
7028           VarName, OffloadingEntriesNum, Addr, VarSize, Flags, Linkage, "");
7029     ++OffloadingEntriesNum;
7030   }
7031 }
7032 
7033 void OffloadEntriesInfoManager::actOnDeviceGlobalVarEntriesInfo(
7034     const OffloadDeviceGlobalVarEntryInfoActTy &Action) {
7035   // Scan all target region entries and perform the provided action.
7036   for (const auto &E : OffloadEntriesDeviceGlobalVar)
7037     Action(E.getKey(), E.getValue());
7038 }
7039 
7040 //===----------------------------------------------------------------------===//
7041 // CanonicalLoopInfo
7042 //===----------------------------------------------------------------------===//
7043 
7044 void CanonicalLoopInfo::collectControlBlocks(
7045     SmallVectorImpl<BasicBlock *> &BBs) {
7046   // We only count those BBs as control block for which we do not need to
7047   // reverse the CFG, i.e. not the loop body which can contain arbitrary control
7048   // flow. For consistency, this also means we do not add the Body block, which
7049   // is just the entry to the body code.
7050   BBs.reserve(BBs.size() + 6);
7051   BBs.append({getPreheader(), Header, Cond, Latch, Exit, getAfter()});
7052 }
7053 
7054 BasicBlock *CanonicalLoopInfo::getPreheader() const {
7055   assert(isValid() && "Requires a valid canonical loop");
7056   for (BasicBlock *Pred : predecessors(Header)) {
7057     if (Pred != Latch)
7058       return Pred;
7059   }
7060   llvm_unreachable("Missing preheader");
7061 }
7062 
7063 void CanonicalLoopInfo::setTripCount(Value *TripCount) {
7064   assert(isValid() && "Requires a valid canonical loop");
7065 
7066   Instruction *CmpI = &getCond()->front();
7067   assert(isa<CmpInst>(CmpI) && "First inst must compare IV with TripCount");
7068   CmpI->setOperand(1, TripCount);
7069 
7070 #ifndef NDEBUG
7071   assertOK();
7072 #endif
7073 }
7074 
7075 void CanonicalLoopInfo::mapIndVar(
7076     llvm::function_ref<Value *(Instruction *)> Updater) {
7077   assert(isValid() && "Requires a valid canonical loop");
7078 
7079   Instruction *OldIV = getIndVar();
7080 
7081   // Record all uses excluding those introduced by the updater. Uses by the
7082   // CanonicalLoopInfo itself to keep track of the number of iterations are
7083   // excluded.
7084   SmallVector<Use *> ReplacableUses;
7085   for (Use &U : OldIV->uses()) {
7086     auto *User = dyn_cast<Instruction>(U.getUser());
7087     if (!User)
7088       continue;
7089     if (User->getParent() == getCond())
7090       continue;
7091     if (User->getParent() == getLatch())
7092       continue;
7093     ReplacableUses.push_back(&U);
7094   }
7095 
7096   // Run the updater that may introduce new uses
7097   Value *NewIV = Updater(OldIV);
7098 
7099   // Replace the old uses with the value returned by the updater.
7100   for (Use *U : ReplacableUses)
7101     U->set(NewIV);
7102 
7103 #ifndef NDEBUG
7104   assertOK();
7105 #endif
7106 }
7107 
7108 void CanonicalLoopInfo::assertOK() const {
7109 #ifndef NDEBUG
7110   // No constraints if this object currently does not describe a loop.
7111   if (!isValid())
7112     return;
7113 
7114   BasicBlock *Preheader = getPreheader();
7115   BasicBlock *Body = getBody();
7116   BasicBlock *After = getAfter();
7117 
7118   // Verify standard control-flow we use for OpenMP loops.
7119   assert(Preheader);
7120   assert(isa<BranchInst>(Preheader->getTerminator()) &&
7121          "Preheader must terminate with unconditional branch");
7122   assert(Preheader->getSingleSuccessor() == Header &&
7123          "Preheader must jump to header");
7124 
7125   assert(Header);
7126   assert(isa<BranchInst>(Header->getTerminator()) &&
7127          "Header must terminate with unconditional branch");
7128   assert(Header->getSingleSuccessor() == Cond &&
7129          "Header must jump to exiting block");
7130 
7131   assert(Cond);
7132   assert(Cond->getSinglePredecessor() == Header &&
7133          "Exiting block only reachable from header");
7134 
7135   assert(isa<BranchInst>(Cond->getTerminator()) &&
7136          "Exiting block must terminate with conditional branch");
7137   assert(size(successors(Cond)) == 2 &&
7138          "Exiting block must have two successors");
7139   assert(cast<BranchInst>(Cond->getTerminator())->getSuccessor(0) == Body &&
7140          "Exiting block's first successor jump to the body");
7141   assert(cast<BranchInst>(Cond->getTerminator())->getSuccessor(1) == Exit &&
7142          "Exiting block's second successor must exit the loop");
7143 
7144   assert(Body);
7145   assert(Body->getSinglePredecessor() == Cond &&
7146          "Body only reachable from exiting block");
7147   assert(!isa<PHINode>(Body->front()));
7148 
7149   assert(Latch);
7150   assert(isa<BranchInst>(Latch->getTerminator()) &&
7151          "Latch must terminate with unconditional branch");
7152   assert(Latch->getSingleSuccessor() == Header && "Latch must jump to header");
7153   // TODO: To support simple redirecting of the end of the body code that has
7154   // multiple; introduce another auxiliary basic block like preheader and after.
7155   assert(Latch->getSinglePredecessor() != nullptr);
7156   assert(!isa<PHINode>(Latch->front()));
7157 
7158   assert(Exit);
7159   assert(isa<BranchInst>(Exit->getTerminator()) &&
7160          "Exit block must terminate with unconditional branch");
7161   assert(Exit->getSingleSuccessor() == After &&
7162          "Exit block must jump to after block");
7163 
7164   assert(After);
7165   assert(After->getSinglePredecessor() == Exit &&
7166          "After block only reachable from exit block");
7167   assert(After->empty() || !isa<PHINode>(After->front()));
7168 
7169   Instruction *IndVar = getIndVar();
7170   assert(IndVar && "Canonical induction variable not found?");
7171   assert(isa<IntegerType>(IndVar->getType()) &&
7172          "Induction variable must be an integer");
7173   assert(cast<PHINode>(IndVar)->getParent() == Header &&
7174          "Induction variable must be a PHI in the loop header");
7175   assert(cast<PHINode>(IndVar)->getIncomingBlock(0) == Preheader);
7176   assert(
7177       cast<ConstantInt>(cast<PHINode>(IndVar)->getIncomingValue(0))->isZero());
7178   assert(cast<PHINode>(IndVar)->getIncomingBlock(1) == Latch);
7179 
7180   auto *NextIndVar = cast<PHINode>(IndVar)->getIncomingValue(1);
7181   assert(cast<Instruction>(NextIndVar)->getParent() == Latch);
7182   assert(cast<BinaryOperator>(NextIndVar)->getOpcode() == BinaryOperator::Add);
7183   assert(cast<BinaryOperator>(NextIndVar)->getOperand(0) == IndVar);
7184   assert(cast<ConstantInt>(cast<BinaryOperator>(NextIndVar)->getOperand(1))
7185              ->isOne());
7186 
7187   Value *TripCount = getTripCount();
7188   assert(TripCount && "Loop trip count not found?");
7189   assert(IndVar->getType() == TripCount->getType() &&
7190          "Trip count and induction variable must have the same type");
7191 
7192   auto *CmpI = cast<CmpInst>(&Cond->front());
7193   assert(CmpI->getPredicate() == CmpInst::ICMP_ULT &&
7194          "Exit condition must be a signed less-than comparison");
7195   assert(CmpI->getOperand(0) == IndVar &&
7196          "Exit condition must compare the induction variable");
7197   assert(CmpI->getOperand(1) == TripCount &&
7198          "Exit condition must compare with the trip count");
7199 #endif
7200 }
7201 
7202 void CanonicalLoopInfo::invalidate() {
7203   Header = nullptr;
7204   Cond = nullptr;
7205   Latch = nullptr;
7206   Exit = nullptr;
7207 }
7208