xref: /freebsd/contrib/llvm-project/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===- OpenMPIRBuilder.cpp - Builder for LLVM-IR for OpenMP directives ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 ///
10 /// This file implements the OpenMPIRBuilder class, which is used as a
11 /// convenient way to create LLVM instructions for OpenMP directives.
12 ///
13 //===----------------------------------------------------------------------===//
14 
15 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
16 #include "llvm/ADT/SmallBitVector.h"
17 #include "llvm/ADT/SmallSet.h"
18 #include "llvm/ADT/StringExtras.h"
19 #include "llvm/ADT/StringRef.h"
20 #include "llvm/Analysis/AssumptionCache.h"
21 #include "llvm/Analysis/CodeMetrics.h"
22 #include "llvm/Analysis/LoopInfo.h"
23 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
24 #include "llvm/Analysis/ScalarEvolution.h"
25 #include "llvm/Analysis/TargetLibraryInfo.h"
26 #include "llvm/Bitcode/BitcodeReader.h"
27 #include "llvm/Frontend/Offloading/Utility.h"
28 #include "llvm/Frontend/OpenMP/OMPGridValues.h"
29 #include "llvm/IR/Attributes.h"
30 #include "llvm/IR/BasicBlock.h"
31 #include "llvm/IR/CFG.h"
32 #include "llvm/IR/CallingConv.h"
33 #include "llvm/IR/Constant.h"
34 #include "llvm/IR/Constants.h"
35 #include "llvm/IR/DIBuilder.h"
36 #include "llvm/IR/DebugInfoMetadata.h"
37 #include "llvm/IR/DerivedTypes.h"
38 #include "llvm/IR/Function.h"
39 #include "llvm/IR/GlobalVariable.h"
40 #include "llvm/IR/IRBuilder.h"
41 #include "llvm/IR/InstIterator.h"
42 #include "llvm/IR/IntrinsicInst.h"
43 #include "llvm/IR/LLVMContext.h"
44 #include "llvm/IR/MDBuilder.h"
45 #include "llvm/IR/Metadata.h"
46 #include "llvm/IR/PassInstrumentation.h"
47 #include "llvm/IR/PassManager.h"
48 #include "llvm/IR/ReplaceConstant.h"
49 #include "llvm/IR/Value.h"
50 #include "llvm/MC/TargetRegistry.h"
51 #include "llvm/Support/CommandLine.h"
52 #include "llvm/Support/ErrorHandling.h"
53 #include "llvm/Support/FileSystem.h"
54 #include "llvm/Target/TargetMachine.h"
55 #include "llvm/Target/TargetOptions.h"
56 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
57 #include "llvm/Transforms/Utils/Cloning.h"
58 #include "llvm/Transforms/Utils/CodeExtractor.h"
59 #include "llvm/Transforms/Utils/LoopPeel.h"
60 #include "llvm/Transforms/Utils/UnrollLoop.h"
61 
62 #include <cstdint>
63 #include <optional>
64 
65 #define DEBUG_TYPE "openmp-ir-builder"
66 
67 using namespace llvm;
68 using namespace omp;
69 
70 static cl::opt<bool>
71     OptimisticAttributes("openmp-ir-builder-optimistic-attributes", cl::Hidden,
72                          cl::desc("Use optimistic attributes describing "
73                                   "'as-if' properties of runtime calls."),
74                          cl::init(false));
75 
76 static cl::opt<double> UnrollThresholdFactor(
77     "openmp-ir-builder-unroll-threshold-factor", cl::Hidden,
78     cl::desc("Factor for the unroll threshold to account for code "
79              "simplifications still taking place"),
80     cl::init(1.5));
81 
82 #ifndef NDEBUG
83 /// Return whether IP1 and IP2 are ambiguous, i.e. that inserting instructions
84 /// at position IP1 may change the meaning of IP2 or vice-versa. This is because
85 /// an InsertPoint stores the instruction before something is inserted. For
86 /// instance, if both point to the same instruction, two IRBuilders alternating
87 /// creating instruction will cause the instructions to be interleaved.
isConflictIP(IRBuilder<>::InsertPoint IP1,IRBuilder<>::InsertPoint IP2)88 static bool isConflictIP(IRBuilder<>::InsertPoint IP1,
89                          IRBuilder<>::InsertPoint IP2) {
90   if (!IP1.isSet() || !IP2.isSet())
91     return false;
92   return IP1.getBlock() == IP2.getBlock() && IP1.getPoint() == IP2.getPoint();
93 }
94 
isValidWorkshareLoopScheduleType(OMPScheduleType SchedType)95 static bool isValidWorkshareLoopScheduleType(OMPScheduleType SchedType) {
96   // Valid ordered/unordered and base algorithm combinations.
97   switch (SchedType & ~OMPScheduleType::MonotonicityMask) {
98   case OMPScheduleType::UnorderedStaticChunked:
99   case OMPScheduleType::UnorderedStatic:
100   case OMPScheduleType::UnorderedDynamicChunked:
101   case OMPScheduleType::UnorderedGuidedChunked:
102   case OMPScheduleType::UnorderedRuntime:
103   case OMPScheduleType::UnorderedAuto:
104   case OMPScheduleType::UnorderedTrapezoidal:
105   case OMPScheduleType::UnorderedGreedy:
106   case OMPScheduleType::UnorderedBalanced:
107   case OMPScheduleType::UnorderedGuidedIterativeChunked:
108   case OMPScheduleType::UnorderedGuidedAnalyticalChunked:
109   case OMPScheduleType::UnorderedSteal:
110   case OMPScheduleType::UnorderedStaticBalancedChunked:
111   case OMPScheduleType::UnorderedGuidedSimd:
112   case OMPScheduleType::UnorderedRuntimeSimd:
113   case OMPScheduleType::OrderedStaticChunked:
114   case OMPScheduleType::OrderedStatic:
115   case OMPScheduleType::OrderedDynamicChunked:
116   case OMPScheduleType::OrderedGuidedChunked:
117   case OMPScheduleType::OrderedRuntime:
118   case OMPScheduleType::OrderedAuto:
119   case OMPScheduleType::OrderdTrapezoidal:
120   case OMPScheduleType::NomergeUnorderedStaticChunked:
121   case OMPScheduleType::NomergeUnorderedStatic:
122   case OMPScheduleType::NomergeUnorderedDynamicChunked:
123   case OMPScheduleType::NomergeUnorderedGuidedChunked:
124   case OMPScheduleType::NomergeUnorderedRuntime:
125   case OMPScheduleType::NomergeUnorderedAuto:
126   case OMPScheduleType::NomergeUnorderedTrapezoidal:
127   case OMPScheduleType::NomergeUnorderedGreedy:
128   case OMPScheduleType::NomergeUnorderedBalanced:
129   case OMPScheduleType::NomergeUnorderedGuidedIterativeChunked:
130   case OMPScheduleType::NomergeUnorderedGuidedAnalyticalChunked:
131   case OMPScheduleType::NomergeUnorderedSteal:
132   case OMPScheduleType::NomergeOrderedStaticChunked:
133   case OMPScheduleType::NomergeOrderedStatic:
134   case OMPScheduleType::NomergeOrderedDynamicChunked:
135   case OMPScheduleType::NomergeOrderedGuidedChunked:
136   case OMPScheduleType::NomergeOrderedRuntime:
137   case OMPScheduleType::NomergeOrderedAuto:
138   case OMPScheduleType::NomergeOrderedTrapezoidal:
139     break;
140   default:
141     return false;
142   }
143 
144   // Must not set both monotonicity modifiers at the same time.
145   OMPScheduleType MonotonicityFlags =
146       SchedType & OMPScheduleType::MonotonicityMask;
147   if (MonotonicityFlags == OMPScheduleType::MonotonicityMask)
148     return false;
149 
150   return true;
151 }
152 #endif
153 
getGridValue(const Triple & T,Function * Kernel)154 static const omp::GV &getGridValue(const Triple &T, Function *Kernel) {
155   if (T.isAMDGPU()) {
156     StringRef Features =
157         Kernel->getFnAttribute("target-features").getValueAsString();
158     if (Features.count("+wavefrontsize64"))
159       return omp::getAMDGPUGridValues<64>();
160     return omp::getAMDGPUGridValues<32>();
161   }
162   if (T.isNVPTX())
163     return omp::NVPTXGridValues;
164   if (T.isSPIRV())
165     return omp::SPIRVGridValues;
166   llvm_unreachable("No grid value available for this architecture!");
167 }
168 
169 /// Determine which scheduling algorithm to use, determined from schedule clause
170 /// arguments.
171 static OMPScheduleType
getOpenMPBaseScheduleType(llvm::omp::ScheduleKind ClauseKind,bool HasChunks,bool HasSimdModifier)172 getOpenMPBaseScheduleType(llvm::omp::ScheduleKind ClauseKind, bool HasChunks,
173                           bool HasSimdModifier) {
174   // Currently, the default schedule it static.
175   switch (ClauseKind) {
176   case OMP_SCHEDULE_Default:
177   case OMP_SCHEDULE_Static:
178     return HasChunks ? OMPScheduleType::BaseStaticChunked
179                      : OMPScheduleType::BaseStatic;
180   case OMP_SCHEDULE_Dynamic:
181     return OMPScheduleType::BaseDynamicChunked;
182   case OMP_SCHEDULE_Guided:
183     return HasSimdModifier ? OMPScheduleType::BaseGuidedSimd
184                            : OMPScheduleType::BaseGuidedChunked;
185   case OMP_SCHEDULE_Auto:
186     return llvm::omp::OMPScheduleType::BaseAuto;
187   case OMP_SCHEDULE_Runtime:
188     return HasSimdModifier ? OMPScheduleType::BaseRuntimeSimd
189                            : OMPScheduleType::BaseRuntime;
190   }
191   llvm_unreachable("unhandled schedule clause argument");
192 }
193 
194 /// Adds ordering modifier flags to schedule type.
195 static OMPScheduleType
getOpenMPOrderingScheduleType(OMPScheduleType BaseScheduleType,bool HasOrderedClause)196 getOpenMPOrderingScheduleType(OMPScheduleType BaseScheduleType,
197                               bool HasOrderedClause) {
198   assert((BaseScheduleType & OMPScheduleType::ModifierMask) ==
199              OMPScheduleType::None &&
200          "Must not have ordering nor monotonicity flags already set");
201 
202   OMPScheduleType OrderingModifier = HasOrderedClause
203                                          ? OMPScheduleType::ModifierOrdered
204                                          : OMPScheduleType::ModifierUnordered;
205   OMPScheduleType OrderingScheduleType = BaseScheduleType | OrderingModifier;
206 
207   // Unsupported combinations
208   if (OrderingScheduleType ==
209       (OMPScheduleType::BaseGuidedSimd | OMPScheduleType::ModifierOrdered))
210     return OMPScheduleType::OrderedGuidedChunked;
211   else if (OrderingScheduleType == (OMPScheduleType::BaseRuntimeSimd |
212                                     OMPScheduleType::ModifierOrdered))
213     return OMPScheduleType::OrderedRuntime;
214 
215   return OrderingScheduleType;
216 }
217 
218 /// Adds monotonicity modifier flags to schedule type.
219 static OMPScheduleType
getOpenMPMonotonicityScheduleType(OMPScheduleType ScheduleType,bool HasSimdModifier,bool HasMonotonic,bool HasNonmonotonic,bool HasOrderedClause)220 getOpenMPMonotonicityScheduleType(OMPScheduleType ScheduleType,
221                                   bool HasSimdModifier, bool HasMonotonic,
222                                   bool HasNonmonotonic, bool HasOrderedClause) {
223   assert((ScheduleType & OMPScheduleType::MonotonicityMask) ==
224              OMPScheduleType::None &&
225          "Must not have monotonicity flags already set");
226   assert((!HasMonotonic || !HasNonmonotonic) &&
227          "Monotonic and Nonmonotonic are contradicting each other");
228 
229   if (HasMonotonic) {
230     return ScheduleType | OMPScheduleType::ModifierMonotonic;
231   } else if (HasNonmonotonic) {
232     return ScheduleType | OMPScheduleType::ModifierNonmonotonic;
233   } else {
234     // OpenMP 5.1, 2.11.4 Worksharing-Loop Construct, Description.
235     // If the static schedule kind is specified or if the ordered clause is
236     // specified, and if the nonmonotonic modifier is not specified, the
237     // effect is as if the monotonic modifier is specified. Otherwise, unless
238     // the monotonic modifier is specified, the effect is as if the
239     // nonmonotonic modifier is specified.
240     OMPScheduleType BaseScheduleType =
241         ScheduleType & ~OMPScheduleType::ModifierMask;
242     if ((BaseScheduleType == OMPScheduleType::BaseStatic) ||
243         (BaseScheduleType == OMPScheduleType::BaseStaticChunked) ||
244         HasOrderedClause) {
245       // The monotonic is used by default in openmp runtime library, so no need
246       // to set it.
247       return ScheduleType;
248     } else {
249       return ScheduleType | OMPScheduleType::ModifierNonmonotonic;
250     }
251   }
252 }
253 
254 /// Determine the schedule type using schedule and ordering clause arguments.
255 static OMPScheduleType
computeOpenMPScheduleType(ScheduleKind ClauseKind,bool HasChunks,bool HasSimdModifier,bool HasMonotonicModifier,bool HasNonmonotonicModifier,bool HasOrderedClause)256 computeOpenMPScheduleType(ScheduleKind ClauseKind, bool HasChunks,
257                           bool HasSimdModifier, bool HasMonotonicModifier,
258                           bool HasNonmonotonicModifier, bool HasOrderedClause) {
259   OMPScheduleType BaseSchedule =
260       getOpenMPBaseScheduleType(ClauseKind, HasChunks, HasSimdModifier);
261   OMPScheduleType OrderedSchedule =
262       getOpenMPOrderingScheduleType(BaseSchedule, HasOrderedClause);
263   OMPScheduleType Result = getOpenMPMonotonicityScheduleType(
264       OrderedSchedule, HasSimdModifier, HasMonotonicModifier,
265       HasNonmonotonicModifier, HasOrderedClause);
266 
267   assert(isValidWorkshareLoopScheduleType(Result));
268   return Result;
269 }
270 
271 /// Make \p Source branch to \p Target.
272 ///
273 /// Handles two situations:
274 /// * \p Source already has an unconditional branch.
275 /// * \p Source is a degenerate block (no terminator because the BB is
276 ///             the current head of the IR construction).
redirectTo(BasicBlock * Source,BasicBlock * Target,DebugLoc DL)277 static void redirectTo(BasicBlock *Source, BasicBlock *Target, DebugLoc DL) {
278   if (Instruction *Term = Source->getTerminator()) {
279     auto *Br = cast<BranchInst>(Term);
280     assert(!Br->isConditional() &&
281            "BB's terminator must be an unconditional branch (or degenerate)");
282     BasicBlock *Succ = Br->getSuccessor(0);
283     Succ->removePredecessor(Source, /*KeepOneInputPHIs=*/true);
284     Br->setSuccessor(0, Target);
285     return;
286   }
287 
288   auto *NewBr = BranchInst::Create(Target, Source);
289   NewBr->setDebugLoc(DL);
290 }
291 
spliceBB(IRBuilderBase::InsertPoint IP,BasicBlock * New,bool CreateBranch,DebugLoc DL)292 void llvm::spliceBB(IRBuilderBase::InsertPoint IP, BasicBlock *New,
293                     bool CreateBranch, DebugLoc DL) {
294   assert(New->getFirstInsertionPt() == New->begin() &&
295          "Target BB must not have PHI nodes");
296 
297   // Move instructions to new block.
298   BasicBlock *Old = IP.getBlock();
299   New->splice(New->begin(), Old, IP.getPoint(), Old->end());
300 
301   if (CreateBranch) {
302     auto *NewBr = BranchInst::Create(New, Old);
303     NewBr->setDebugLoc(DL);
304   }
305 }
306 
spliceBB(IRBuilder<> & Builder,BasicBlock * New,bool CreateBranch)307 void llvm::spliceBB(IRBuilder<> &Builder, BasicBlock *New, bool CreateBranch) {
308   DebugLoc DebugLoc = Builder.getCurrentDebugLocation();
309   BasicBlock *Old = Builder.GetInsertBlock();
310 
311   spliceBB(Builder.saveIP(), New, CreateBranch, DebugLoc);
312   if (CreateBranch)
313     Builder.SetInsertPoint(Old->getTerminator());
314   else
315     Builder.SetInsertPoint(Old);
316 
317   // SetInsertPoint also updates the Builder's debug location, but we want to
318   // keep the one the Builder was configured to use.
319   Builder.SetCurrentDebugLocation(DebugLoc);
320 }
321 
splitBB(IRBuilderBase::InsertPoint IP,bool CreateBranch,DebugLoc DL,llvm::Twine Name)322 BasicBlock *llvm::splitBB(IRBuilderBase::InsertPoint IP, bool CreateBranch,
323                           DebugLoc DL, llvm::Twine Name) {
324   BasicBlock *Old = IP.getBlock();
325   BasicBlock *New = BasicBlock::Create(
326       Old->getContext(), Name.isTriviallyEmpty() ? Old->getName() : Name,
327       Old->getParent(), Old->getNextNode());
328   spliceBB(IP, New, CreateBranch, DL);
329   New->replaceSuccessorsPhiUsesWith(Old, New);
330   return New;
331 }
332 
splitBB(IRBuilderBase & Builder,bool CreateBranch,llvm::Twine Name)333 BasicBlock *llvm::splitBB(IRBuilderBase &Builder, bool CreateBranch,
334                           llvm::Twine Name) {
335   DebugLoc DebugLoc = Builder.getCurrentDebugLocation();
336   BasicBlock *New = splitBB(Builder.saveIP(), CreateBranch, DebugLoc, Name);
337   if (CreateBranch)
338     Builder.SetInsertPoint(Builder.GetInsertBlock()->getTerminator());
339   else
340     Builder.SetInsertPoint(Builder.GetInsertBlock());
341   // SetInsertPoint also updates the Builder's debug location, but we want to
342   // keep the one the Builder was configured to use.
343   Builder.SetCurrentDebugLocation(DebugLoc);
344   return New;
345 }
346 
splitBB(IRBuilder<> & Builder,bool CreateBranch,llvm::Twine Name)347 BasicBlock *llvm::splitBB(IRBuilder<> &Builder, bool CreateBranch,
348                           llvm::Twine Name) {
349   DebugLoc DebugLoc = Builder.getCurrentDebugLocation();
350   BasicBlock *New = splitBB(Builder.saveIP(), CreateBranch, DebugLoc, Name);
351   if (CreateBranch)
352     Builder.SetInsertPoint(Builder.GetInsertBlock()->getTerminator());
353   else
354     Builder.SetInsertPoint(Builder.GetInsertBlock());
355   // SetInsertPoint also updates the Builder's debug location, but we want to
356   // keep the one the Builder was configured to use.
357   Builder.SetCurrentDebugLocation(DebugLoc);
358   return New;
359 }
360 
splitBBWithSuffix(IRBuilderBase & Builder,bool CreateBranch,llvm::Twine Suffix)361 BasicBlock *llvm::splitBBWithSuffix(IRBuilderBase &Builder, bool CreateBranch,
362                                     llvm::Twine Suffix) {
363   BasicBlock *Old = Builder.GetInsertBlock();
364   return splitBB(Builder, CreateBranch, Old->getName() + Suffix);
365 }
366 
367 // This function creates a fake integer value and a fake use for the integer
368 // value. It returns the fake value created. This is useful in modeling the
369 // extra arguments to the outlined functions.
createFakeIntVal(IRBuilderBase & Builder,OpenMPIRBuilder::InsertPointTy OuterAllocaIP,llvm::SmallVectorImpl<Instruction * > & ToBeDeleted,OpenMPIRBuilder::InsertPointTy InnerAllocaIP,const Twine & Name="",bool AsPtr=true)370 Value *createFakeIntVal(IRBuilderBase &Builder,
371                         OpenMPIRBuilder::InsertPointTy OuterAllocaIP,
372                         llvm::SmallVectorImpl<Instruction *> &ToBeDeleted,
373                         OpenMPIRBuilder::InsertPointTy InnerAllocaIP,
374                         const Twine &Name = "", bool AsPtr = true) {
375   Builder.restoreIP(OuterAllocaIP);
376   Instruction *FakeVal;
377   AllocaInst *FakeValAddr =
378       Builder.CreateAlloca(Builder.getInt32Ty(), nullptr, Name + ".addr");
379   ToBeDeleted.push_back(FakeValAddr);
380 
381   if (AsPtr) {
382     FakeVal = FakeValAddr;
383   } else {
384     FakeVal =
385         Builder.CreateLoad(Builder.getInt32Ty(), FakeValAddr, Name + ".val");
386     ToBeDeleted.push_back(FakeVal);
387   }
388 
389   // Generate a fake use of this value
390   Builder.restoreIP(InnerAllocaIP);
391   Instruction *UseFakeVal;
392   if (AsPtr) {
393     UseFakeVal =
394         Builder.CreateLoad(Builder.getInt32Ty(), FakeVal, Name + ".use");
395   } else {
396     UseFakeVal =
397         cast<BinaryOperator>(Builder.CreateAdd(FakeVal, Builder.getInt32(10)));
398   }
399   ToBeDeleted.push_back(UseFakeVal);
400   return FakeVal;
401 }
402 
403 //===----------------------------------------------------------------------===//
404 // OpenMPIRBuilderConfig
405 //===----------------------------------------------------------------------===//
406 
407 namespace {
408 LLVM_ENABLE_BITMASK_ENUMS_IN_NAMESPACE();
409 /// Values for bit flags for marking which requires clauses have been used.
410 enum OpenMPOffloadingRequiresDirFlags {
411   /// flag undefined.
412   OMP_REQ_UNDEFINED = 0x000,
413   /// no requires directive present.
414   OMP_REQ_NONE = 0x001,
415   /// reverse_offload clause.
416   OMP_REQ_REVERSE_OFFLOAD = 0x002,
417   /// unified_address clause.
418   OMP_REQ_UNIFIED_ADDRESS = 0x004,
419   /// unified_shared_memory clause.
420   OMP_REQ_UNIFIED_SHARED_MEMORY = 0x008,
421   /// dynamic_allocators clause.
422   OMP_REQ_DYNAMIC_ALLOCATORS = 0x010,
423   LLVM_MARK_AS_BITMASK_ENUM(/*LargestValue=*/OMP_REQ_DYNAMIC_ALLOCATORS)
424 };
425 
426 } // anonymous namespace
427 
OpenMPIRBuilderConfig()428 OpenMPIRBuilderConfig::OpenMPIRBuilderConfig()
429     : RequiresFlags(OMP_REQ_UNDEFINED) {}
430 
OpenMPIRBuilderConfig(bool IsTargetDevice,bool IsGPU,bool OpenMPOffloadMandatory,bool HasRequiresReverseOffload,bool HasRequiresUnifiedAddress,bool HasRequiresUnifiedSharedMemory,bool HasRequiresDynamicAllocators)431 OpenMPIRBuilderConfig::OpenMPIRBuilderConfig(
432     bool IsTargetDevice, bool IsGPU, bool OpenMPOffloadMandatory,
433     bool HasRequiresReverseOffload, bool HasRequiresUnifiedAddress,
434     bool HasRequiresUnifiedSharedMemory, bool HasRequiresDynamicAllocators)
435     : IsTargetDevice(IsTargetDevice), IsGPU(IsGPU),
436       OpenMPOffloadMandatory(OpenMPOffloadMandatory),
437       RequiresFlags(OMP_REQ_UNDEFINED) {
438   if (HasRequiresReverseOffload)
439     RequiresFlags |= OMP_REQ_REVERSE_OFFLOAD;
440   if (HasRequiresUnifiedAddress)
441     RequiresFlags |= OMP_REQ_UNIFIED_ADDRESS;
442   if (HasRequiresUnifiedSharedMemory)
443     RequiresFlags |= OMP_REQ_UNIFIED_SHARED_MEMORY;
444   if (HasRequiresDynamicAllocators)
445     RequiresFlags |= OMP_REQ_DYNAMIC_ALLOCATORS;
446 }
447 
hasRequiresReverseOffload() const448 bool OpenMPIRBuilderConfig::hasRequiresReverseOffload() const {
449   return RequiresFlags & OMP_REQ_REVERSE_OFFLOAD;
450 }
451 
hasRequiresUnifiedAddress() const452 bool OpenMPIRBuilderConfig::hasRequiresUnifiedAddress() const {
453   return RequiresFlags & OMP_REQ_UNIFIED_ADDRESS;
454 }
455 
hasRequiresUnifiedSharedMemory() const456 bool OpenMPIRBuilderConfig::hasRequiresUnifiedSharedMemory() const {
457   return RequiresFlags & OMP_REQ_UNIFIED_SHARED_MEMORY;
458 }
459 
hasRequiresDynamicAllocators() const460 bool OpenMPIRBuilderConfig::hasRequiresDynamicAllocators() const {
461   return RequiresFlags & OMP_REQ_DYNAMIC_ALLOCATORS;
462 }
463 
getRequiresFlags() const464 int64_t OpenMPIRBuilderConfig::getRequiresFlags() const {
465   return hasRequiresFlags() ? RequiresFlags
466                             : static_cast<int64_t>(OMP_REQ_NONE);
467 }
468 
setHasRequiresReverseOffload(bool Value)469 void OpenMPIRBuilderConfig::setHasRequiresReverseOffload(bool Value) {
470   if (Value)
471     RequiresFlags |= OMP_REQ_REVERSE_OFFLOAD;
472   else
473     RequiresFlags &= ~OMP_REQ_REVERSE_OFFLOAD;
474 }
475 
setHasRequiresUnifiedAddress(bool Value)476 void OpenMPIRBuilderConfig::setHasRequiresUnifiedAddress(bool Value) {
477   if (Value)
478     RequiresFlags |= OMP_REQ_UNIFIED_ADDRESS;
479   else
480     RequiresFlags &= ~OMP_REQ_UNIFIED_ADDRESS;
481 }
482 
setHasRequiresUnifiedSharedMemory(bool Value)483 void OpenMPIRBuilderConfig::setHasRequiresUnifiedSharedMemory(bool Value) {
484   if (Value)
485     RequiresFlags |= OMP_REQ_UNIFIED_SHARED_MEMORY;
486   else
487     RequiresFlags &= ~OMP_REQ_UNIFIED_SHARED_MEMORY;
488 }
489 
setHasRequiresDynamicAllocators(bool Value)490 void OpenMPIRBuilderConfig::setHasRequiresDynamicAllocators(bool Value) {
491   if (Value)
492     RequiresFlags |= OMP_REQ_DYNAMIC_ALLOCATORS;
493   else
494     RequiresFlags &= ~OMP_REQ_DYNAMIC_ALLOCATORS;
495 }
496 
497 //===----------------------------------------------------------------------===//
498 // OpenMPIRBuilder
499 //===----------------------------------------------------------------------===//
500 
getKernelArgsVector(TargetKernelArgs & KernelArgs,IRBuilderBase & Builder,SmallVector<Value * > & ArgsVector)501 void OpenMPIRBuilder::getKernelArgsVector(TargetKernelArgs &KernelArgs,
502                                           IRBuilderBase &Builder,
503                                           SmallVector<Value *> &ArgsVector) {
504   Value *Version = Builder.getInt32(OMP_KERNEL_ARG_VERSION);
505   Value *PointerNum = Builder.getInt32(KernelArgs.NumTargetItems);
506   auto Int32Ty = Type::getInt32Ty(Builder.getContext());
507   constexpr const size_t MaxDim = 3;
508   Value *ZeroArray = Constant::getNullValue(ArrayType::get(Int32Ty, MaxDim));
509   Value *Flags = Builder.getInt64(KernelArgs.HasNoWait);
510 
511   assert(!KernelArgs.NumTeams.empty() && !KernelArgs.NumThreads.empty());
512 
513   Value *NumTeams3D =
514       Builder.CreateInsertValue(ZeroArray, KernelArgs.NumTeams[0], {0});
515   Value *NumThreads3D =
516       Builder.CreateInsertValue(ZeroArray, KernelArgs.NumThreads[0], {0});
517   for (unsigned I :
518        seq<unsigned>(1, std::min(KernelArgs.NumTeams.size(), MaxDim)))
519     NumTeams3D =
520         Builder.CreateInsertValue(NumTeams3D, KernelArgs.NumTeams[I], {I});
521   for (unsigned I :
522        seq<unsigned>(1, std::min(KernelArgs.NumThreads.size(), MaxDim)))
523     NumThreads3D =
524         Builder.CreateInsertValue(NumThreads3D, KernelArgs.NumThreads[I], {I});
525 
526   ArgsVector = {Version,
527                 PointerNum,
528                 KernelArgs.RTArgs.BasePointersArray,
529                 KernelArgs.RTArgs.PointersArray,
530                 KernelArgs.RTArgs.SizesArray,
531                 KernelArgs.RTArgs.MapTypesArray,
532                 KernelArgs.RTArgs.MapNamesArray,
533                 KernelArgs.RTArgs.MappersArray,
534                 KernelArgs.NumIterations,
535                 Flags,
536                 NumTeams3D,
537                 NumThreads3D,
538                 KernelArgs.DynCGGroupMem};
539 }
540 
addAttributes(omp::RuntimeFunction FnID,Function & Fn)541 void OpenMPIRBuilder::addAttributes(omp::RuntimeFunction FnID, Function &Fn) {
542   LLVMContext &Ctx = Fn.getContext();
543 
544   // Get the function's current attributes.
545   auto Attrs = Fn.getAttributes();
546   auto FnAttrs = Attrs.getFnAttrs();
547   auto RetAttrs = Attrs.getRetAttrs();
548   SmallVector<AttributeSet, 4> ArgAttrs;
549   for (size_t ArgNo = 0; ArgNo < Fn.arg_size(); ++ArgNo)
550     ArgAttrs.emplace_back(Attrs.getParamAttrs(ArgNo));
551 
552   // Add AS to FnAS while taking special care with integer extensions.
553   auto addAttrSet = [&](AttributeSet &FnAS, const AttributeSet &AS,
554                         bool Param = true) -> void {
555     bool HasSignExt = AS.hasAttribute(Attribute::SExt);
556     bool HasZeroExt = AS.hasAttribute(Attribute::ZExt);
557     if (HasSignExt || HasZeroExt) {
558       assert(AS.getNumAttributes() == 1 &&
559              "Currently not handling extension attr combined with others.");
560       if (Param) {
561         if (auto AK = TargetLibraryInfo::getExtAttrForI32Param(T, HasSignExt))
562           FnAS = FnAS.addAttribute(Ctx, AK);
563       } else if (auto AK =
564                      TargetLibraryInfo::getExtAttrForI32Return(T, HasSignExt))
565         FnAS = FnAS.addAttribute(Ctx, AK);
566     } else {
567       FnAS = FnAS.addAttributes(Ctx, AS);
568     }
569   };
570 
571 #define OMP_ATTRS_SET(VarName, AttrSet) AttributeSet VarName = AttrSet;
572 #include "llvm/Frontend/OpenMP/OMPKinds.def"
573 
574   // Add attributes to the function declaration.
575   switch (FnID) {
576 #define OMP_RTL_ATTRS(Enum, FnAttrSet, RetAttrSet, ArgAttrSets)                \
577   case Enum:                                                                   \
578     FnAttrs = FnAttrs.addAttributes(Ctx, FnAttrSet);                           \
579     addAttrSet(RetAttrs, RetAttrSet, /*Param*/ false);                         \
580     for (size_t ArgNo = 0; ArgNo < ArgAttrSets.size(); ++ArgNo)                \
581       addAttrSet(ArgAttrs[ArgNo], ArgAttrSets[ArgNo]);                         \
582     Fn.setAttributes(AttributeList::get(Ctx, FnAttrs, RetAttrs, ArgAttrs));    \
583     break;
584 #include "llvm/Frontend/OpenMP/OMPKinds.def"
585   default:
586     // Attributes are optional.
587     break;
588   }
589 }
590 
591 FunctionCallee
getOrCreateRuntimeFunction(Module & M,RuntimeFunction FnID)592 OpenMPIRBuilder::getOrCreateRuntimeFunction(Module &M, RuntimeFunction FnID) {
593   FunctionType *FnTy = nullptr;
594   Function *Fn = nullptr;
595 
596   // Try to find the declation in the module first.
597   switch (FnID) {
598 #define OMP_RTL(Enum, Str, IsVarArg, ReturnType, ...)                          \
599   case Enum:                                                                   \
600     FnTy = FunctionType::get(ReturnType, ArrayRef<Type *>{__VA_ARGS__},        \
601                              IsVarArg);                                        \
602     Fn = M.getFunction(Str);                                                   \
603     break;
604 #include "llvm/Frontend/OpenMP/OMPKinds.def"
605   }
606 
607   if (!Fn) {
608     // Create a new declaration if we need one.
609     switch (FnID) {
610 #define OMP_RTL(Enum, Str, ...)                                                \
611   case Enum:                                                                   \
612     Fn = Function::Create(FnTy, GlobalValue::ExternalLinkage, Str, M);         \
613     break;
614 #include "llvm/Frontend/OpenMP/OMPKinds.def"
615     }
616 
617     // Add information if the runtime function takes a callback function
618     if (FnID == OMPRTL___kmpc_fork_call || FnID == OMPRTL___kmpc_fork_teams) {
619       if (!Fn->hasMetadata(LLVMContext::MD_callback)) {
620         LLVMContext &Ctx = Fn->getContext();
621         MDBuilder MDB(Ctx);
622         // Annotate the callback behavior of the runtime function:
623         //  - The callback callee is argument number 2 (microtask).
624         //  - The first two arguments of the callback callee are unknown (-1).
625         //  - All variadic arguments to the runtime function are passed to the
626         //    callback callee.
627         Fn->addMetadata(
628             LLVMContext::MD_callback,
629             *MDNode::get(Ctx, {MDB.createCallbackEncoding(
630                                   2, {-1, -1}, /* VarArgsArePassed */ true)}));
631       }
632     }
633 
634     LLVM_DEBUG(dbgs() << "Created OpenMP runtime function " << Fn->getName()
635                       << " with type " << *Fn->getFunctionType() << "\n");
636     addAttributes(FnID, *Fn);
637 
638   } else {
639     LLVM_DEBUG(dbgs() << "Found OpenMP runtime function " << Fn->getName()
640                       << " with type " << *Fn->getFunctionType() << "\n");
641   }
642 
643   assert(Fn && "Failed to create OpenMP runtime function");
644 
645   return {FnTy, Fn};
646 }
647 
getOrCreateRuntimeFunctionPtr(RuntimeFunction FnID)648 Function *OpenMPIRBuilder::getOrCreateRuntimeFunctionPtr(RuntimeFunction FnID) {
649   FunctionCallee RTLFn = getOrCreateRuntimeFunction(M, FnID);
650   auto *Fn = dyn_cast<llvm::Function>(RTLFn.getCallee());
651   assert(Fn && "Failed to create OpenMP runtime function pointer");
652   return Fn;
653 }
654 
initialize()655 void OpenMPIRBuilder::initialize() { initializeTypes(M); }
656 
raiseUserConstantDataAllocasToEntryBlock(IRBuilderBase & Builder,Function * Function)657 static void raiseUserConstantDataAllocasToEntryBlock(IRBuilderBase &Builder,
658                                                      Function *Function) {
659   BasicBlock &EntryBlock = Function->getEntryBlock();
660   BasicBlock::iterator MoveLocInst = EntryBlock.getFirstNonPHIIt();
661 
662   // Loop over blocks looking for constant allocas, skipping the entry block
663   // as any allocas there are already in the desired location.
664   for (auto Block = std::next(Function->begin(), 1); Block != Function->end();
665        Block++) {
666     for (auto Inst = Block->getReverseIterator()->begin();
667          Inst != Block->getReverseIterator()->end();) {
668       if (auto *AllocaInst = dyn_cast_if_present<llvm::AllocaInst>(Inst)) {
669         Inst++;
670         if (!isa<ConstantData>(AllocaInst->getArraySize()))
671           continue;
672         AllocaInst->moveBeforePreserving(MoveLocInst);
673       } else {
674         Inst++;
675       }
676     }
677   }
678 }
679 
finalize(Function * Fn)680 void OpenMPIRBuilder::finalize(Function *Fn) {
681   SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
682   SmallVector<BasicBlock *, 32> Blocks;
683   SmallVector<OutlineInfo, 16> DeferredOutlines;
684   for (OutlineInfo &OI : OutlineInfos) {
685     // Skip functions that have not finalized yet; may happen with nested
686     // function generation.
687     if (Fn && OI.getFunction() != Fn) {
688       DeferredOutlines.push_back(OI);
689       continue;
690     }
691 
692     ParallelRegionBlockSet.clear();
693     Blocks.clear();
694     OI.collectBlocks(ParallelRegionBlockSet, Blocks);
695 
696     Function *OuterFn = OI.getFunction();
697     CodeExtractorAnalysisCache CEAC(*OuterFn);
698     // If we generate code for the target device, we need to allocate
699     // struct for aggregate params in the device default alloca address space.
700     // OpenMP runtime requires that the params of the extracted functions are
701     // passed as zero address space pointers. This flag ensures that
702     // CodeExtractor generates correct code for extracted functions
703     // which are used by OpenMP runtime.
704     bool ArgsInZeroAddressSpace = Config.isTargetDevice();
705     CodeExtractor Extractor(Blocks, /* DominatorTree */ nullptr,
706                             /* AggregateArgs */ true,
707                             /* BlockFrequencyInfo */ nullptr,
708                             /* BranchProbabilityInfo */ nullptr,
709                             /* AssumptionCache */ nullptr,
710                             /* AllowVarArgs */ true,
711                             /* AllowAlloca */ true,
712                             /* AllocaBlock*/ OI.OuterAllocaBB,
713                             /* Suffix */ ".omp_par", ArgsInZeroAddressSpace);
714 
715     LLVM_DEBUG(dbgs() << "Before     outlining: " << *OuterFn << "\n");
716     LLVM_DEBUG(dbgs() << "Entry " << OI.EntryBB->getName()
717                       << " Exit: " << OI.ExitBB->getName() << "\n");
718     assert(Extractor.isEligible() &&
719            "Expected OpenMP outlining to be possible!");
720 
721     for (auto *V : OI.ExcludeArgsFromAggregate)
722       Extractor.excludeArgFromAggregate(V);
723 
724     Function *OutlinedFn = Extractor.extractCodeRegion(CEAC);
725 
726     // Forward target-cpu, target-features attributes to the outlined function.
727     auto TargetCpuAttr = OuterFn->getFnAttribute("target-cpu");
728     if (TargetCpuAttr.isStringAttribute())
729       OutlinedFn->addFnAttr(TargetCpuAttr);
730 
731     auto TargetFeaturesAttr = OuterFn->getFnAttribute("target-features");
732     if (TargetFeaturesAttr.isStringAttribute())
733       OutlinedFn->addFnAttr(TargetFeaturesAttr);
734 
735     LLVM_DEBUG(dbgs() << "After      outlining: " << *OuterFn << "\n");
736     LLVM_DEBUG(dbgs() << "   Outlined function: " << *OutlinedFn << "\n");
737     assert(OutlinedFn->getReturnType()->isVoidTy() &&
738            "OpenMP outlined functions should not return a value!");
739 
740     // For compability with the clang CG we move the outlined function after the
741     // one with the parallel region.
742     OutlinedFn->removeFromParent();
743     M.getFunctionList().insertAfter(OuterFn->getIterator(), OutlinedFn);
744 
745     // Remove the artificial entry introduced by the extractor right away, we
746     // made our own entry block after all.
747     {
748       BasicBlock &ArtificialEntry = OutlinedFn->getEntryBlock();
749       assert(ArtificialEntry.getUniqueSuccessor() == OI.EntryBB);
750       assert(OI.EntryBB->getUniquePredecessor() == &ArtificialEntry);
751       // Move instructions from the to-be-deleted ArtificialEntry to the entry
752       // basic block of the parallel region. CodeExtractor generates
753       // instructions to unwrap the aggregate argument and may sink
754       // allocas/bitcasts for values that are solely used in the outlined region
755       // and do not escape.
756       assert(!ArtificialEntry.empty() &&
757              "Expected instructions to add in the outlined region entry");
758       for (BasicBlock::reverse_iterator It = ArtificialEntry.rbegin(),
759                                         End = ArtificialEntry.rend();
760            It != End;) {
761         Instruction &I = *It;
762         It++;
763 
764         if (I.isTerminator()) {
765           // Absorb any debug value that terminator may have
766           if (OI.EntryBB->getTerminator())
767             OI.EntryBB->getTerminator()->adoptDbgRecords(
768                 &ArtificialEntry, I.getIterator(), false);
769           continue;
770         }
771 
772         I.moveBeforePreserving(*OI.EntryBB, OI.EntryBB->getFirstInsertionPt());
773       }
774 
775       OI.EntryBB->moveBefore(&ArtificialEntry);
776       ArtificialEntry.eraseFromParent();
777     }
778     assert(&OutlinedFn->getEntryBlock() == OI.EntryBB);
779     assert(OutlinedFn && OutlinedFn->hasNUses(1));
780 
781     // Run a user callback, e.g. to add attributes.
782     if (OI.PostOutlineCB)
783       OI.PostOutlineCB(*OutlinedFn);
784   }
785 
786   // Remove work items that have been completed.
787   OutlineInfos = std::move(DeferredOutlines);
788 
789   // The createTarget functions embeds user written code into
790   // the target region which may inject allocas which need to
791   // be moved to the entry block of our target or risk malformed
792   // optimisations by later passes, this is only relevant for
793   // the device pass which appears to be a little more delicate
794   // when it comes to optimisations (however, we do not block on
795   // that here, it's up to the inserter to the list to do so).
796   // This notbaly has to occur after the OutlinedInfo candidates
797   // have been extracted so we have an end product that will not
798   // be implicitly adversely affected by any raises unless
799   // intentionally appended to the list.
800   // NOTE: This only does so for ConstantData, it could be extended
801   // to ConstantExpr's with further effort, however, they should
802   // largely be folded when they get here. Extending it to runtime
803   // defined/read+writeable allocation sizes would be non-trivial
804   // (need to factor in movement of any stores to variables the
805   // allocation size depends on, as well as the usual loads,
806   // otherwise it'll yield the wrong result after movement) and
807   // likely be more suitable as an LLVM optimisation pass.
808   for (Function *F : ConstantAllocaRaiseCandidates)
809     raiseUserConstantDataAllocasToEntryBlock(Builder, F);
810 
811   EmitMetadataErrorReportFunctionTy &&ErrorReportFn =
812       [](EmitMetadataErrorKind Kind,
813          const TargetRegionEntryInfo &EntryInfo) -> void {
814     errs() << "Error of kind: " << Kind
815            << " when emitting offload entries and metadata during "
816               "OMPIRBuilder finalization \n";
817   };
818 
819   if (!OffloadInfoManager.empty())
820     createOffloadEntriesAndInfoMetadata(ErrorReportFn);
821 
822   if (Config.EmitLLVMUsedMetaInfo.value_or(false)) {
823     std::vector<WeakTrackingVH> LLVMCompilerUsed = {
824         M.getGlobalVariable("__openmp_nvptx_data_transfer_temporary_storage")};
825     emitUsed("llvm.compiler.used", LLVMCompilerUsed);
826   }
827 
828   IsFinalized = true;
829 }
830 
isFinalized()831 bool OpenMPIRBuilder::isFinalized() { return IsFinalized; }
832 
~OpenMPIRBuilder()833 OpenMPIRBuilder::~OpenMPIRBuilder() {
834   assert(OutlineInfos.empty() && "There must be no outstanding outlinings");
835 }
836 
createGlobalFlag(unsigned Value,StringRef Name)837 GlobalValue *OpenMPIRBuilder::createGlobalFlag(unsigned Value, StringRef Name) {
838   IntegerType *I32Ty = Type::getInt32Ty(M.getContext());
839   auto *GV =
840       new GlobalVariable(M, I32Ty,
841                          /* isConstant = */ true, GlobalValue::WeakODRLinkage,
842                          ConstantInt::get(I32Ty, Value), Name);
843   GV->setVisibility(GlobalValue::HiddenVisibility);
844 
845   return GV;
846 }
847 
emitUsed(StringRef Name,ArrayRef<WeakTrackingVH> List)848 void OpenMPIRBuilder::emitUsed(StringRef Name, ArrayRef<WeakTrackingVH> List) {
849   if (List.empty())
850     return;
851 
852   // Convert List to what ConstantArray needs.
853   SmallVector<Constant *, 8> UsedArray;
854   UsedArray.resize(List.size());
855   for (unsigned I = 0, E = List.size(); I != E; ++I)
856     UsedArray[I] = ConstantExpr::getPointerBitCastOrAddrSpaceCast(
857         cast<Constant>(&*List[I]), Builder.getPtrTy());
858 
859   if (UsedArray.empty())
860     return;
861   ArrayType *ATy = ArrayType::get(Builder.getPtrTy(), UsedArray.size());
862 
863   auto *GV = new GlobalVariable(M, ATy, false, GlobalValue::AppendingLinkage,
864                                 ConstantArray::get(ATy, UsedArray), Name);
865 
866   GV->setSection("llvm.metadata");
867 }
868 
869 GlobalVariable *
emitKernelExecutionMode(StringRef KernelName,OMPTgtExecModeFlags Mode)870 OpenMPIRBuilder::emitKernelExecutionMode(StringRef KernelName,
871                                          OMPTgtExecModeFlags Mode) {
872   auto *Int8Ty = Builder.getInt8Ty();
873   auto *GVMode = new GlobalVariable(
874       M, Int8Ty, /*isConstant=*/true, GlobalValue::WeakAnyLinkage,
875       ConstantInt::get(Int8Ty, Mode), Twine(KernelName, "_exec_mode"));
876   GVMode->setVisibility(GlobalVariable::ProtectedVisibility);
877   return GVMode;
878 }
879 
getOrCreateIdent(Constant * SrcLocStr,uint32_t SrcLocStrSize,IdentFlag LocFlags,unsigned Reserve2Flags)880 Constant *OpenMPIRBuilder::getOrCreateIdent(Constant *SrcLocStr,
881                                             uint32_t SrcLocStrSize,
882                                             IdentFlag LocFlags,
883                                             unsigned Reserve2Flags) {
884   // Enable "C-mode".
885   LocFlags |= OMP_IDENT_FLAG_KMPC;
886 
887   Constant *&Ident =
888       IdentMap[{SrcLocStr, uint64_t(LocFlags) << 31 | Reserve2Flags}];
889   if (!Ident) {
890     Constant *I32Null = ConstantInt::getNullValue(Int32);
891     Constant *IdentData[] = {I32Null,
892                              ConstantInt::get(Int32, uint32_t(LocFlags)),
893                              ConstantInt::get(Int32, Reserve2Flags),
894                              ConstantInt::get(Int32, SrcLocStrSize), SrcLocStr};
895     Constant *Initializer =
896         ConstantStruct::get(OpenMPIRBuilder::Ident, IdentData);
897 
898     // Look for existing encoding of the location + flags, not needed but
899     // minimizes the difference to the existing solution while we transition.
900     for (GlobalVariable &GV : M.globals())
901       if (GV.getValueType() == OpenMPIRBuilder::Ident && GV.hasInitializer())
902         if (GV.getInitializer() == Initializer)
903           Ident = &GV;
904 
905     if (!Ident) {
906       auto *GV = new GlobalVariable(
907           M, OpenMPIRBuilder::Ident,
908           /* isConstant = */ true, GlobalValue::PrivateLinkage, Initializer, "",
909           nullptr, GlobalValue::NotThreadLocal,
910           M.getDataLayout().getDefaultGlobalsAddressSpace());
911       GV->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
912       GV->setAlignment(Align(8));
913       Ident = GV;
914     }
915   }
916 
917   return ConstantExpr::getPointerBitCastOrAddrSpaceCast(Ident, IdentPtr);
918 }
919 
getOrCreateSrcLocStr(StringRef LocStr,uint32_t & SrcLocStrSize)920 Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(StringRef LocStr,
921                                                 uint32_t &SrcLocStrSize) {
922   SrcLocStrSize = LocStr.size();
923   Constant *&SrcLocStr = SrcLocStrMap[LocStr];
924   if (!SrcLocStr) {
925     Constant *Initializer =
926         ConstantDataArray::getString(M.getContext(), LocStr);
927 
928     // Look for existing encoding of the location, not needed but minimizes the
929     // difference to the existing solution while we transition.
930     for (GlobalVariable &GV : M.globals())
931       if (GV.isConstant() && GV.hasInitializer() &&
932           GV.getInitializer() == Initializer)
933         return SrcLocStr = ConstantExpr::getPointerCast(&GV, Int8Ptr);
934 
935     SrcLocStr = Builder.CreateGlobalString(LocStr, /* Name */ "",
936                                            /* AddressSpace */ 0, &M);
937   }
938   return SrcLocStr;
939 }
940 
getOrCreateSrcLocStr(StringRef FunctionName,StringRef FileName,unsigned Line,unsigned Column,uint32_t & SrcLocStrSize)941 Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(StringRef FunctionName,
942                                                 StringRef FileName,
943                                                 unsigned Line, unsigned Column,
944                                                 uint32_t &SrcLocStrSize) {
945   SmallString<128> Buffer;
946   Buffer.push_back(';');
947   Buffer.append(FileName);
948   Buffer.push_back(';');
949   Buffer.append(FunctionName);
950   Buffer.push_back(';');
951   Buffer.append(std::to_string(Line));
952   Buffer.push_back(';');
953   Buffer.append(std::to_string(Column));
954   Buffer.push_back(';');
955   Buffer.push_back(';');
956   return getOrCreateSrcLocStr(Buffer.str(), SrcLocStrSize);
957 }
958 
959 Constant *
getOrCreateDefaultSrcLocStr(uint32_t & SrcLocStrSize)960 OpenMPIRBuilder::getOrCreateDefaultSrcLocStr(uint32_t &SrcLocStrSize) {
961   StringRef UnknownLoc = ";unknown;unknown;0;0;;";
962   return getOrCreateSrcLocStr(UnknownLoc, SrcLocStrSize);
963 }
964 
getOrCreateSrcLocStr(DebugLoc DL,uint32_t & SrcLocStrSize,Function * F)965 Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(DebugLoc DL,
966                                                 uint32_t &SrcLocStrSize,
967                                                 Function *F) {
968   DILocation *DIL = DL.get();
969   if (!DIL)
970     return getOrCreateDefaultSrcLocStr(SrcLocStrSize);
971   StringRef FileName = M.getName();
972   if (DIFile *DIF = DIL->getFile())
973     if (std::optional<StringRef> Source = DIF->getSource())
974       FileName = *Source;
975   StringRef Function = DIL->getScope()->getSubprogram()->getName();
976   if (Function.empty() && F)
977     Function = F->getName();
978   return getOrCreateSrcLocStr(Function, FileName, DIL->getLine(),
979                               DIL->getColumn(), SrcLocStrSize);
980 }
981 
getOrCreateSrcLocStr(const LocationDescription & Loc,uint32_t & SrcLocStrSize)982 Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(const LocationDescription &Loc,
983                                                 uint32_t &SrcLocStrSize) {
984   return getOrCreateSrcLocStr(Loc.DL, SrcLocStrSize,
985                               Loc.IP.getBlock()->getParent());
986 }
987 
getOrCreateThreadID(Value * Ident)988 Value *OpenMPIRBuilder::getOrCreateThreadID(Value *Ident) {
989   return Builder.CreateCall(
990       getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_global_thread_num), Ident,
991       "omp_global_thread_num");
992 }
993 
994 OpenMPIRBuilder::InsertPointOrErrorTy
createBarrier(const LocationDescription & Loc,Directive Kind,bool ForceSimpleCall,bool CheckCancelFlag)995 OpenMPIRBuilder::createBarrier(const LocationDescription &Loc, Directive Kind,
996                                bool ForceSimpleCall, bool CheckCancelFlag) {
997   if (!updateToLocation(Loc))
998     return Loc.IP;
999 
1000   // Build call __kmpc_cancel_barrier(loc, thread_id) or
1001   //            __kmpc_barrier(loc, thread_id);
1002 
1003   IdentFlag BarrierLocFlags;
1004   switch (Kind) {
1005   case OMPD_for:
1006     BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL_FOR;
1007     break;
1008   case OMPD_sections:
1009     BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL_SECTIONS;
1010     break;
1011   case OMPD_single:
1012     BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL_SINGLE;
1013     break;
1014   case OMPD_barrier:
1015     BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_EXPL;
1016     break;
1017   default:
1018     BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL;
1019     break;
1020   }
1021 
1022   uint32_t SrcLocStrSize;
1023   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1024   Value *Args[] = {
1025       getOrCreateIdent(SrcLocStr, SrcLocStrSize, BarrierLocFlags),
1026       getOrCreateThreadID(getOrCreateIdent(SrcLocStr, SrcLocStrSize))};
1027 
1028   // If we are in a cancellable parallel region, barriers are cancellation
1029   // points.
1030   // TODO: Check why we would force simple calls or to ignore the cancel flag.
1031   bool UseCancelBarrier =
1032       !ForceSimpleCall && isLastFinalizationInfoCancellable(OMPD_parallel);
1033 
1034   Value *Result =
1035       Builder.CreateCall(getOrCreateRuntimeFunctionPtr(
1036                              UseCancelBarrier ? OMPRTL___kmpc_cancel_barrier
1037                                               : OMPRTL___kmpc_barrier),
1038                          Args);
1039 
1040   if (UseCancelBarrier && CheckCancelFlag)
1041     if (Error Err = emitCancelationCheckImpl(Result, OMPD_parallel))
1042       return Err;
1043 
1044   return Builder.saveIP();
1045 }
1046 
1047 OpenMPIRBuilder::InsertPointOrErrorTy
createCancel(const LocationDescription & Loc,Value * IfCondition,omp::Directive CanceledDirective)1048 OpenMPIRBuilder::createCancel(const LocationDescription &Loc,
1049                               Value *IfCondition,
1050                               omp::Directive CanceledDirective) {
1051   if (!updateToLocation(Loc))
1052     return Loc.IP;
1053 
1054   // LLVM utilities like blocks with terminators.
1055   auto *UI = Builder.CreateUnreachable();
1056 
1057   Instruction *ThenTI = UI, *ElseTI = nullptr;
1058   if (IfCondition)
1059     SplitBlockAndInsertIfThenElse(IfCondition, UI, &ThenTI, &ElseTI);
1060   Builder.SetInsertPoint(ThenTI);
1061 
1062   Value *CancelKind = nullptr;
1063   switch (CanceledDirective) {
1064 #define OMP_CANCEL_KIND(Enum, Str, DirectiveEnum, Value)                       \
1065   case DirectiveEnum:                                                          \
1066     CancelKind = Builder.getInt32(Value);                                      \
1067     break;
1068 #include "llvm/Frontend/OpenMP/OMPKinds.def"
1069   default:
1070     llvm_unreachable("Unknown cancel kind!");
1071   }
1072 
1073   uint32_t SrcLocStrSize;
1074   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1075   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1076   Value *Args[] = {Ident, getOrCreateThreadID(Ident), CancelKind};
1077   Value *Result = Builder.CreateCall(
1078       getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_cancel), Args);
1079   auto ExitCB = [this, CanceledDirective, Loc](InsertPointTy IP) -> Error {
1080     if (CanceledDirective == OMPD_parallel) {
1081       IRBuilder<>::InsertPointGuard IPG(Builder);
1082       Builder.restoreIP(IP);
1083       return createBarrier(LocationDescription(Builder.saveIP(), Loc.DL),
1084                            omp::Directive::OMPD_unknown,
1085                            /* ForceSimpleCall */ false,
1086                            /* CheckCancelFlag */ false)
1087           .takeError();
1088     }
1089     return Error::success();
1090   };
1091 
1092   // The actual cancel logic is shared with others, e.g., cancel_barriers.
1093   if (Error Err = emitCancelationCheckImpl(Result, CanceledDirective, ExitCB))
1094     return Err;
1095 
1096   // Update the insertion point and remove the terminator we introduced.
1097   Builder.SetInsertPoint(UI->getParent());
1098   UI->eraseFromParent();
1099 
1100   return Builder.saveIP();
1101 }
1102 
1103 OpenMPIRBuilder::InsertPointOrErrorTy
createCancellationPoint(const LocationDescription & Loc,omp::Directive CanceledDirective)1104 OpenMPIRBuilder::createCancellationPoint(const LocationDescription &Loc,
1105                                          omp::Directive CanceledDirective) {
1106   if (!updateToLocation(Loc))
1107     return Loc.IP;
1108 
1109   // LLVM utilities like blocks with terminators.
1110   auto *UI = Builder.CreateUnreachable();
1111   Builder.SetInsertPoint(UI);
1112 
1113   Value *CancelKind = nullptr;
1114   switch (CanceledDirective) {
1115 #define OMP_CANCEL_KIND(Enum, Str, DirectiveEnum, Value)                       \
1116   case DirectiveEnum:                                                          \
1117     CancelKind = Builder.getInt32(Value);                                      \
1118     break;
1119 #include "llvm/Frontend/OpenMP/OMPKinds.def"
1120   default:
1121     llvm_unreachable("Unknown cancel kind!");
1122   }
1123 
1124   uint32_t SrcLocStrSize;
1125   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1126   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1127   Value *Args[] = {Ident, getOrCreateThreadID(Ident), CancelKind};
1128   Value *Result = Builder.CreateCall(
1129       getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_cancellationpoint), Args);
1130   auto ExitCB = [this, CanceledDirective, Loc](InsertPointTy IP) -> Error {
1131     if (CanceledDirective == OMPD_parallel) {
1132       IRBuilder<>::InsertPointGuard IPG(Builder);
1133       Builder.restoreIP(IP);
1134       return createBarrier(LocationDescription(Builder.saveIP(), Loc.DL),
1135                            omp::Directive::OMPD_unknown,
1136                            /* ForceSimpleCall */ false,
1137                            /* CheckCancelFlag */ false)
1138           .takeError();
1139     }
1140     return Error::success();
1141   };
1142 
1143   // The actual cancel logic is shared with others, e.g., cancel_barriers.
1144   if (Error Err = emitCancelationCheckImpl(Result, CanceledDirective, ExitCB))
1145     return Err;
1146 
1147   // Update the insertion point and remove the terminator we introduced.
1148   Builder.SetInsertPoint(UI->getParent());
1149   UI->eraseFromParent();
1150 
1151   return Builder.saveIP();
1152 }
1153 
emitTargetKernel(const LocationDescription & Loc,InsertPointTy AllocaIP,Value * & Return,Value * Ident,Value * DeviceID,Value * NumTeams,Value * NumThreads,Value * HostPtr,ArrayRef<Value * > KernelArgs)1154 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitTargetKernel(
1155     const LocationDescription &Loc, InsertPointTy AllocaIP, Value *&Return,
1156     Value *Ident, Value *DeviceID, Value *NumTeams, Value *NumThreads,
1157     Value *HostPtr, ArrayRef<Value *> KernelArgs) {
1158   if (!updateToLocation(Loc))
1159     return Loc.IP;
1160 
1161   Builder.restoreIP(AllocaIP);
1162   auto *KernelArgsPtr =
1163       Builder.CreateAlloca(OpenMPIRBuilder::KernelArgs, nullptr, "kernel_args");
1164   Builder.restoreIP(Loc.IP);
1165 
1166   for (unsigned I = 0, Size = KernelArgs.size(); I != Size; ++I) {
1167     llvm::Value *Arg =
1168         Builder.CreateStructGEP(OpenMPIRBuilder::KernelArgs, KernelArgsPtr, I);
1169     Builder.CreateAlignedStore(
1170         KernelArgs[I], Arg,
1171         M.getDataLayout().getPrefTypeAlign(KernelArgs[I]->getType()));
1172   }
1173 
1174   SmallVector<Value *> OffloadingArgs{Ident,      DeviceID, NumTeams,
1175                                       NumThreads, HostPtr,  KernelArgsPtr};
1176 
1177   Return = Builder.CreateCall(
1178       getOrCreateRuntimeFunction(M, OMPRTL___tgt_target_kernel),
1179       OffloadingArgs);
1180 
1181   return Builder.saveIP();
1182 }
1183 
emitKernelLaunch(const LocationDescription & Loc,Value * OutlinedFnID,EmitFallbackCallbackTy EmitTargetCallFallbackCB,TargetKernelArgs & Args,Value * DeviceID,Value * RTLoc,InsertPointTy AllocaIP)1184 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitKernelLaunch(
1185     const LocationDescription &Loc, Value *OutlinedFnID,
1186     EmitFallbackCallbackTy EmitTargetCallFallbackCB, TargetKernelArgs &Args,
1187     Value *DeviceID, Value *RTLoc, InsertPointTy AllocaIP) {
1188 
1189   if (!updateToLocation(Loc))
1190     return Loc.IP;
1191 
1192   Builder.restoreIP(Loc.IP);
1193   // On top of the arrays that were filled up, the target offloading call
1194   // takes as arguments the device id as well as the host pointer. The host
1195   // pointer is used by the runtime library to identify the current target
1196   // region, so it only has to be unique and not necessarily point to
1197   // anything. It could be the pointer to the outlined function that
1198   // implements the target region, but we aren't using that so that the
1199   // compiler doesn't need to keep that, and could therefore inline the host
1200   // function if proven worthwhile during optimization.
1201 
1202   // From this point on, we need to have an ID of the target region defined.
1203   assert(OutlinedFnID && "Invalid outlined function ID!");
1204   (void)OutlinedFnID;
1205 
1206   // Return value of the runtime offloading call.
1207   Value *Return = nullptr;
1208 
1209   // Arguments for the target kernel.
1210   SmallVector<Value *> ArgsVector;
1211   getKernelArgsVector(Args, Builder, ArgsVector);
1212 
1213   // The target region is an outlined function launched by the runtime
1214   // via calls to __tgt_target_kernel().
1215   //
1216   // Note that on the host and CPU targets, the runtime implementation of
1217   // these calls simply call the outlined function without forking threads.
1218   // The outlined functions themselves have runtime calls to
1219   // __kmpc_fork_teams() and __kmpc_fork() for this purpose, codegen'd by
1220   // the compiler in emitTeamsCall() and emitParallelCall().
1221   //
1222   // In contrast, on the NVPTX target, the implementation of
1223   // __tgt_target_teams() launches a GPU kernel with the requested number
1224   // of teams and threads so no additional calls to the runtime are required.
1225   // Check the error code and execute the host version if required.
1226   Builder.restoreIP(emitTargetKernel(
1227       Builder, AllocaIP, Return, RTLoc, DeviceID, Args.NumTeams.front(),
1228       Args.NumThreads.front(), OutlinedFnID, ArgsVector));
1229 
1230   BasicBlock *OffloadFailedBlock =
1231       BasicBlock::Create(Builder.getContext(), "omp_offload.failed");
1232   BasicBlock *OffloadContBlock =
1233       BasicBlock::Create(Builder.getContext(), "omp_offload.cont");
1234   Value *Failed = Builder.CreateIsNotNull(Return);
1235   Builder.CreateCondBr(Failed, OffloadFailedBlock, OffloadContBlock);
1236 
1237   auto CurFn = Builder.GetInsertBlock()->getParent();
1238   emitBlock(OffloadFailedBlock, CurFn);
1239   InsertPointOrErrorTy AfterIP = EmitTargetCallFallbackCB(Builder.saveIP());
1240   if (!AfterIP)
1241     return AfterIP.takeError();
1242   Builder.restoreIP(*AfterIP);
1243   emitBranch(OffloadContBlock);
1244   emitBlock(OffloadContBlock, CurFn, /*IsFinished=*/true);
1245   return Builder.saveIP();
1246 }
1247 
emitCancelationCheckImpl(Value * CancelFlag,omp::Directive CanceledDirective,FinalizeCallbackTy ExitCB)1248 Error OpenMPIRBuilder::emitCancelationCheckImpl(
1249     Value *CancelFlag, omp::Directive CanceledDirective,
1250     FinalizeCallbackTy ExitCB) {
1251   assert(isLastFinalizationInfoCancellable(CanceledDirective) &&
1252          "Unexpected cancellation!");
1253 
1254   // For a cancel barrier we create two new blocks.
1255   BasicBlock *BB = Builder.GetInsertBlock();
1256   BasicBlock *NonCancellationBlock;
1257   if (Builder.GetInsertPoint() == BB->end()) {
1258     // TODO: This branch will not be needed once we moved to the
1259     // OpenMPIRBuilder codegen completely.
1260     NonCancellationBlock = BasicBlock::Create(
1261         BB->getContext(), BB->getName() + ".cont", BB->getParent());
1262   } else {
1263     NonCancellationBlock = SplitBlock(BB, &*Builder.GetInsertPoint());
1264     BB->getTerminator()->eraseFromParent();
1265     Builder.SetInsertPoint(BB);
1266   }
1267   BasicBlock *CancellationBlock = BasicBlock::Create(
1268       BB->getContext(), BB->getName() + ".cncl", BB->getParent());
1269 
1270   // Jump to them based on the return value.
1271   Value *Cmp = Builder.CreateIsNull(CancelFlag);
1272   Builder.CreateCondBr(Cmp, NonCancellationBlock, CancellationBlock,
1273                        /* TODO weight */ nullptr, nullptr);
1274 
1275   // From the cancellation block we finalize all variables and go to the
1276   // post finalization block that is known to the FiniCB callback.
1277   Builder.SetInsertPoint(CancellationBlock);
1278   if (ExitCB)
1279     if (Error Err = ExitCB(Builder.saveIP()))
1280       return Err;
1281   auto &FI = FinalizationStack.back();
1282   if (Error Err = FI.FiniCB(Builder.saveIP()))
1283     return Err;
1284 
1285   // The continuation block is where code generation continues.
1286   Builder.SetInsertPoint(NonCancellationBlock, NonCancellationBlock->begin());
1287   return Error::success();
1288 }
1289 
1290 // Callback used to create OpenMP runtime calls to support
1291 // omp parallel clause for the device.
1292 // We need to use this callback to replace call to the OutlinedFn in OuterFn
1293 // by the call to the OpenMP DeviceRTL runtime function (kmpc_parallel_51)
targetParallelCallback(OpenMPIRBuilder * OMPIRBuilder,Function & OutlinedFn,Function * OuterFn,BasicBlock * OuterAllocaBB,Value * Ident,Value * IfCondition,Value * NumThreads,Instruction * PrivTID,AllocaInst * PrivTIDAddr,Value * ThreadID,const SmallVector<Instruction *,4> & ToBeDeleted)1294 static void targetParallelCallback(
1295     OpenMPIRBuilder *OMPIRBuilder, Function &OutlinedFn, Function *OuterFn,
1296     BasicBlock *OuterAllocaBB, Value *Ident, Value *IfCondition,
1297     Value *NumThreads, Instruction *PrivTID, AllocaInst *PrivTIDAddr,
1298     Value *ThreadID, const SmallVector<Instruction *, 4> &ToBeDeleted) {
1299   // Add some known attributes.
1300   IRBuilder<> &Builder = OMPIRBuilder->Builder;
1301   OutlinedFn.addParamAttr(0, Attribute::NoAlias);
1302   OutlinedFn.addParamAttr(1, Attribute::NoAlias);
1303   OutlinedFn.addParamAttr(0, Attribute::NoUndef);
1304   OutlinedFn.addParamAttr(1, Attribute::NoUndef);
1305   OutlinedFn.addFnAttr(Attribute::NoUnwind);
1306 
1307   assert(OutlinedFn.arg_size() >= 2 &&
1308          "Expected at least tid and bounded tid as arguments");
1309   unsigned NumCapturedVars = OutlinedFn.arg_size() - /* tid & bounded tid */ 2;
1310 
1311   CallInst *CI = cast<CallInst>(OutlinedFn.user_back());
1312   assert(CI && "Expected call instruction to outlined function");
1313   CI->getParent()->setName("omp_parallel");
1314 
1315   Builder.SetInsertPoint(CI);
1316   Type *PtrTy = OMPIRBuilder->VoidPtr;
1317   Value *NullPtrValue = Constant::getNullValue(PtrTy);
1318 
1319   // Add alloca for kernel args
1320   OpenMPIRBuilder ::InsertPointTy CurrentIP = Builder.saveIP();
1321   Builder.SetInsertPoint(OuterAllocaBB, OuterAllocaBB->getFirstInsertionPt());
1322   AllocaInst *ArgsAlloca =
1323       Builder.CreateAlloca(ArrayType::get(PtrTy, NumCapturedVars));
1324   Value *Args = ArgsAlloca;
1325   // Add address space cast if array for storing arguments is not allocated
1326   // in address space 0
1327   if (ArgsAlloca->getAddressSpace())
1328     Args = Builder.CreatePointerCast(ArgsAlloca, PtrTy);
1329   Builder.restoreIP(CurrentIP);
1330 
1331   // Store captured vars which are used by kmpc_parallel_51
1332   for (unsigned Idx = 0; Idx < NumCapturedVars; Idx++) {
1333     Value *V = *(CI->arg_begin() + 2 + Idx);
1334     Value *StoreAddress = Builder.CreateConstInBoundsGEP2_64(
1335         ArrayType::get(PtrTy, NumCapturedVars), Args, 0, Idx);
1336     Builder.CreateStore(V, StoreAddress);
1337   }
1338 
1339   Value *Cond =
1340       IfCondition ? Builder.CreateSExtOrTrunc(IfCondition, OMPIRBuilder->Int32)
1341                   : Builder.getInt32(1);
1342 
1343   // Build kmpc_parallel_51 call
1344   Value *Parallel51CallArgs[] = {
1345       /* identifier*/ Ident,
1346       /* global thread num*/ ThreadID,
1347       /* if expression */ Cond,
1348       /* number of threads */ NumThreads ? NumThreads : Builder.getInt32(-1),
1349       /* Proc bind */ Builder.getInt32(-1),
1350       /* outlined function */ &OutlinedFn,
1351       /* wrapper function */ NullPtrValue,
1352       /* arguments of the outlined funciton*/ Args,
1353       /* number of arguments */ Builder.getInt64(NumCapturedVars)};
1354 
1355   FunctionCallee RTLFn =
1356       OMPIRBuilder->getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_parallel_51);
1357 
1358   Builder.CreateCall(RTLFn, Parallel51CallArgs);
1359 
1360   LLVM_DEBUG(dbgs() << "With kmpc_parallel_51 placed: "
1361                     << *Builder.GetInsertBlock()->getParent() << "\n");
1362 
1363   // Initialize the local TID stack location with the argument value.
1364   Builder.SetInsertPoint(PrivTID);
1365   Function::arg_iterator OutlinedAI = OutlinedFn.arg_begin();
1366   Builder.CreateStore(Builder.CreateLoad(OMPIRBuilder->Int32, OutlinedAI),
1367                       PrivTIDAddr);
1368 
1369   // Remove redundant call to the outlined function.
1370   CI->eraseFromParent();
1371 
1372   for (Instruction *I : ToBeDeleted) {
1373     I->eraseFromParent();
1374   }
1375 }
1376 
1377 // Callback used to create OpenMP runtime calls to support
1378 // omp parallel clause for the host.
1379 // We need to use this callback to replace call to the OutlinedFn in OuterFn
1380 // by the call to the OpenMP host runtime function ( __kmpc_fork_call[_if])
1381 static void
hostParallelCallback(OpenMPIRBuilder * OMPIRBuilder,Function & OutlinedFn,Function * OuterFn,Value * Ident,Value * IfCondition,Instruction * PrivTID,AllocaInst * PrivTIDAddr,const SmallVector<Instruction *,4> & ToBeDeleted)1382 hostParallelCallback(OpenMPIRBuilder *OMPIRBuilder, Function &OutlinedFn,
1383                      Function *OuterFn, Value *Ident, Value *IfCondition,
1384                      Instruction *PrivTID, AllocaInst *PrivTIDAddr,
1385                      const SmallVector<Instruction *, 4> &ToBeDeleted) {
1386   IRBuilder<> &Builder = OMPIRBuilder->Builder;
1387   FunctionCallee RTLFn;
1388   if (IfCondition) {
1389     RTLFn =
1390         OMPIRBuilder->getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_call_if);
1391   } else {
1392     RTLFn =
1393         OMPIRBuilder->getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_call);
1394   }
1395   if (auto *F = dyn_cast<Function>(RTLFn.getCallee())) {
1396     if (!F->hasMetadata(LLVMContext::MD_callback)) {
1397       LLVMContext &Ctx = F->getContext();
1398       MDBuilder MDB(Ctx);
1399       // Annotate the callback behavior of the __kmpc_fork_call:
1400       //  - The callback callee is argument number 2 (microtask).
1401       //  - The first two arguments of the callback callee are unknown (-1).
1402       //  - All variadic arguments to the __kmpc_fork_call are passed to the
1403       //    callback callee.
1404       F->addMetadata(LLVMContext::MD_callback,
1405                      *MDNode::get(Ctx, {MDB.createCallbackEncoding(
1406                                            2, {-1, -1},
1407                                            /* VarArgsArePassed */ true)}));
1408     }
1409   }
1410   // Add some known attributes.
1411   OutlinedFn.addParamAttr(0, Attribute::NoAlias);
1412   OutlinedFn.addParamAttr(1, Attribute::NoAlias);
1413   OutlinedFn.addFnAttr(Attribute::NoUnwind);
1414 
1415   assert(OutlinedFn.arg_size() >= 2 &&
1416          "Expected at least tid and bounded tid as arguments");
1417   unsigned NumCapturedVars = OutlinedFn.arg_size() - /* tid & bounded tid */ 2;
1418 
1419   CallInst *CI = cast<CallInst>(OutlinedFn.user_back());
1420   CI->getParent()->setName("omp_parallel");
1421   Builder.SetInsertPoint(CI);
1422 
1423   // Build call __kmpc_fork_call[_if](Ident, n, microtask, var1, .., varn);
1424   Value *ForkCallArgs[] = {Ident, Builder.getInt32(NumCapturedVars),
1425                            &OutlinedFn};
1426 
1427   SmallVector<Value *, 16> RealArgs;
1428   RealArgs.append(std::begin(ForkCallArgs), std::end(ForkCallArgs));
1429   if (IfCondition) {
1430     Value *Cond = Builder.CreateSExtOrTrunc(IfCondition, OMPIRBuilder->Int32);
1431     RealArgs.push_back(Cond);
1432   }
1433   RealArgs.append(CI->arg_begin() + /* tid & bound tid */ 2, CI->arg_end());
1434 
1435   // __kmpc_fork_call_if always expects a void ptr as the last argument
1436   // If there are no arguments, pass a null pointer.
1437   auto PtrTy = OMPIRBuilder->VoidPtr;
1438   if (IfCondition && NumCapturedVars == 0) {
1439     Value *NullPtrValue = Constant::getNullValue(PtrTy);
1440     RealArgs.push_back(NullPtrValue);
1441   }
1442 
1443   Builder.CreateCall(RTLFn, RealArgs);
1444 
1445   LLVM_DEBUG(dbgs() << "With fork_call placed: "
1446                     << *Builder.GetInsertBlock()->getParent() << "\n");
1447 
1448   // Initialize the local TID stack location with the argument value.
1449   Builder.SetInsertPoint(PrivTID);
1450   Function::arg_iterator OutlinedAI = OutlinedFn.arg_begin();
1451   Builder.CreateStore(Builder.CreateLoad(OMPIRBuilder->Int32, OutlinedAI),
1452                       PrivTIDAddr);
1453 
1454   // Remove redundant call to the outlined function.
1455   CI->eraseFromParent();
1456 
1457   for (Instruction *I : ToBeDeleted) {
1458     I->eraseFromParent();
1459   }
1460 }
1461 
createParallel(const LocationDescription & Loc,InsertPointTy OuterAllocaIP,BodyGenCallbackTy BodyGenCB,PrivatizeCallbackTy PrivCB,FinalizeCallbackTy FiniCB,Value * IfCondition,Value * NumThreads,omp::ProcBindKind ProcBind,bool IsCancellable)1462 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createParallel(
1463     const LocationDescription &Loc, InsertPointTy OuterAllocaIP,
1464     BodyGenCallbackTy BodyGenCB, PrivatizeCallbackTy PrivCB,
1465     FinalizeCallbackTy FiniCB, Value *IfCondition, Value *NumThreads,
1466     omp::ProcBindKind ProcBind, bool IsCancellable) {
1467   assert(!isConflictIP(Loc.IP, OuterAllocaIP) && "IPs must not be ambiguous");
1468 
1469   if (!updateToLocation(Loc))
1470     return Loc.IP;
1471 
1472   uint32_t SrcLocStrSize;
1473   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1474   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1475   Value *ThreadID = getOrCreateThreadID(Ident);
1476   // If we generate code for the target device, we need to allocate
1477   // struct for aggregate params in the device default alloca address space.
1478   // OpenMP runtime requires that the params of the extracted functions are
1479   // passed as zero address space pointers. This flag ensures that extracted
1480   // function arguments are declared in zero address space
1481   bool ArgsInZeroAddressSpace = Config.isTargetDevice();
1482 
1483   // Build call __kmpc_push_num_threads(&Ident, global_tid, num_threads)
1484   // only if we compile for host side.
1485   if (NumThreads && !Config.isTargetDevice()) {
1486     Value *Args[] = {
1487         Ident, ThreadID,
1488         Builder.CreateIntCast(NumThreads, Int32, /*isSigned*/ false)};
1489     Builder.CreateCall(
1490         getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_num_threads), Args);
1491   }
1492 
1493   if (ProcBind != OMP_PROC_BIND_default) {
1494     // Build call __kmpc_push_proc_bind(&Ident, global_tid, proc_bind)
1495     Value *Args[] = {
1496         Ident, ThreadID,
1497         ConstantInt::get(Int32, unsigned(ProcBind), /*isSigned=*/true)};
1498     Builder.CreateCall(
1499         getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_proc_bind), Args);
1500   }
1501 
1502   BasicBlock *InsertBB = Builder.GetInsertBlock();
1503   Function *OuterFn = InsertBB->getParent();
1504 
1505   // Save the outer alloca block because the insertion iterator may get
1506   // invalidated and we still need this later.
1507   BasicBlock *OuterAllocaBlock = OuterAllocaIP.getBlock();
1508 
1509   // Vector to remember instructions we used only during the modeling but which
1510   // we want to delete at the end.
1511   SmallVector<Instruction *, 4> ToBeDeleted;
1512 
1513   // Change the location to the outer alloca insertion point to create and
1514   // initialize the allocas we pass into the parallel region.
1515   InsertPointTy NewOuter(OuterAllocaBlock, OuterAllocaBlock->begin());
1516   Builder.restoreIP(NewOuter);
1517   AllocaInst *TIDAddrAlloca = Builder.CreateAlloca(Int32, nullptr, "tid.addr");
1518   AllocaInst *ZeroAddrAlloca =
1519       Builder.CreateAlloca(Int32, nullptr, "zero.addr");
1520   Instruction *TIDAddr = TIDAddrAlloca;
1521   Instruction *ZeroAddr = ZeroAddrAlloca;
1522   if (ArgsInZeroAddressSpace && M.getDataLayout().getAllocaAddrSpace() != 0) {
1523     // Add additional casts to enforce pointers in zero address space
1524     TIDAddr = new AddrSpaceCastInst(
1525         TIDAddrAlloca, PointerType ::get(M.getContext(), 0), "tid.addr.ascast");
1526     TIDAddr->insertAfter(TIDAddrAlloca->getIterator());
1527     ToBeDeleted.push_back(TIDAddr);
1528     ZeroAddr = new AddrSpaceCastInst(ZeroAddrAlloca,
1529                                      PointerType ::get(M.getContext(), 0),
1530                                      "zero.addr.ascast");
1531     ZeroAddr->insertAfter(ZeroAddrAlloca->getIterator());
1532     ToBeDeleted.push_back(ZeroAddr);
1533   }
1534 
1535   // We only need TIDAddr and ZeroAddr for modeling purposes to get the
1536   // associated arguments in the outlined function, so we delete them later.
1537   ToBeDeleted.push_back(TIDAddrAlloca);
1538   ToBeDeleted.push_back(ZeroAddrAlloca);
1539 
1540   // Create an artificial insertion point that will also ensure the blocks we
1541   // are about to split are not degenerated.
1542   auto *UI = new UnreachableInst(Builder.getContext(), InsertBB);
1543 
1544   BasicBlock *EntryBB = UI->getParent();
1545   BasicBlock *PRegEntryBB = EntryBB->splitBasicBlock(UI, "omp.par.entry");
1546   BasicBlock *PRegBodyBB = PRegEntryBB->splitBasicBlock(UI, "omp.par.region");
1547   BasicBlock *PRegPreFiniBB =
1548       PRegBodyBB->splitBasicBlock(UI, "omp.par.pre_finalize");
1549   BasicBlock *PRegExitBB = PRegPreFiniBB->splitBasicBlock(UI, "omp.par.exit");
1550 
1551   auto FiniCBWrapper = [&](InsertPointTy IP) {
1552     // Hide "open-ended" blocks from the given FiniCB by setting the right jump
1553     // target to the region exit block.
1554     if (IP.getBlock()->end() == IP.getPoint()) {
1555       IRBuilder<>::InsertPointGuard IPG(Builder);
1556       Builder.restoreIP(IP);
1557       Instruction *I = Builder.CreateBr(PRegExitBB);
1558       IP = InsertPointTy(I->getParent(), I->getIterator());
1559     }
1560     assert(IP.getBlock()->getTerminator()->getNumSuccessors() == 1 &&
1561            IP.getBlock()->getTerminator()->getSuccessor(0) == PRegExitBB &&
1562            "Unexpected insertion point for finalization call!");
1563     return FiniCB(IP);
1564   };
1565 
1566   FinalizationStack.push_back({FiniCBWrapper, OMPD_parallel, IsCancellable});
1567 
1568   // Generate the privatization allocas in the block that will become the entry
1569   // of the outlined function.
1570   Builder.SetInsertPoint(PRegEntryBB->getTerminator());
1571   InsertPointTy InnerAllocaIP = Builder.saveIP();
1572 
1573   AllocaInst *PrivTIDAddr =
1574       Builder.CreateAlloca(Int32, nullptr, "tid.addr.local");
1575   Instruction *PrivTID = Builder.CreateLoad(Int32, PrivTIDAddr, "tid");
1576 
1577   // Add some fake uses for OpenMP provided arguments.
1578   ToBeDeleted.push_back(Builder.CreateLoad(Int32, TIDAddr, "tid.addr.use"));
1579   Instruction *ZeroAddrUse =
1580       Builder.CreateLoad(Int32, ZeroAddr, "zero.addr.use");
1581   ToBeDeleted.push_back(ZeroAddrUse);
1582 
1583   // EntryBB
1584   //   |
1585   //   V
1586   // PRegionEntryBB         <- Privatization allocas are placed here.
1587   //   |
1588   //   V
1589   // PRegionBodyBB          <- BodeGen is invoked here.
1590   //   |
1591   //   V
1592   // PRegPreFiniBB          <- The block we will start finalization from.
1593   //   |
1594   //   V
1595   // PRegionExitBB          <- A common exit to simplify block collection.
1596   //
1597 
1598   LLVM_DEBUG(dbgs() << "Before body codegen: " << *OuterFn << "\n");
1599 
1600   // Let the caller create the body.
1601   assert(BodyGenCB && "Expected body generation callback!");
1602   InsertPointTy CodeGenIP(PRegBodyBB, PRegBodyBB->begin());
1603   if (Error Err = BodyGenCB(InnerAllocaIP, CodeGenIP))
1604     return Err;
1605 
1606   LLVM_DEBUG(dbgs() << "After  body codegen: " << *OuterFn << "\n");
1607 
1608   OutlineInfo OI;
1609   if (Config.isTargetDevice()) {
1610     // Generate OpenMP target specific runtime call
1611     OI.PostOutlineCB = [=, ToBeDeletedVec =
1612                                std::move(ToBeDeleted)](Function &OutlinedFn) {
1613       targetParallelCallback(this, OutlinedFn, OuterFn, OuterAllocaBlock, Ident,
1614                              IfCondition, NumThreads, PrivTID, PrivTIDAddr,
1615                              ThreadID, ToBeDeletedVec);
1616     };
1617   } else {
1618     // Generate OpenMP host runtime call
1619     OI.PostOutlineCB = [=, ToBeDeletedVec =
1620                                std::move(ToBeDeleted)](Function &OutlinedFn) {
1621       hostParallelCallback(this, OutlinedFn, OuterFn, Ident, IfCondition,
1622                            PrivTID, PrivTIDAddr, ToBeDeletedVec);
1623     };
1624   }
1625 
1626   OI.OuterAllocaBB = OuterAllocaBlock;
1627   OI.EntryBB = PRegEntryBB;
1628   OI.ExitBB = PRegExitBB;
1629 
1630   SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
1631   SmallVector<BasicBlock *, 32> Blocks;
1632   OI.collectBlocks(ParallelRegionBlockSet, Blocks);
1633 
1634   CodeExtractorAnalysisCache CEAC(*OuterFn);
1635   CodeExtractor Extractor(Blocks, /* DominatorTree */ nullptr,
1636                           /* AggregateArgs */ false,
1637                           /* BlockFrequencyInfo */ nullptr,
1638                           /* BranchProbabilityInfo */ nullptr,
1639                           /* AssumptionCache */ nullptr,
1640                           /* AllowVarArgs */ true,
1641                           /* AllowAlloca */ true,
1642                           /* AllocationBlock */ OuterAllocaBlock,
1643                           /* Suffix */ ".omp_par", ArgsInZeroAddressSpace);
1644 
1645   // Find inputs to, outputs from the code region.
1646   BasicBlock *CommonExit = nullptr;
1647   SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
1648   Extractor.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
1649 
1650   Extractor.findInputsOutputs(Inputs, Outputs, SinkingCands,
1651                               /*CollectGlobalInputs=*/true);
1652 
1653   Inputs.remove_if([&](Value *I) {
1654     if (auto *GV = dyn_cast_if_present<GlobalVariable>(I))
1655       return GV->getValueType() == OpenMPIRBuilder::Ident;
1656 
1657     return false;
1658   });
1659 
1660   LLVM_DEBUG(dbgs() << "Before privatization: " << *OuterFn << "\n");
1661 
1662   FunctionCallee TIDRTLFn =
1663       getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_global_thread_num);
1664 
1665   auto PrivHelper = [&](Value &V) -> Error {
1666     if (&V == TIDAddr || &V == ZeroAddr) {
1667       OI.ExcludeArgsFromAggregate.push_back(&V);
1668       return Error::success();
1669     }
1670 
1671     SetVector<Use *> Uses;
1672     for (Use &U : V.uses())
1673       if (auto *UserI = dyn_cast<Instruction>(U.getUser()))
1674         if (ParallelRegionBlockSet.count(UserI->getParent()))
1675           Uses.insert(&U);
1676 
1677     // __kmpc_fork_call expects extra arguments as pointers. If the input
1678     // already has a pointer type, everything is fine. Otherwise, store the
1679     // value onto stack and load it back inside the to-be-outlined region. This
1680     // will ensure only the pointer will be passed to the function.
1681     // FIXME: if there are more than 15 trailing arguments, they must be
1682     // additionally packed in a struct.
1683     Value *Inner = &V;
1684     if (!V.getType()->isPointerTy()) {
1685       IRBuilder<>::InsertPointGuard Guard(Builder);
1686       LLVM_DEBUG(llvm::dbgs() << "Forwarding input as pointer: " << V << "\n");
1687 
1688       Builder.restoreIP(OuterAllocaIP);
1689       Value *Ptr =
1690           Builder.CreateAlloca(V.getType(), nullptr, V.getName() + ".reloaded");
1691 
1692       // Store to stack at end of the block that currently branches to the entry
1693       // block of the to-be-outlined region.
1694       Builder.SetInsertPoint(InsertBB,
1695                              InsertBB->getTerminator()->getIterator());
1696       Builder.CreateStore(&V, Ptr);
1697 
1698       // Load back next to allocations in the to-be-outlined region.
1699       Builder.restoreIP(InnerAllocaIP);
1700       Inner = Builder.CreateLoad(V.getType(), Ptr);
1701     }
1702 
1703     Value *ReplacementValue = nullptr;
1704     CallInst *CI = dyn_cast<CallInst>(&V);
1705     if (CI && CI->getCalledFunction() == TIDRTLFn.getCallee()) {
1706       ReplacementValue = PrivTID;
1707     } else {
1708       InsertPointOrErrorTy AfterIP =
1709           PrivCB(InnerAllocaIP, Builder.saveIP(), V, *Inner, ReplacementValue);
1710       if (!AfterIP)
1711         return AfterIP.takeError();
1712       Builder.restoreIP(*AfterIP);
1713       InnerAllocaIP = {
1714           InnerAllocaIP.getBlock(),
1715           InnerAllocaIP.getBlock()->getTerminator()->getIterator()};
1716 
1717       assert(ReplacementValue &&
1718              "Expected copy/create callback to set replacement value!");
1719       if (ReplacementValue == &V)
1720         return Error::success();
1721     }
1722 
1723     for (Use *UPtr : Uses)
1724       UPtr->set(ReplacementValue);
1725 
1726     return Error::success();
1727   };
1728 
1729   // Reset the inner alloca insertion as it will be used for loading the values
1730   // wrapped into pointers before passing them into the to-be-outlined region.
1731   // Configure it to insert immediately after the fake use of zero address so
1732   // that they are available in the generated body and so that the
1733   // OpenMP-related values (thread ID and zero address pointers) remain leading
1734   // in the argument list.
1735   InnerAllocaIP = IRBuilder<>::InsertPoint(
1736       ZeroAddrUse->getParent(), ZeroAddrUse->getNextNode()->getIterator());
1737 
1738   // Reset the outer alloca insertion point to the entry of the relevant block
1739   // in case it was invalidated.
1740   OuterAllocaIP = IRBuilder<>::InsertPoint(
1741       OuterAllocaBlock, OuterAllocaBlock->getFirstInsertionPt());
1742 
1743   for (Value *Input : Inputs) {
1744     LLVM_DEBUG(dbgs() << "Captured input: " << *Input << "\n");
1745     if (Error Err = PrivHelper(*Input))
1746       return Err;
1747   }
1748   LLVM_DEBUG({
1749     for (Value *Output : Outputs)
1750       LLVM_DEBUG(dbgs() << "Captured output: " << *Output << "\n");
1751   });
1752   assert(Outputs.empty() &&
1753          "OpenMP outlining should not produce live-out values!");
1754 
1755   LLVM_DEBUG(dbgs() << "After  privatization: " << *OuterFn << "\n");
1756   LLVM_DEBUG({
1757     for (auto *BB : Blocks)
1758       dbgs() << " PBR: " << BB->getName() << "\n";
1759   });
1760 
1761   // Adjust the finalization stack, verify the adjustment, and call the
1762   // finalize function a last time to finalize values between the pre-fini
1763   // block and the exit block if we left the parallel "the normal way".
1764   auto FiniInfo = FinalizationStack.pop_back_val();
1765   (void)FiniInfo;
1766   assert(FiniInfo.DK == OMPD_parallel &&
1767          "Unexpected finalization stack state!");
1768 
1769   Instruction *PRegPreFiniTI = PRegPreFiniBB->getTerminator();
1770 
1771   InsertPointTy PreFiniIP(PRegPreFiniBB, PRegPreFiniTI->getIterator());
1772   if (Error Err = FiniCB(PreFiniIP))
1773     return Err;
1774 
1775   // Register the outlined info.
1776   addOutlineInfo(std::move(OI));
1777 
1778   InsertPointTy AfterIP(UI->getParent(), UI->getParent()->end());
1779   UI->eraseFromParent();
1780 
1781   return AfterIP;
1782 }
1783 
emitFlush(const LocationDescription & Loc)1784 void OpenMPIRBuilder::emitFlush(const LocationDescription &Loc) {
1785   // Build call void __kmpc_flush(ident_t *loc)
1786   uint32_t SrcLocStrSize;
1787   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1788   Value *Args[] = {getOrCreateIdent(SrcLocStr, SrcLocStrSize)};
1789 
1790   Builder.CreateCall(getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_flush), Args);
1791 }
1792 
createFlush(const LocationDescription & Loc)1793 void OpenMPIRBuilder::createFlush(const LocationDescription &Loc) {
1794   if (!updateToLocation(Loc))
1795     return;
1796   emitFlush(Loc);
1797 }
1798 
emitTaskwaitImpl(const LocationDescription & Loc)1799 void OpenMPIRBuilder::emitTaskwaitImpl(const LocationDescription &Loc) {
1800   // Build call kmp_int32 __kmpc_omp_taskwait(ident_t *loc, kmp_int32
1801   // global_tid);
1802   uint32_t SrcLocStrSize;
1803   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1804   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1805   Value *Args[] = {Ident, getOrCreateThreadID(Ident)};
1806 
1807   // Ignore return result until untied tasks are supported.
1808   Builder.CreateCall(getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_taskwait),
1809                      Args);
1810 }
1811 
createTaskwait(const LocationDescription & Loc)1812 void OpenMPIRBuilder::createTaskwait(const LocationDescription &Loc) {
1813   if (!updateToLocation(Loc))
1814     return;
1815   emitTaskwaitImpl(Loc);
1816 }
1817 
emitTaskyieldImpl(const LocationDescription & Loc)1818 void OpenMPIRBuilder::emitTaskyieldImpl(const LocationDescription &Loc) {
1819   // Build call __kmpc_omp_taskyield(loc, thread_id, 0);
1820   uint32_t SrcLocStrSize;
1821   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1822   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1823   Constant *I32Null = ConstantInt::getNullValue(Int32);
1824   Value *Args[] = {Ident, getOrCreateThreadID(Ident), I32Null};
1825 
1826   Builder.CreateCall(getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_taskyield),
1827                      Args);
1828 }
1829 
createTaskyield(const LocationDescription & Loc)1830 void OpenMPIRBuilder::createTaskyield(const LocationDescription &Loc) {
1831   if (!updateToLocation(Loc))
1832     return;
1833   emitTaskyieldImpl(Loc);
1834 }
1835 
1836 // Processes the dependencies in Dependencies and does the following
1837 // - Allocates space on the stack of an array of DependInfo objects
1838 // - Populates each DependInfo object with relevant information of
1839 //   the corresponding dependence.
1840 // - All code is inserted in the entry block of the current function.
emitTaskDependencies(OpenMPIRBuilder & OMPBuilder,const SmallVectorImpl<OpenMPIRBuilder::DependData> & Dependencies)1841 static Value *emitTaskDependencies(
1842     OpenMPIRBuilder &OMPBuilder,
1843     const SmallVectorImpl<OpenMPIRBuilder::DependData> &Dependencies) {
1844   // Early return if we have no dependencies to process
1845   if (Dependencies.empty())
1846     return nullptr;
1847 
1848   // Given a vector of DependData objects, in this function we create an
1849   // array on the stack that holds kmp_dep_info objects corresponding
1850   // to each dependency. This is then passed to the OpenMP runtime.
1851   // For example, if there are 'n' dependencies then the following psedo
1852   // code is generated. Assume the first dependence is on a variable 'a'
1853   //
1854   // \code{c}
1855   // DepArray = alloc(n x sizeof(kmp_depend_info);
1856   // idx = 0;
1857   // DepArray[idx].base_addr = ptrtoint(&a);
1858   // DepArray[idx].len = 8;
1859   // DepArray[idx].flags = Dep.DepKind; /*(See OMPContants.h for DepKind)*/
1860   // ++idx;
1861   // DepArray[idx].base_addr = ...;
1862   // \endcode
1863 
1864   IRBuilderBase &Builder = OMPBuilder.Builder;
1865   Type *DependInfo = OMPBuilder.DependInfo;
1866   Module &M = OMPBuilder.M;
1867 
1868   Value *DepArray = nullptr;
1869   OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
1870   Builder.SetInsertPoint(
1871       OldIP.getBlock()->getParent()->getEntryBlock().getTerminator());
1872 
1873   Type *DepArrayTy = ArrayType::get(DependInfo, Dependencies.size());
1874   DepArray = Builder.CreateAlloca(DepArrayTy, nullptr, ".dep.arr.addr");
1875 
1876   Builder.restoreIP(OldIP);
1877 
1878   for (const auto &[DepIdx, Dep] : enumerate(Dependencies)) {
1879     Value *Base =
1880         Builder.CreateConstInBoundsGEP2_64(DepArrayTy, DepArray, 0, DepIdx);
1881     // Store the pointer to the variable
1882     Value *Addr = Builder.CreateStructGEP(
1883         DependInfo, Base,
1884         static_cast<unsigned int>(RTLDependInfoFields::BaseAddr));
1885     Value *DepValPtr = Builder.CreatePtrToInt(Dep.DepVal, Builder.getInt64Ty());
1886     Builder.CreateStore(DepValPtr, Addr);
1887     // Store the size of the variable
1888     Value *Size = Builder.CreateStructGEP(
1889         DependInfo, Base, static_cast<unsigned int>(RTLDependInfoFields::Len));
1890     Builder.CreateStore(
1891         Builder.getInt64(M.getDataLayout().getTypeStoreSize(Dep.DepValueType)),
1892         Size);
1893     // Store the dependency kind
1894     Value *Flags = Builder.CreateStructGEP(
1895         DependInfo, Base,
1896         static_cast<unsigned int>(RTLDependInfoFields::Flags));
1897     Builder.CreateStore(
1898         ConstantInt::get(Builder.getInt8Ty(),
1899                          static_cast<unsigned int>(Dep.DepKind)),
1900         Flags);
1901   }
1902   return DepArray;
1903 }
1904 
createTask(const LocationDescription & Loc,InsertPointTy AllocaIP,BodyGenCallbackTy BodyGenCB,bool Tied,Value * Final,Value * IfCondition,SmallVector<DependData> Dependencies,bool Mergeable,Value * EventHandle,Value * Priority)1905 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTask(
1906     const LocationDescription &Loc, InsertPointTy AllocaIP,
1907     BodyGenCallbackTy BodyGenCB, bool Tied, Value *Final, Value *IfCondition,
1908     SmallVector<DependData> Dependencies, bool Mergeable, Value *EventHandle,
1909     Value *Priority) {
1910 
1911   if (!updateToLocation(Loc))
1912     return InsertPointTy();
1913 
1914   uint32_t SrcLocStrSize;
1915   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1916   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1917   // The current basic block is split into four basic blocks. After outlining,
1918   // they will be mapped as follows:
1919   // ```
1920   // def current_fn() {
1921   //   current_basic_block:
1922   //     br label %task.exit
1923   //   task.exit:
1924   //     ; instructions after task
1925   // }
1926   // def outlined_fn() {
1927   //   task.alloca:
1928   //     br label %task.body
1929   //   task.body:
1930   //     ret void
1931   // }
1932   // ```
1933   BasicBlock *TaskExitBB = splitBB(Builder, /*CreateBranch=*/true, "task.exit");
1934   BasicBlock *TaskBodyBB = splitBB(Builder, /*CreateBranch=*/true, "task.body");
1935   BasicBlock *TaskAllocaBB =
1936       splitBB(Builder, /*CreateBranch=*/true, "task.alloca");
1937 
1938   InsertPointTy TaskAllocaIP =
1939       InsertPointTy(TaskAllocaBB, TaskAllocaBB->begin());
1940   InsertPointTy TaskBodyIP = InsertPointTy(TaskBodyBB, TaskBodyBB->begin());
1941   if (Error Err = BodyGenCB(TaskAllocaIP, TaskBodyIP))
1942     return Err;
1943 
1944   OutlineInfo OI;
1945   OI.EntryBB = TaskAllocaBB;
1946   OI.OuterAllocaBB = AllocaIP.getBlock();
1947   OI.ExitBB = TaskExitBB;
1948 
1949   // Add the thread ID argument.
1950   SmallVector<Instruction *, 4> ToBeDeleted;
1951   OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
1952       Builder, AllocaIP, ToBeDeleted, TaskAllocaIP, "global.tid", false));
1953 
1954   OI.PostOutlineCB = [this, Ident, Tied, Final, IfCondition, Dependencies,
1955                       Mergeable, Priority, EventHandle, TaskAllocaBB,
1956                       ToBeDeleted](Function &OutlinedFn) mutable {
1957     // Replace the Stale CI by appropriate RTL function call.
1958     assert(OutlinedFn.hasOneUse() &&
1959            "there must be a single user for the outlined function");
1960     CallInst *StaleCI = cast<CallInst>(OutlinedFn.user_back());
1961 
1962     // HasShareds is true if any variables are captured in the outlined region,
1963     // false otherwise.
1964     bool HasShareds = StaleCI->arg_size() > 1;
1965     Builder.SetInsertPoint(StaleCI);
1966 
1967     // Gather the arguments for emitting the runtime call for
1968     // @__kmpc_omp_task_alloc
1969     Function *TaskAllocFn =
1970         getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_alloc);
1971 
1972     // Arguments - `loc_ref` (Ident) and `gtid` (ThreadID)
1973     // call.
1974     Value *ThreadID = getOrCreateThreadID(Ident);
1975 
1976     // Argument - `flags`
1977     // Task is tied iff (Flags & 1) == 1.
1978     // Task is untied iff (Flags & 1) == 0.
1979     // Task is final iff (Flags & 2) == 2.
1980     // Task is not final iff (Flags & 2) == 0.
1981     // Task is mergeable iff (Flags & 4) == 4.
1982     // Task is not mergeable iff (Flags & 4) == 0.
1983     // Task is priority iff (Flags & 32) == 32.
1984     // Task is not priority iff (Flags & 32) == 0.
1985     // TODO: Handle the other flags.
1986     Value *Flags = Builder.getInt32(Tied);
1987     if (Final) {
1988       Value *FinalFlag =
1989           Builder.CreateSelect(Final, Builder.getInt32(2), Builder.getInt32(0));
1990       Flags = Builder.CreateOr(FinalFlag, Flags);
1991     }
1992 
1993     if (Mergeable)
1994       Flags = Builder.CreateOr(Builder.getInt32(4), Flags);
1995     if (Priority)
1996       Flags = Builder.CreateOr(Builder.getInt32(32), Flags);
1997 
1998     // Argument - `sizeof_kmp_task_t` (TaskSize)
1999     // Tasksize refers to the size in bytes of kmp_task_t data structure
2000     // including private vars accessed in task.
2001     // TODO: add kmp_task_t_with_privates (privates)
2002     Value *TaskSize = Builder.getInt64(
2003         divideCeil(M.getDataLayout().getTypeSizeInBits(Task), 8));
2004 
2005     // Argument - `sizeof_shareds` (SharedsSize)
2006     // SharedsSize refers to the shareds array size in the kmp_task_t data
2007     // structure.
2008     Value *SharedsSize = Builder.getInt64(0);
2009     if (HasShareds) {
2010       AllocaInst *ArgStructAlloca =
2011           dyn_cast<AllocaInst>(StaleCI->getArgOperand(1));
2012       assert(ArgStructAlloca &&
2013              "Unable to find the alloca instruction corresponding to arguments "
2014              "for extracted function");
2015       StructType *ArgStructType =
2016           dyn_cast<StructType>(ArgStructAlloca->getAllocatedType());
2017       assert(ArgStructType && "Unable to find struct type corresponding to "
2018                               "arguments for extracted function");
2019       SharedsSize =
2020           Builder.getInt64(M.getDataLayout().getTypeStoreSize(ArgStructType));
2021     }
2022     // Emit the @__kmpc_omp_task_alloc runtime call
2023     // The runtime call returns a pointer to an area where the task captured
2024     // variables must be copied before the task is run (TaskData)
2025     CallInst *TaskData = Builder.CreateCall(
2026         TaskAllocFn, {/*loc_ref=*/Ident, /*gtid=*/ThreadID, /*flags=*/Flags,
2027                       /*sizeof_task=*/TaskSize, /*sizeof_shared=*/SharedsSize,
2028                       /*task_func=*/&OutlinedFn});
2029 
2030     // Emit detach clause initialization.
2031     // evt = (typeof(evt))__kmpc_task_allow_completion_event(loc, tid,
2032     // task_descriptor);
2033     if (EventHandle) {
2034       Function *TaskDetachFn = getOrCreateRuntimeFunctionPtr(
2035           OMPRTL___kmpc_task_allow_completion_event);
2036       llvm::Value *EventVal =
2037           Builder.CreateCall(TaskDetachFn, {Ident, ThreadID, TaskData});
2038       llvm::Value *EventHandleAddr =
2039           Builder.CreatePointerBitCastOrAddrSpaceCast(EventHandle,
2040                                                       Builder.getPtrTy(0));
2041       EventVal = Builder.CreatePtrToInt(EventVal, Builder.getInt64Ty());
2042       Builder.CreateStore(EventVal, EventHandleAddr);
2043     }
2044     // Copy the arguments for outlined function
2045     if (HasShareds) {
2046       Value *Shareds = StaleCI->getArgOperand(1);
2047       Align Alignment = TaskData->getPointerAlignment(M.getDataLayout());
2048       Value *TaskShareds = Builder.CreateLoad(VoidPtr, TaskData);
2049       Builder.CreateMemCpy(TaskShareds, Alignment, Shareds, Alignment,
2050                            SharedsSize);
2051     }
2052 
2053     if (Priority) {
2054       //
2055       // The return type of "__kmpc_omp_task_alloc" is "kmp_task_t *",
2056       // we populate the priority information into the "kmp_task_t" here
2057       //
2058       // The struct "kmp_task_t" definition is available in kmp.h
2059       // kmp_task_t = { shareds, routine, part_id, data1, data2 }
2060       // data2 is used for priority
2061       //
2062       Type *Int32Ty = Builder.getInt32Ty();
2063       Constant *Zero = ConstantInt::get(Int32Ty, 0);
2064       // kmp_task_t* => { ptr }
2065       Type *TaskPtr = StructType::get(VoidPtr);
2066       Value *TaskGEP =
2067           Builder.CreateInBoundsGEP(TaskPtr, TaskData, {Zero, Zero});
2068       // kmp_task_t => { ptr, ptr, i32, ptr, ptr }
2069       Type *TaskStructType = StructType::get(
2070           VoidPtr, VoidPtr, Builder.getInt32Ty(), VoidPtr, VoidPtr);
2071       Value *PriorityData = Builder.CreateInBoundsGEP(
2072           TaskStructType, TaskGEP, {Zero, ConstantInt::get(Int32Ty, 4)});
2073       // kmp_cmplrdata_t => { ptr, ptr }
2074       Type *CmplrStructType = StructType::get(VoidPtr, VoidPtr);
2075       Value *CmplrData = Builder.CreateInBoundsGEP(CmplrStructType,
2076                                                    PriorityData, {Zero, Zero});
2077       Builder.CreateStore(Priority, CmplrData);
2078     }
2079 
2080     Value *DepArray = emitTaskDependencies(*this, Dependencies);
2081 
2082     // In the presence of the `if` clause, the following IR is generated:
2083     //    ...
2084     //    %data = call @__kmpc_omp_task_alloc(...)
2085     //    br i1 %if_condition, label %then, label %else
2086     //  then:
2087     //    call @__kmpc_omp_task(...)
2088     //    br label %exit
2089     //  else:
2090     //    ;; Wait for resolution of dependencies, if any, before
2091     //    ;; beginning the task
2092     //    call @__kmpc_omp_wait_deps(...)
2093     //    call @__kmpc_omp_task_begin_if0(...)
2094     //    call @outlined_fn(...)
2095     //    call @__kmpc_omp_task_complete_if0(...)
2096     //    br label %exit
2097     //  exit:
2098     //    ...
2099     if (IfCondition) {
2100       // `SplitBlockAndInsertIfThenElse` requires the block to have a
2101       // terminator.
2102       splitBB(Builder, /*CreateBranch=*/true, "if.end");
2103       Instruction *IfTerminator =
2104           Builder.GetInsertPoint()->getParent()->getTerminator();
2105       Instruction *ThenTI = IfTerminator, *ElseTI = nullptr;
2106       Builder.SetInsertPoint(IfTerminator);
2107       SplitBlockAndInsertIfThenElse(IfCondition, IfTerminator, &ThenTI,
2108                                     &ElseTI);
2109       Builder.SetInsertPoint(ElseTI);
2110 
2111       if (Dependencies.size()) {
2112         Function *TaskWaitFn =
2113             getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_wait_deps);
2114         Builder.CreateCall(
2115             TaskWaitFn,
2116             {Ident, ThreadID, Builder.getInt32(Dependencies.size()), DepArray,
2117              ConstantInt::get(Builder.getInt32Ty(), 0),
2118              ConstantPointerNull::get(PointerType::getUnqual(M.getContext()))});
2119       }
2120       Function *TaskBeginFn =
2121           getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_begin_if0);
2122       Function *TaskCompleteFn =
2123           getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_complete_if0);
2124       Builder.CreateCall(TaskBeginFn, {Ident, ThreadID, TaskData});
2125       CallInst *CI = nullptr;
2126       if (HasShareds)
2127         CI = Builder.CreateCall(&OutlinedFn, {ThreadID, TaskData});
2128       else
2129         CI = Builder.CreateCall(&OutlinedFn, {ThreadID});
2130       CI->setDebugLoc(StaleCI->getDebugLoc());
2131       Builder.CreateCall(TaskCompleteFn, {Ident, ThreadID, TaskData});
2132       Builder.SetInsertPoint(ThenTI);
2133     }
2134 
2135     if (Dependencies.size()) {
2136       Function *TaskFn =
2137           getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_with_deps);
2138       Builder.CreateCall(
2139           TaskFn,
2140           {Ident, ThreadID, TaskData, Builder.getInt32(Dependencies.size()),
2141            DepArray, ConstantInt::get(Builder.getInt32Ty(), 0),
2142            ConstantPointerNull::get(PointerType::getUnqual(M.getContext()))});
2143 
2144     } else {
2145       // Emit the @__kmpc_omp_task runtime call to spawn the task
2146       Function *TaskFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task);
2147       Builder.CreateCall(TaskFn, {Ident, ThreadID, TaskData});
2148     }
2149 
2150     StaleCI->eraseFromParent();
2151 
2152     Builder.SetInsertPoint(TaskAllocaBB, TaskAllocaBB->begin());
2153     if (HasShareds) {
2154       LoadInst *Shareds = Builder.CreateLoad(VoidPtr, OutlinedFn.getArg(1));
2155       OutlinedFn.getArg(1)->replaceUsesWithIf(
2156           Shareds, [Shareds](Use &U) { return U.getUser() != Shareds; });
2157     }
2158 
2159     for (Instruction *I : llvm::reverse(ToBeDeleted))
2160       I->eraseFromParent();
2161   };
2162 
2163   addOutlineInfo(std::move(OI));
2164   Builder.SetInsertPoint(TaskExitBB, TaskExitBB->begin());
2165 
2166   return Builder.saveIP();
2167 }
2168 
2169 OpenMPIRBuilder::InsertPointOrErrorTy
createTaskgroup(const LocationDescription & Loc,InsertPointTy AllocaIP,BodyGenCallbackTy BodyGenCB)2170 OpenMPIRBuilder::createTaskgroup(const LocationDescription &Loc,
2171                                  InsertPointTy AllocaIP,
2172                                  BodyGenCallbackTy BodyGenCB) {
2173   if (!updateToLocation(Loc))
2174     return InsertPointTy();
2175 
2176   uint32_t SrcLocStrSize;
2177   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
2178   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
2179   Value *ThreadID = getOrCreateThreadID(Ident);
2180 
2181   // Emit the @__kmpc_taskgroup runtime call to start the taskgroup
2182   Function *TaskgroupFn =
2183       getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_taskgroup);
2184   Builder.CreateCall(TaskgroupFn, {Ident, ThreadID});
2185 
2186   BasicBlock *TaskgroupExitBB = splitBB(Builder, true, "taskgroup.exit");
2187   if (Error Err = BodyGenCB(AllocaIP, Builder.saveIP()))
2188     return Err;
2189 
2190   Builder.SetInsertPoint(TaskgroupExitBB);
2191   // Emit the @__kmpc_end_taskgroup runtime call to end the taskgroup
2192   Function *EndTaskgroupFn =
2193       getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_taskgroup);
2194   Builder.CreateCall(EndTaskgroupFn, {Ident, ThreadID});
2195 
2196   return Builder.saveIP();
2197 }
2198 
createSections(const LocationDescription & Loc,InsertPointTy AllocaIP,ArrayRef<StorableBodyGenCallbackTy> SectionCBs,PrivatizeCallbackTy PrivCB,FinalizeCallbackTy FiniCB,bool IsCancellable,bool IsNowait)2199 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createSections(
2200     const LocationDescription &Loc, InsertPointTy AllocaIP,
2201     ArrayRef<StorableBodyGenCallbackTy> SectionCBs, PrivatizeCallbackTy PrivCB,
2202     FinalizeCallbackTy FiniCB, bool IsCancellable, bool IsNowait) {
2203   assert(!isConflictIP(AllocaIP, Loc.IP) && "Dedicated IP allocas required");
2204 
2205   if (!updateToLocation(Loc))
2206     return Loc.IP;
2207 
2208   // FiniCBWrapper needs to create a branch to the loop finalization block, but
2209   // this has not been created yet at some times when this callback runs.
2210   SmallVector<BranchInst *> CancellationBranches;
2211   auto FiniCBWrapper = [&](InsertPointTy IP) {
2212     if (IP.getBlock()->end() != IP.getPoint())
2213       return FiniCB(IP);
2214     // This must be done otherwise any nested constructs using FinalizeOMPRegion
2215     // will fail because that function requires the Finalization Basic Block to
2216     // have a terminator, which is already removed by EmitOMPRegionBody.
2217     // IP is currently at cancelation block.
2218     BranchInst *DummyBranch = Builder.CreateBr(IP.getBlock());
2219     IP = InsertPointTy(DummyBranch->getParent(), DummyBranch->getIterator());
2220     CancellationBranches.push_back(DummyBranch);
2221     return FiniCB(IP);
2222   };
2223 
2224   FinalizationStack.push_back({FiniCBWrapper, OMPD_sections, IsCancellable});
2225 
2226   // Each section is emitted as a switch case
2227   // Each finalization callback is handled from clang.EmitOMPSectionDirective()
2228   // -> OMP.createSection() which generates the IR for each section
2229   // Iterate through all sections and emit a switch construct:
2230   // switch (IV) {
2231   //   case 0:
2232   //     <SectionStmt[0]>;
2233   //     break;
2234   // ...
2235   //   case <NumSection> - 1:
2236   //     <SectionStmt[<NumSection> - 1]>;
2237   //     break;
2238   // }
2239   // ...
2240   // section_loop.after:
2241   // <FiniCB>;
2242   auto LoopBodyGenCB = [&](InsertPointTy CodeGenIP, Value *IndVar) -> Error {
2243     Builder.restoreIP(CodeGenIP);
2244     BasicBlock *Continue =
2245         splitBBWithSuffix(Builder, /*CreateBranch=*/false, ".sections.after");
2246     Function *CurFn = Continue->getParent();
2247     SwitchInst *SwitchStmt = Builder.CreateSwitch(IndVar, Continue);
2248 
2249     unsigned CaseNumber = 0;
2250     for (auto SectionCB : SectionCBs) {
2251       BasicBlock *CaseBB = BasicBlock::Create(
2252           M.getContext(), "omp_section_loop.body.case", CurFn, Continue);
2253       SwitchStmt->addCase(Builder.getInt32(CaseNumber), CaseBB);
2254       Builder.SetInsertPoint(CaseBB);
2255       BranchInst *CaseEndBr = Builder.CreateBr(Continue);
2256       if (Error Err = SectionCB(InsertPointTy(), {CaseEndBr->getParent(),
2257                                                   CaseEndBr->getIterator()}))
2258         return Err;
2259       CaseNumber++;
2260     }
2261     // remove the existing terminator from body BB since there can be no
2262     // terminators after switch/case
2263     return Error::success();
2264   };
2265   // Loop body ends here
2266   // LowerBound, UpperBound, and STride for createCanonicalLoop
2267   Type *I32Ty = Type::getInt32Ty(M.getContext());
2268   Value *LB = ConstantInt::get(I32Ty, 0);
2269   Value *UB = ConstantInt::get(I32Ty, SectionCBs.size());
2270   Value *ST = ConstantInt::get(I32Ty, 1);
2271   Expected<CanonicalLoopInfo *> LoopInfo = createCanonicalLoop(
2272       Loc, LoopBodyGenCB, LB, UB, ST, true, false, AllocaIP, "section_loop");
2273   if (!LoopInfo)
2274     return LoopInfo.takeError();
2275 
2276   InsertPointOrErrorTy WsloopIP =
2277       applyStaticWorkshareLoop(Loc.DL, *LoopInfo, AllocaIP,
2278                                WorksharingLoopType::ForStaticLoop, !IsNowait);
2279   if (!WsloopIP)
2280     return WsloopIP.takeError();
2281   InsertPointTy AfterIP = *WsloopIP;
2282 
2283   BasicBlock *LoopFini = AfterIP.getBlock()->getSinglePredecessor();
2284   assert(LoopFini && "Bad structure of static workshare loop finalization");
2285 
2286   // Apply the finalization callback in LoopAfterBB
2287   auto FiniInfo = FinalizationStack.pop_back_val();
2288   assert(FiniInfo.DK == OMPD_sections &&
2289          "Unexpected finalization stack state!");
2290   if (FinalizeCallbackTy &CB = FiniInfo.FiniCB) {
2291     Builder.restoreIP(AfterIP);
2292     BasicBlock *FiniBB =
2293         splitBBWithSuffix(Builder, /*CreateBranch=*/true, "sections.fini");
2294     if (Error Err = CB(Builder.saveIP()))
2295       return Err;
2296     AfterIP = {FiniBB, FiniBB->begin()};
2297   }
2298 
2299   // Now we can fix the dummy branch to point to the right place
2300   for (BranchInst *DummyBranch : CancellationBranches) {
2301     assert(DummyBranch->getNumSuccessors() == 1);
2302     DummyBranch->setSuccessor(0, LoopFini);
2303   }
2304 
2305   return AfterIP;
2306 }
2307 
2308 OpenMPIRBuilder::InsertPointOrErrorTy
createSection(const LocationDescription & Loc,BodyGenCallbackTy BodyGenCB,FinalizeCallbackTy FiniCB)2309 OpenMPIRBuilder::createSection(const LocationDescription &Loc,
2310                                BodyGenCallbackTy BodyGenCB,
2311                                FinalizeCallbackTy FiniCB) {
2312   if (!updateToLocation(Loc))
2313     return Loc.IP;
2314 
2315   auto FiniCBWrapper = [&](InsertPointTy IP) {
2316     if (IP.getBlock()->end() != IP.getPoint())
2317       return FiniCB(IP);
2318     // This must be done otherwise any nested constructs using FinalizeOMPRegion
2319     // will fail because that function requires the Finalization Basic Block to
2320     // have a terminator, which is already removed by EmitOMPRegionBody.
2321     // IP is currently at cancelation block.
2322     // We need to backtrack to the condition block to fetch
2323     // the exit block and create a branch from cancelation
2324     // to exit block.
2325     IRBuilder<>::InsertPointGuard IPG(Builder);
2326     Builder.restoreIP(IP);
2327     auto *CaseBB = Loc.IP.getBlock();
2328     auto *CondBB = CaseBB->getSinglePredecessor()->getSinglePredecessor();
2329     auto *ExitBB = CondBB->getTerminator()->getSuccessor(1);
2330     Instruction *I = Builder.CreateBr(ExitBB);
2331     IP = InsertPointTy(I->getParent(), I->getIterator());
2332     return FiniCB(IP);
2333   };
2334 
2335   Directive OMPD = Directive::OMPD_sections;
2336   // Since we are using Finalization Callback here, HasFinalize
2337   // and IsCancellable have to be true
2338   return EmitOMPInlinedRegion(OMPD, nullptr, nullptr, BodyGenCB, FiniCBWrapper,
2339                               /*Conditional*/ false, /*hasFinalize*/ true,
2340                               /*IsCancellable*/ true);
2341 }
2342 
getInsertPointAfterInstr(Instruction * I)2343 static OpenMPIRBuilder::InsertPointTy getInsertPointAfterInstr(Instruction *I) {
2344   BasicBlock::iterator IT(I);
2345   IT++;
2346   return OpenMPIRBuilder::InsertPointTy(I->getParent(), IT);
2347 }
2348 
getGPUThreadID()2349 Value *OpenMPIRBuilder::getGPUThreadID() {
2350   return Builder.CreateCall(
2351       getOrCreateRuntimeFunction(M,
2352                                  OMPRTL___kmpc_get_hardware_thread_id_in_block),
2353       {});
2354 }
2355 
getGPUWarpSize()2356 Value *OpenMPIRBuilder::getGPUWarpSize() {
2357   return Builder.CreateCall(
2358       getOrCreateRuntimeFunction(M, OMPRTL___kmpc_get_warp_size), {});
2359 }
2360 
getNVPTXWarpID()2361 Value *OpenMPIRBuilder::getNVPTXWarpID() {
2362   unsigned LaneIDBits = Log2_32(Config.getGridValue().GV_Warp_Size);
2363   return Builder.CreateAShr(getGPUThreadID(), LaneIDBits, "nvptx_warp_id");
2364 }
2365 
getNVPTXLaneID()2366 Value *OpenMPIRBuilder::getNVPTXLaneID() {
2367   unsigned LaneIDBits = Log2_32(Config.getGridValue().GV_Warp_Size);
2368   assert(LaneIDBits < 32 && "Invalid LaneIDBits size in NVPTX device.");
2369   unsigned LaneIDMask = ~0u >> (32u - LaneIDBits);
2370   return Builder.CreateAnd(getGPUThreadID(), Builder.getInt32(LaneIDMask),
2371                            "nvptx_lane_id");
2372 }
2373 
castValueToType(InsertPointTy AllocaIP,Value * From,Type * ToType)2374 Value *OpenMPIRBuilder::castValueToType(InsertPointTy AllocaIP, Value *From,
2375                                         Type *ToType) {
2376   Type *FromType = From->getType();
2377   uint64_t FromSize = M.getDataLayout().getTypeStoreSize(FromType);
2378   uint64_t ToSize = M.getDataLayout().getTypeStoreSize(ToType);
2379   assert(FromSize > 0 && "From size must be greater than zero");
2380   assert(ToSize > 0 && "To size must be greater than zero");
2381   if (FromType == ToType)
2382     return From;
2383   if (FromSize == ToSize)
2384     return Builder.CreateBitCast(From, ToType);
2385   if (ToType->isIntegerTy() && FromType->isIntegerTy())
2386     return Builder.CreateIntCast(From, ToType, /*isSigned*/ true);
2387   InsertPointTy SaveIP = Builder.saveIP();
2388   Builder.restoreIP(AllocaIP);
2389   Value *CastItem = Builder.CreateAlloca(ToType);
2390   Builder.restoreIP(SaveIP);
2391 
2392   Value *ValCastItem = Builder.CreatePointerBitCastOrAddrSpaceCast(
2393       CastItem, Builder.getPtrTy(0));
2394   Builder.CreateStore(From, ValCastItem);
2395   return Builder.CreateLoad(ToType, CastItem);
2396 }
2397 
createRuntimeShuffleFunction(InsertPointTy AllocaIP,Value * Element,Type * ElementType,Value * Offset)2398 Value *OpenMPIRBuilder::createRuntimeShuffleFunction(InsertPointTy AllocaIP,
2399                                                      Value *Element,
2400                                                      Type *ElementType,
2401                                                      Value *Offset) {
2402   uint64_t Size = M.getDataLayout().getTypeStoreSize(ElementType);
2403   assert(Size <= 8 && "Unsupported bitwidth in shuffle instruction");
2404 
2405   // Cast all types to 32- or 64-bit values before calling shuffle routines.
2406   Type *CastTy = Builder.getIntNTy(Size <= 4 ? 32 : 64);
2407   Value *ElemCast = castValueToType(AllocaIP, Element, CastTy);
2408   Value *WarpSize =
2409       Builder.CreateIntCast(getGPUWarpSize(), Builder.getInt16Ty(), true);
2410   Function *ShuffleFunc = getOrCreateRuntimeFunctionPtr(
2411       Size <= 4 ? RuntimeFunction::OMPRTL___kmpc_shuffle_int32
2412                 : RuntimeFunction::OMPRTL___kmpc_shuffle_int64);
2413   Value *WarpSizeCast =
2414       Builder.CreateIntCast(WarpSize, Builder.getInt16Ty(), /*isSigned=*/true);
2415   Value *ShuffleCall =
2416       Builder.CreateCall(ShuffleFunc, {ElemCast, Offset, WarpSizeCast});
2417   return castValueToType(AllocaIP, ShuffleCall, CastTy);
2418 }
2419 
shuffleAndStore(InsertPointTy AllocaIP,Value * SrcAddr,Value * DstAddr,Type * ElemType,Value * Offset,Type * ReductionArrayTy)2420 void OpenMPIRBuilder::shuffleAndStore(InsertPointTy AllocaIP, Value *SrcAddr,
2421                                       Value *DstAddr, Type *ElemType,
2422                                       Value *Offset, Type *ReductionArrayTy) {
2423   uint64_t Size = M.getDataLayout().getTypeStoreSize(ElemType);
2424   // Create the loop over the big sized data.
2425   // ptr = (void*)Elem;
2426   // ptrEnd = (void*) Elem + 1;
2427   // Step = 8;
2428   // while (ptr + Step < ptrEnd)
2429   //   shuffle((int64_t)*ptr);
2430   // Step = 4;
2431   // while (ptr + Step < ptrEnd)
2432   //   shuffle((int32_t)*ptr);
2433   // ...
2434   Type *IndexTy = Builder.getIndexTy(
2435       M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
2436   Value *ElemPtr = DstAddr;
2437   Value *Ptr = SrcAddr;
2438   for (unsigned IntSize = 8; IntSize >= 1; IntSize /= 2) {
2439     if (Size < IntSize)
2440       continue;
2441     Type *IntType = Builder.getIntNTy(IntSize * 8);
2442     Ptr = Builder.CreatePointerBitCastOrAddrSpaceCast(
2443         Ptr, Builder.getPtrTy(0), Ptr->getName() + ".ascast");
2444     Value *SrcAddrGEP =
2445         Builder.CreateGEP(ElemType, SrcAddr, {ConstantInt::get(IndexTy, 1)});
2446     ElemPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
2447         ElemPtr, Builder.getPtrTy(0), ElemPtr->getName() + ".ascast");
2448 
2449     Function *CurFunc = Builder.GetInsertBlock()->getParent();
2450     if ((Size / IntSize) > 1) {
2451       Value *PtrEnd = Builder.CreatePointerBitCastOrAddrSpaceCast(
2452           SrcAddrGEP, Builder.getPtrTy());
2453       BasicBlock *PreCondBB =
2454           BasicBlock::Create(M.getContext(), ".shuffle.pre_cond");
2455       BasicBlock *ThenBB = BasicBlock::Create(M.getContext(), ".shuffle.then");
2456       BasicBlock *ExitBB = BasicBlock::Create(M.getContext(), ".shuffle.exit");
2457       BasicBlock *CurrentBB = Builder.GetInsertBlock();
2458       emitBlock(PreCondBB, CurFunc);
2459       PHINode *PhiSrc =
2460           Builder.CreatePHI(Ptr->getType(), /*NumReservedValues=*/2);
2461       PhiSrc->addIncoming(Ptr, CurrentBB);
2462       PHINode *PhiDest =
2463           Builder.CreatePHI(ElemPtr->getType(), /*NumReservedValues=*/2);
2464       PhiDest->addIncoming(ElemPtr, CurrentBB);
2465       Ptr = PhiSrc;
2466       ElemPtr = PhiDest;
2467       Value *PtrDiff = Builder.CreatePtrDiff(
2468           Builder.getInt8Ty(), PtrEnd,
2469           Builder.CreatePointerBitCastOrAddrSpaceCast(Ptr, Builder.getPtrTy()));
2470       Builder.CreateCondBr(
2471           Builder.CreateICmpSGT(PtrDiff, Builder.getInt64(IntSize - 1)), ThenBB,
2472           ExitBB);
2473       emitBlock(ThenBB, CurFunc);
2474       Value *Res = createRuntimeShuffleFunction(
2475           AllocaIP,
2476           Builder.CreateAlignedLoad(
2477               IntType, Ptr, M.getDataLayout().getPrefTypeAlign(ElemType)),
2478           IntType, Offset);
2479       Builder.CreateAlignedStore(Res, ElemPtr,
2480                                  M.getDataLayout().getPrefTypeAlign(ElemType));
2481       Value *LocalPtr =
2482           Builder.CreateGEP(IntType, Ptr, {ConstantInt::get(IndexTy, 1)});
2483       Value *LocalElemPtr =
2484           Builder.CreateGEP(IntType, ElemPtr, {ConstantInt::get(IndexTy, 1)});
2485       PhiSrc->addIncoming(LocalPtr, ThenBB);
2486       PhiDest->addIncoming(LocalElemPtr, ThenBB);
2487       emitBranch(PreCondBB);
2488       emitBlock(ExitBB, CurFunc);
2489     } else {
2490       Value *Res = createRuntimeShuffleFunction(
2491           AllocaIP, Builder.CreateLoad(IntType, Ptr), IntType, Offset);
2492       if (ElemType->isIntegerTy() && ElemType->getScalarSizeInBits() <
2493                                          Res->getType()->getScalarSizeInBits())
2494         Res = Builder.CreateTrunc(Res, ElemType);
2495       Builder.CreateStore(Res, ElemPtr);
2496       Ptr = Builder.CreateGEP(IntType, Ptr, {ConstantInt::get(IndexTy, 1)});
2497       ElemPtr =
2498           Builder.CreateGEP(IntType, ElemPtr, {ConstantInt::get(IndexTy, 1)});
2499     }
2500     Size = Size % IntSize;
2501   }
2502 }
2503 
emitReductionListCopy(InsertPointTy AllocaIP,CopyAction Action,Type * ReductionArrayTy,ArrayRef<ReductionInfo> ReductionInfos,Value * SrcBase,Value * DestBase,CopyOptionsTy CopyOptions)2504 void OpenMPIRBuilder::emitReductionListCopy(
2505     InsertPointTy AllocaIP, CopyAction Action, Type *ReductionArrayTy,
2506     ArrayRef<ReductionInfo> ReductionInfos, Value *SrcBase, Value *DestBase,
2507     CopyOptionsTy CopyOptions) {
2508   Type *IndexTy = Builder.getIndexTy(
2509       M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
2510   Value *RemoteLaneOffset = CopyOptions.RemoteLaneOffset;
2511 
2512   // Iterates, element-by-element, through the source Reduce list and
2513   // make a copy.
2514   for (auto En : enumerate(ReductionInfos)) {
2515     const ReductionInfo &RI = En.value();
2516     Value *SrcElementAddr = nullptr;
2517     Value *DestElementAddr = nullptr;
2518     Value *DestElementPtrAddr = nullptr;
2519     // Should we shuffle in an element from a remote lane?
2520     bool ShuffleInElement = false;
2521     // Set to true to update the pointer in the dest Reduce list to a
2522     // newly created element.
2523     bool UpdateDestListPtr = false;
2524 
2525     // Step 1.1: Get the address for the src element in the Reduce list.
2526     Value *SrcElementPtrAddr = Builder.CreateInBoundsGEP(
2527         ReductionArrayTy, SrcBase,
2528         {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
2529     SrcElementAddr = Builder.CreateLoad(Builder.getPtrTy(), SrcElementPtrAddr);
2530 
2531     // Step 1.2: Create a temporary to store the element in the destination
2532     // Reduce list.
2533     DestElementPtrAddr = Builder.CreateInBoundsGEP(
2534         ReductionArrayTy, DestBase,
2535         {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
2536     switch (Action) {
2537     case CopyAction::RemoteLaneToThread: {
2538       InsertPointTy CurIP = Builder.saveIP();
2539       Builder.restoreIP(AllocaIP);
2540       AllocaInst *DestAlloca = Builder.CreateAlloca(RI.ElementType, nullptr,
2541                                                     ".omp.reduction.element");
2542       DestAlloca->setAlignment(
2543           M.getDataLayout().getPrefTypeAlign(RI.ElementType));
2544       DestElementAddr = DestAlloca;
2545       DestElementAddr =
2546           Builder.CreateAddrSpaceCast(DestElementAddr, Builder.getPtrTy(),
2547                                       DestElementAddr->getName() + ".ascast");
2548       Builder.restoreIP(CurIP);
2549       ShuffleInElement = true;
2550       UpdateDestListPtr = true;
2551       break;
2552     }
2553     case CopyAction::ThreadCopy: {
2554       DestElementAddr =
2555           Builder.CreateLoad(Builder.getPtrTy(), DestElementPtrAddr);
2556       break;
2557     }
2558     }
2559 
2560     // Now that all active lanes have read the element in the
2561     // Reduce list, shuffle over the value from the remote lane.
2562     if (ShuffleInElement) {
2563       shuffleAndStore(AllocaIP, SrcElementAddr, DestElementAddr, RI.ElementType,
2564                       RemoteLaneOffset, ReductionArrayTy);
2565     } else {
2566       switch (RI.EvaluationKind) {
2567       case EvalKind::Scalar: {
2568         Value *Elem = Builder.CreateLoad(RI.ElementType, SrcElementAddr);
2569         // Store the source element value to the dest element address.
2570         Builder.CreateStore(Elem, DestElementAddr);
2571         break;
2572       }
2573       case EvalKind::Complex: {
2574         Value *SrcRealPtr = Builder.CreateConstInBoundsGEP2_32(
2575             RI.ElementType, SrcElementAddr, 0, 0, ".realp");
2576         Value *SrcReal = Builder.CreateLoad(
2577             RI.ElementType->getStructElementType(0), SrcRealPtr, ".real");
2578         Value *SrcImgPtr = Builder.CreateConstInBoundsGEP2_32(
2579             RI.ElementType, SrcElementAddr, 0, 1, ".imagp");
2580         Value *SrcImg = Builder.CreateLoad(
2581             RI.ElementType->getStructElementType(1), SrcImgPtr, ".imag");
2582 
2583         Value *DestRealPtr = Builder.CreateConstInBoundsGEP2_32(
2584             RI.ElementType, DestElementAddr, 0, 0, ".realp");
2585         Value *DestImgPtr = Builder.CreateConstInBoundsGEP2_32(
2586             RI.ElementType, DestElementAddr, 0, 1, ".imagp");
2587         Builder.CreateStore(SrcReal, DestRealPtr);
2588         Builder.CreateStore(SrcImg, DestImgPtr);
2589         break;
2590       }
2591       case EvalKind::Aggregate: {
2592         Value *SizeVal = Builder.getInt64(
2593             M.getDataLayout().getTypeStoreSize(RI.ElementType));
2594         Builder.CreateMemCpy(
2595             DestElementAddr, M.getDataLayout().getPrefTypeAlign(RI.ElementType),
2596             SrcElementAddr, M.getDataLayout().getPrefTypeAlign(RI.ElementType),
2597             SizeVal, false);
2598         break;
2599       }
2600       };
2601     }
2602 
2603     // Step 3.1: Modify reference in dest Reduce list as needed.
2604     // Modifying the reference in Reduce list to point to the newly
2605     // created element.  The element is live in the current function
2606     // scope and that of functions it invokes (i.e., reduce_function).
2607     // RemoteReduceData[i] = (void*)&RemoteElem
2608     if (UpdateDestListPtr) {
2609       Value *CastDestAddr = Builder.CreatePointerBitCastOrAddrSpaceCast(
2610           DestElementAddr, Builder.getPtrTy(),
2611           DestElementAddr->getName() + ".ascast");
2612       Builder.CreateStore(CastDestAddr, DestElementPtrAddr);
2613     }
2614   }
2615 }
2616 
emitInterWarpCopyFunction(const LocationDescription & Loc,ArrayRef<ReductionInfo> ReductionInfos,AttributeList FuncAttrs)2617 Expected<Function *> OpenMPIRBuilder::emitInterWarpCopyFunction(
2618     const LocationDescription &Loc, ArrayRef<ReductionInfo> ReductionInfos,
2619     AttributeList FuncAttrs) {
2620   IRBuilder<>::InsertPointGuard IPG(Builder);
2621   LLVMContext &Ctx = M.getContext();
2622   FunctionType *FuncTy = FunctionType::get(
2623       Builder.getVoidTy(), {Builder.getPtrTy(), Builder.getInt32Ty()},
2624       /* IsVarArg */ false);
2625   Function *WcFunc =
2626       Function::Create(FuncTy, GlobalVariable::InternalLinkage,
2627                        "_omp_reduction_inter_warp_copy_func", &M);
2628   WcFunc->setAttributes(FuncAttrs);
2629   WcFunc->addParamAttr(0, Attribute::NoUndef);
2630   WcFunc->addParamAttr(1, Attribute::NoUndef);
2631   BasicBlock *EntryBB = BasicBlock::Create(M.getContext(), "entry", WcFunc);
2632   Builder.SetInsertPoint(EntryBB);
2633   Builder.SetCurrentDebugLocation(llvm::DebugLoc());
2634 
2635   // ReduceList: thread local Reduce list.
2636   // At the stage of the computation when this function is called, partially
2637   // aggregated values reside in the first lane of every active warp.
2638   Argument *ReduceListArg = WcFunc->getArg(0);
2639   // NumWarps: number of warps active in the parallel region.  This could
2640   // be smaller than 32 (max warps in a CTA) for partial block reduction.
2641   Argument *NumWarpsArg = WcFunc->getArg(1);
2642 
2643   // This array is used as a medium to transfer, one reduce element at a time,
2644   // the data from the first lane of every warp to lanes in the first warp
2645   // in order to perform the final step of a reduction in a parallel region
2646   // (reduction across warps).  The array is placed in NVPTX __shared__ memory
2647   // for reduced latency, as well as to have a distinct copy for concurrently
2648   // executing target regions.  The array is declared with common linkage so
2649   // as to be shared across compilation units.
2650   StringRef TransferMediumName =
2651       "__openmp_nvptx_data_transfer_temporary_storage";
2652   GlobalVariable *TransferMedium = M.getGlobalVariable(TransferMediumName);
2653   unsigned WarpSize = Config.getGridValue().GV_Warp_Size;
2654   ArrayType *ArrayTy = ArrayType::get(Builder.getInt32Ty(), WarpSize);
2655   if (!TransferMedium) {
2656     TransferMedium = new GlobalVariable(
2657         M, ArrayTy, /*isConstant=*/false, GlobalVariable::WeakAnyLinkage,
2658         UndefValue::get(ArrayTy), TransferMediumName,
2659         /*InsertBefore=*/nullptr, GlobalVariable::NotThreadLocal,
2660         /*AddressSpace=*/3);
2661   }
2662 
2663   // Get the CUDA thread id of the current OpenMP thread on the GPU.
2664   Value *GPUThreadID = getGPUThreadID();
2665   // nvptx_lane_id = nvptx_id % warpsize
2666   Value *LaneID = getNVPTXLaneID();
2667   // nvptx_warp_id = nvptx_id / warpsize
2668   Value *WarpID = getNVPTXWarpID();
2669 
2670   InsertPointTy AllocaIP =
2671       InsertPointTy(Builder.GetInsertBlock(),
2672                     Builder.GetInsertBlock()->getFirstInsertionPt());
2673   Type *Arg0Type = ReduceListArg->getType();
2674   Type *Arg1Type = NumWarpsArg->getType();
2675   Builder.restoreIP(AllocaIP);
2676   AllocaInst *ReduceListAlloca = Builder.CreateAlloca(
2677       Arg0Type, nullptr, ReduceListArg->getName() + ".addr");
2678   AllocaInst *NumWarpsAlloca =
2679       Builder.CreateAlloca(Arg1Type, nullptr, NumWarpsArg->getName() + ".addr");
2680   Value *ReduceListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2681       ReduceListAlloca, Arg0Type, ReduceListAlloca->getName() + ".ascast");
2682   Value *NumWarpsAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2683       NumWarpsAlloca, Builder.getPtrTy(0),
2684       NumWarpsAlloca->getName() + ".ascast");
2685   Builder.CreateStore(ReduceListArg, ReduceListAddrCast);
2686   Builder.CreateStore(NumWarpsArg, NumWarpsAddrCast);
2687   AllocaIP = getInsertPointAfterInstr(NumWarpsAlloca);
2688   InsertPointTy CodeGenIP =
2689       getInsertPointAfterInstr(&Builder.GetInsertBlock()->back());
2690   Builder.restoreIP(CodeGenIP);
2691 
2692   Value *ReduceList =
2693       Builder.CreateLoad(Builder.getPtrTy(), ReduceListAddrCast);
2694 
2695   for (auto En : enumerate(ReductionInfos)) {
2696     //
2697     // Warp master copies reduce element to transfer medium in __shared__
2698     // memory.
2699     //
2700     const ReductionInfo &RI = En.value();
2701     unsigned RealTySize = M.getDataLayout().getTypeAllocSize(RI.ElementType);
2702     for (unsigned TySize = 4; TySize > 0 && RealTySize > 0; TySize /= 2) {
2703       Type *CType = Builder.getIntNTy(TySize * 8);
2704 
2705       unsigned NumIters = RealTySize / TySize;
2706       if (NumIters == 0)
2707         continue;
2708       Value *Cnt = nullptr;
2709       Value *CntAddr = nullptr;
2710       BasicBlock *PrecondBB = nullptr;
2711       BasicBlock *ExitBB = nullptr;
2712       if (NumIters > 1) {
2713         CodeGenIP = Builder.saveIP();
2714         Builder.restoreIP(AllocaIP);
2715         CntAddr =
2716             Builder.CreateAlloca(Builder.getInt32Ty(), nullptr, ".cnt.addr");
2717 
2718         CntAddr = Builder.CreateAddrSpaceCast(CntAddr, Builder.getPtrTy(),
2719                                               CntAddr->getName() + ".ascast");
2720         Builder.restoreIP(CodeGenIP);
2721         Builder.CreateStore(Constant::getNullValue(Builder.getInt32Ty()),
2722                             CntAddr,
2723                             /*Volatile=*/false);
2724         PrecondBB = BasicBlock::Create(Ctx, "precond");
2725         ExitBB = BasicBlock::Create(Ctx, "exit");
2726         BasicBlock *BodyBB = BasicBlock::Create(Ctx, "body");
2727         emitBlock(PrecondBB, Builder.GetInsertBlock()->getParent());
2728         Cnt = Builder.CreateLoad(Builder.getInt32Ty(), CntAddr,
2729                                  /*Volatile=*/false);
2730         Value *Cmp = Builder.CreateICmpULT(
2731             Cnt, ConstantInt::get(Builder.getInt32Ty(), NumIters));
2732         Builder.CreateCondBr(Cmp, BodyBB, ExitBB);
2733         emitBlock(BodyBB, Builder.GetInsertBlock()->getParent());
2734       }
2735 
2736       // kmpc_barrier.
2737       InsertPointOrErrorTy BarrierIP1 =
2738           createBarrier(LocationDescription(Builder.saveIP(), Loc.DL),
2739                         omp::Directive::OMPD_unknown,
2740                         /* ForceSimpleCall */ false,
2741                         /* CheckCancelFlag */ true);
2742       if (!BarrierIP1)
2743         return BarrierIP1.takeError();
2744       BasicBlock *ThenBB = BasicBlock::Create(Ctx, "then");
2745       BasicBlock *ElseBB = BasicBlock::Create(Ctx, "else");
2746       BasicBlock *MergeBB = BasicBlock::Create(Ctx, "ifcont");
2747 
2748       // if (lane_id  == 0)
2749       Value *IsWarpMaster = Builder.CreateIsNull(LaneID, "warp_master");
2750       Builder.CreateCondBr(IsWarpMaster, ThenBB, ElseBB);
2751       emitBlock(ThenBB, Builder.GetInsertBlock()->getParent());
2752 
2753       // Reduce element = LocalReduceList[i]
2754       auto *RedListArrayTy =
2755           ArrayType::get(Builder.getPtrTy(), ReductionInfos.size());
2756       Type *IndexTy = Builder.getIndexTy(
2757           M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
2758       Value *ElemPtrPtr =
2759           Builder.CreateInBoundsGEP(RedListArrayTy, ReduceList,
2760                                     {ConstantInt::get(IndexTy, 0),
2761                                      ConstantInt::get(IndexTy, En.index())});
2762       // elemptr = ((CopyType*)(elemptrptr)) + I
2763       Value *ElemPtr = Builder.CreateLoad(Builder.getPtrTy(), ElemPtrPtr);
2764       if (NumIters > 1)
2765         ElemPtr = Builder.CreateGEP(Builder.getInt32Ty(), ElemPtr, Cnt);
2766 
2767       // Get pointer to location in transfer medium.
2768       // MediumPtr = &medium[warp_id]
2769       Value *MediumPtr = Builder.CreateInBoundsGEP(
2770           ArrayTy, TransferMedium, {Builder.getInt64(0), WarpID});
2771       // elem = *elemptr
2772       //*MediumPtr = elem
2773       Value *Elem = Builder.CreateLoad(CType, ElemPtr);
2774       // Store the source element value to the dest element address.
2775       Builder.CreateStore(Elem, MediumPtr,
2776                           /*IsVolatile*/ true);
2777       Builder.CreateBr(MergeBB);
2778 
2779       // else
2780       emitBlock(ElseBB, Builder.GetInsertBlock()->getParent());
2781       Builder.CreateBr(MergeBB);
2782 
2783       // endif
2784       emitBlock(MergeBB, Builder.GetInsertBlock()->getParent());
2785       InsertPointOrErrorTy BarrierIP2 =
2786           createBarrier(LocationDescription(Builder.saveIP(), Loc.DL),
2787                         omp::Directive::OMPD_unknown,
2788                         /* ForceSimpleCall */ false,
2789                         /* CheckCancelFlag */ true);
2790       if (!BarrierIP2)
2791         return BarrierIP2.takeError();
2792 
2793       // Warp 0 copies reduce element from transfer medium
2794       BasicBlock *W0ThenBB = BasicBlock::Create(Ctx, "then");
2795       BasicBlock *W0ElseBB = BasicBlock::Create(Ctx, "else");
2796       BasicBlock *W0MergeBB = BasicBlock::Create(Ctx, "ifcont");
2797 
2798       Value *NumWarpsVal =
2799           Builder.CreateLoad(Builder.getInt32Ty(), NumWarpsAddrCast);
2800       // Up to 32 threads in warp 0 are active.
2801       Value *IsActiveThread =
2802           Builder.CreateICmpULT(GPUThreadID, NumWarpsVal, "is_active_thread");
2803       Builder.CreateCondBr(IsActiveThread, W0ThenBB, W0ElseBB);
2804 
2805       emitBlock(W0ThenBB, Builder.GetInsertBlock()->getParent());
2806 
2807       // SecMediumPtr = &medium[tid]
2808       // SrcMediumVal = *SrcMediumPtr
2809       Value *SrcMediumPtrVal = Builder.CreateInBoundsGEP(
2810           ArrayTy, TransferMedium, {Builder.getInt64(0), GPUThreadID});
2811       // TargetElemPtr = (CopyType*)(SrcDataAddr[i]) + I
2812       Value *TargetElemPtrPtr =
2813           Builder.CreateInBoundsGEP(RedListArrayTy, ReduceList,
2814                                     {ConstantInt::get(IndexTy, 0),
2815                                      ConstantInt::get(IndexTy, En.index())});
2816       Value *TargetElemPtrVal =
2817           Builder.CreateLoad(Builder.getPtrTy(), TargetElemPtrPtr);
2818       Value *TargetElemPtr = TargetElemPtrVal;
2819       if (NumIters > 1)
2820         TargetElemPtr =
2821             Builder.CreateGEP(Builder.getInt32Ty(), TargetElemPtr, Cnt);
2822 
2823       // *TargetElemPtr = SrcMediumVal;
2824       Value *SrcMediumValue =
2825           Builder.CreateLoad(CType, SrcMediumPtrVal, /*IsVolatile*/ true);
2826       Builder.CreateStore(SrcMediumValue, TargetElemPtr);
2827       Builder.CreateBr(W0MergeBB);
2828 
2829       emitBlock(W0ElseBB, Builder.GetInsertBlock()->getParent());
2830       Builder.CreateBr(W0MergeBB);
2831 
2832       emitBlock(W0MergeBB, Builder.GetInsertBlock()->getParent());
2833 
2834       if (NumIters > 1) {
2835         Cnt = Builder.CreateNSWAdd(
2836             Cnt, ConstantInt::get(Builder.getInt32Ty(), /*V=*/1));
2837         Builder.CreateStore(Cnt, CntAddr, /*Volatile=*/false);
2838 
2839         auto *CurFn = Builder.GetInsertBlock()->getParent();
2840         emitBranch(PrecondBB);
2841         emitBlock(ExitBB, CurFn);
2842       }
2843       RealTySize %= TySize;
2844     }
2845   }
2846 
2847   Builder.CreateRetVoid();
2848 
2849   return WcFunc;
2850 }
2851 
emitShuffleAndReduceFunction(ArrayRef<ReductionInfo> ReductionInfos,Function * ReduceFn,AttributeList FuncAttrs)2852 Function *OpenMPIRBuilder::emitShuffleAndReduceFunction(
2853     ArrayRef<ReductionInfo> ReductionInfos, Function *ReduceFn,
2854     AttributeList FuncAttrs) {
2855   LLVMContext &Ctx = M.getContext();
2856   IRBuilder<>::InsertPointGuard IPG(Builder);
2857   FunctionType *FuncTy =
2858       FunctionType::get(Builder.getVoidTy(),
2859                         {Builder.getPtrTy(), Builder.getInt16Ty(),
2860                          Builder.getInt16Ty(), Builder.getInt16Ty()},
2861                         /* IsVarArg */ false);
2862   Function *SarFunc =
2863       Function::Create(FuncTy, GlobalVariable::InternalLinkage,
2864                        "_omp_reduction_shuffle_and_reduce_func", &M);
2865   SarFunc->setAttributes(FuncAttrs);
2866   SarFunc->addParamAttr(0, Attribute::NoUndef);
2867   SarFunc->addParamAttr(1, Attribute::NoUndef);
2868   SarFunc->addParamAttr(2, Attribute::NoUndef);
2869   SarFunc->addParamAttr(3, Attribute::NoUndef);
2870   SarFunc->addParamAttr(1, Attribute::SExt);
2871   SarFunc->addParamAttr(2, Attribute::SExt);
2872   SarFunc->addParamAttr(3, Attribute::SExt);
2873   BasicBlock *EntryBB = BasicBlock::Create(M.getContext(), "entry", SarFunc);
2874   Builder.SetInsertPoint(EntryBB);
2875   Builder.SetCurrentDebugLocation(llvm::DebugLoc());
2876 
2877   // Thread local Reduce list used to host the values of data to be reduced.
2878   Argument *ReduceListArg = SarFunc->getArg(0);
2879   // Current lane id; could be logical.
2880   Argument *LaneIDArg = SarFunc->getArg(1);
2881   // Offset of the remote source lane relative to the current lane.
2882   Argument *RemoteLaneOffsetArg = SarFunc->getArg(2);
2883   // Algorithm version.  This is expected to be known at compile time.
2884   Argument *AlgoVerArg = SarFunc->getArg(3);
2885 
2886   Type *ReduceListArgType = ReduceListArg->getType();
2887   Type *LaneIDArgType = LaneIDArg->getType();
2888   Type *LaneIDArgPtrType = Builder.getPtrTy(0);
2889   Value *ReduceListAlloca = Builder.CreateAlloca(
2890       ReduceListArgType, nullptr, ReduceListArg->getName() + ".addr");
2891   Value *LaneIdAlloca = Builder.CreateAlloca(LaneIDArgType, nullptr,
2892                                              LaneIDArg->getName() + ".addr");
2893   Value *RemoteLaneOffsetAlloca = Builder.CreateAlloca(
2894       LaneIDArgType, nullptr, RemoteLaneOffsetArg->getName() + ".addr");
2895   Value *AlgoVerAlloca = Builder.CreateAlloca(LaneIDArgType, nullptr,
2896                                               AlgoVerArg->getName() + ".addr");
2897   ArrayType *RedListArrayTy =
2898       ArrayType::get(Builder.getPtrTy(), ReductionInfos.size());
2899 
2900   // Create a local thread-private variable to host the Reduce list
2901   // from a remote lane.
2902   Instruction *RemoteReductionListAlloca = Builder.CreateAlloca(
2903       RedListArrayTy, nullptr, ".omp.reduction.remote_reduce_list");
2904 
2905   Value *ReduceListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2906       ReduceListAlloca, ReduceListArgType,
2907       ReduceListAlloca->getName() + ".ascast");
2908   Value *LaneIdAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2909       LaneIdAlloca, LaneIDArgPtrType, LaneIdAlloca->getName() + ".ascast");
2910   Value *RemoteLaneOffsetAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2911       RemoteLaneOffsetAlloca, LaneIDArgPtrType,
2912       RemoteLaneOffsetAlloca->getName() + ".ascast");
2913   Value *AlgoVerAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2914       AlgoVerAlloca, LaneIDArgPtrType, AlgoVerAlloca->getName() + ".ascast");
2915   Value *RemoteListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2916       RemoteReductionListAlloca, Builder.getPtrTy(),
2917       RemoteReductionListAlloca->getName() + ".ascast");
2918 
2919   Builder.CreateStore(ReduceListArg, ReduceListAddrCast);
2920   Builder.CreateStore(LaneIDArg, LaneIdAddrCast);
2921   Builder.CreateStore(RemoteLaneOffsetArg, RemoteLaneOffsetAddrCast);
2922   Builder.CreateStore(AlgoVerArg, AlgoVerAddrCast);
2923 
2924   Value *ReduceList = Builder.CreateLoad(ReduceListArgType, ReduceListAddrCast);
2925   Value *LaneId = Builder.CreateLoad(LaneIDArgType, LaneIdAddrCast);
2926   Value *RemoteLaneOffset =
2927       Builder.CreateLoad(LaneIDArgType, RemoteLaneOffsetAddrCast);
2928   Value *AlgoVer = Builder.CreateLoad(LaneIDArgType, AlgoVerAddrCast);
2929 
2930   InsertPointTy AllocaIP = getInsertPointAfterInstr(RemoteReductionListAlloca);
2931 
2932   // This loop iterates through the list of reduce elements and copies,
2933   // element by element, from a remote lane in the warp to RemoteReduceList,
2934   // hosted on the thread's stack.
2935   emitReductionListCopy(
2936       AllocaIP, CopyAction::RemoteLaneToThread, RedListArrayTy, ReductionInfos,
2937       ReduceList, RemoteListAddrCast, {RemoteLaneOffset, nullptr, nullptr});
2938 
2939   // The actions to be performed on the Remote Reduce list is dependent
2940   // on the algorithm version.
2941   //
2942   //  if (AlgoVer==0) || (AlgoVer==1 && (LaneId < Offset)) || (AlgoVer==2 &&
2943   //  LaneId % 2 == 0 && Offset > 0):
2944   //    do the reduction value aggregation
2945   //
2946   //  The thread local variable Reduce list is mutated in place to host the
2947   //  reduced data, which is the aggregated value produced from local and
2948   //  remote lanes.
2949   //
2950   //  Note that AlgoVer is expected to be a constant integer known at compile
2951   //  time.
2952   //  When AlgoVer==0, the first conjunction evaluates to true, making
2953   //    the entire predicate true during compile time.
2954   //  When AlgoVer==1, the second conjunction has only the second part to be
2955   //    evaluated during runtime.  Other conjunctions evaluates to false
2956   //    during compile time.
2957   //  When AlgoVer==2, the third conjunction has only the second part to be
2958   //    evaluated during runtime.  Other conjunctions evaluates to false
2959   //    during compile time.
2960   Value *CondAlgo0 = Builder.CreateIsNull(AlgoVer);
2961   Value *Algo1 = Builder.CreateICmpEQ(AlgoVer, Builder.getInt16(1));
2962   Value *LaneComp = Builder.CreateICmpULT(LaneId, RemoteLaneOffset);
2963   Value *CondAlgo1 = Builder.CreateAnd(Algo1, LaneComp);
2964   Value *Algo2 = Builder.CreateICmpEQ(AlgoVer, Builder.getInt16(2));
2965   Value *LaneIdAnd1 = Builder.CreateAnd(LaneId, Builder.getInt16(1));
2966   Value *LaneIdComp = Builder.CreateIsNull(LaneIdAnd1);
2967   Value *Algo2AndLaneIdComp = Builder.CreateAnd(Algo2, LaneIdComp);
2968   Value *RemoteOffsetComp =
2969       Builder.CreateICmpSGT(RemoteLaneOffset, Builder.getInt16(0));
2970   Value *CondAlgo2 = Builder.CreateAnd(Algo2AndLaneIdComp, RemoteOffsetComp);
2971   Value *CA0OrCA1 = Builder.CreateOr(CondAlgo0, CondAlgo1);
2972   Value *CondReduce = Builder.CreateOr(CA0OrCA1, CondAlgo2);
2973 
2974   BasicBlock *ThenBB = BasicBlock::Create(Ctx, "then");
2975   BasicBlock *ElseBB = BasicBlock::Create(Ctx, "else");
2976   BasicBlock *MergeBB = BasicBlock::Create(Ctx, "ifcont");
2977 
2978   Builder.CreateCondBr(CondReduce, ThenBB, ElseBB);
2979   emitBlock(ThenBB, Builder.GetInsertBlock()->getParent());
2980   Value *LocalReduceListPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
2981       ReduceList, Builder.getPtrTy());
2982   Value *RemoteReduceListPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
2983       RemoteListAddrCast, Builder.getPtrTy());
2984   Builder.CreateCall(ReduceFn, {LocalReduceListPtr, RemoteReduceListPtr})
2985       ->addFnAttr(Attribute::NoUnwind);
2986   Builder.CreateBr(MergeBB);
2987 
2988   emitBlock(ElseBB, Builder.GetInsertBlock()->getParent());
2989   Builder.CreateBr(MergeBB);
2990 
2991   emitBlock(MergeBB, Builder.GetInsertBlock()->getParent());
2992 
2993   // if (AlgoVer==1 && (LaneId >= Offset)) copy Remote Reduce list to local
2994   // Reduce list.
2995   Algo1 = Builder.CreateICmpEQ(AlgoVer, Builder.getInt16(1));
2996   Value *LaneIdGtOffset = Builder.CreateICmpUGE(LaneId, RemoteLaneOffset);
2997   Value *CondCopy = Builder.CreateAnd(Algo1, LaneIdGtOffset);
2998 
2999   BasicBlock *CpyThenBB = BasicBlock::Create(Ctx, "then");
3000   BasicBlock *CpyElseBB = BasicBlock::Create(Ctx, "else");
3001   BasicBlock *CpyMergeBB = BasicBlock::Create(Ctx, "ifcont");
3002   Builder.CreateCondBr(CondCopy, CpyThenBB, CpyElseBB);
3003 
3004   emitBlock(CpyThenBB, Builder.GetInsertBlock()->getParent());
3005   emitReductionListCopy(AllocaIP, CopyAction::ThreadCopy, RedListArrayTy,
3006                         ReductionInfos, RemoteListAddrCast, ReduceList);
3007   Builder.CreateBr(CpyMergeBB);
3008 
3009   emitBlock(CpyElseBB, Builder.GetInsertBlock()->getParent());
3010   Builder.CreateBr(CpyMergeBB);
3011 
3012   emitBlock(CpyMergeBB, Builder.GetInsertBlock()->getParent());
3013 
3014   Builder.CreateRetVoid();
3015 
3016   return SarFunc;
3017 }
3018 
emitListToGlobalCopyFunction(ArrayRef<ReductionInfo> ReductionInfos,Type * ReductionsBufferTy,AttributeList FuncAttrs)3019 Function *OpenMPIRBuilder::emitListToGlobalCopyFunction(
3020     ArrayRef<ReductionInfo> ReductionInfos, Type *ReductionsBufferTy,
3021     AttributeList FuncAttrs) {
3022   IRBuilder<>::InsertPointGuard IPG(Builder);
3023   LLVMContext &Ctx = M.getContext();
3024   FunctionType *FuncTy = FunctionType::get(
3025       Builder.getVoidTy(),
3026       {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
3027       /* IsVarArg */ false);
3028   Function *LtGCFunc =
3029       Function::Create(FuncTy, GlobalVariable::InternalLinkage,
3030                        "_omp_reduction_list_to_global_copy_func", &M);
3031   LtGCFunc->setAttributes(FuncAttrs);
3032   LtGCFunc->addParamAttr(0, Attribute::NoUndef);
3033   LtGCFunc->addParamAttr(1, Attribute::NoUndef);
3034   LtGCFunc->addParamAttr(2, Attribute::NoUndef);
3035 
3036   BasicBlock *EntryBlock = BasicBlock::Create(Ctx, "entry", LtGCFunc);
3037   Builder.SetInsertPoint(EntryBlock);
3038   Builder.SetCurrentDebugLocation(llvm::DebugLoc());
3039 
3040   // Buffer: global reduction buffer.
3041   Argument *BufferArg = LtGCFunc->getArg(0);
3042   // Idx: index of the buffer.
3043   Argument *IdxArg = LtGCFunc->getArg(1);
3044   // ReduceList: thread local Reduce list.
3045   Argument *ReduceListArg = LtGCFunc->getArg(2);
3046 
3047   Value *BufferArgAlloca = Builder.CreateAlloca(Builder.getPtrTy(), nullptr,
3048                                                 BufferArg->getName() + ".addr");
3049   Value *IdxArgAlloca = Builder.CreateAlloca(Builder.getInt32Ty(), nullptr,
3050                                              IdxArg->getName() + ".addr");
3051   Value *ReduceListArgAlloca = Builder.CreateAlloca(
3052       Builder.getPtrTy(), nullptr, ReduceListArg->getName() + ".addr");
3053   Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3054       BufferArgAlloca, Builder.getPtrTy(),
3055       BufferArgAlloca->getName() + ".ascast");
3056   Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3057       IdxArgAlloca, Builder.getPtrTy(), IdxArgAlloca->getName() + ".ascast");
3058   Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3059       ReduceListArgAlloca, Builder.getPtrTy(),
3060       ReduceListArgAlloca->getName() + ".ascast");
3061 
3062   Builder.CreateStore(BufferArg, BufferArgAddrCast);
3063   Builder.CreateStore(IdxArg, IdxArgAddrCast);
3064   Builder.CreateStore(ReduceListArg, ReduceListArgAddrCast);
3065 
3066   Value *LocalReduceList =
3067       Builder.CreateLoad(Builder.getPtrTy(), ReduceListArgAddrCast);
3068   Value *BufferArgVal =
3069       Builder.CreateLoad(Builder.getPtrTy(), BufferArgAddrCast);
3070   Value *Idxs[] = {Builder.CreateLoad(Builder.getInt32Ty(), IdxArgAddrCast)};
3071   Type *IndexTy = Builder.getIndexTy(
3072       M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
3073   for (auto En : enumerate(ReductionInfos)) {
3074     const ReductionInfo &RI = En.value();
3075     auto *RedListArrayTy =
3076         ArrayType::get(Builder.getPtrTy(), ReductionInfos.size());
3077     // Reduce element = LocalReduceList[i]
3078     Value *ElemPtrPtr = Builder.CreateInBoundsGEP(
3079         RedListArrayTy, LocalReduceList,
3080         {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
3081     // elemptr = ((CopyType*)(elemptrptr)) + I
3082     Value *ElemPtr = Builder.CreateLoad(Builder.getPtrTy(), ElemPtrPtr);
3083 
3084     // Global = Buffer.VD[Idx];
3085     Value *BufferVD =
3086         Builder.CreateInBoundsGEP(ReductionsBufferTy, BufferArgVal, Idxs);
3087     Value *GlobVal = Builder.CreateConstInBoundsGEP2_32(
3088         ReductionsBufferTy, BufferVD, 0, En.index());
3089 
3090     switch (RI.EvaluationKind) {
3091     case EvalKind::Scalar: {
3092       Value *TargetElement = Builder.CreateLoad(RI.ElementType, ElemPtr);
3093       Builder.CreateStore(TargetElement, GlobVal);
3094       break;
3095     }
3096     case EvalKind::Complex: {
3097       Value *SrcRealPtr = Builder.CreateConstInBoundsGEP2_32(
3098           RI.ElementType, ElemPtr, 0, 0, ".realp");
3099       Value *SrcReal = Builder.CreateLoad(
3100           RI.ElementType->getStructElementType(0), SrcRealPtr, ".real");
3101       Value *SrcImgPtr = Builder.CreateConstInBoundsGEP2_32(
3102           RI.ElementType, ElemPtr, 0, 1, ".imagp");
3103       Value *SrcImg = Builder.CreateLoad(
3104           RI.ElementType->getStructElementType(1), SrcImgPtr, ".imag");
3105 
3106       Value *DestRealPtr = Builder.CreateConstInBoundsGEP2_32(
3107           RI.ElementType, GlobVal, 0, 0, ".realp");
3108       Value *DestImgPtr = Builder.CreateConstInBoundsGEP2_32(
3109           RI.ElementType, GlobVal, 0, 1, ".imagp");
3110       Builder.CreateStore(SrcReal, DestRealPtr);
3111       Builder.CreateStore(SrcImg, DestImgPtr);
3112       break;
3113     }
3114     case EvalKind::Aggregate: {
3115       Value *SizeVal =
3116           Builder.getInt64(M.getDataLayout().getTypeStoreSize(RI.ElementType));
3117       Builder.CreateMemCpy(
3118           GlobVal, M.getDataLayout().getPrefTypeAlign(RI.ElementType), ElemPtr,
3119           M.getDataLayout().getPrefTypeAlign(RI.ElementType), SizeVal, false);
3120       break;
3121     }
3122     }
3123   }
3124 
3125   Builder.CreateRetVoid();
3126   return LtGCFunc;
3127 }
3128 
emitListToGlobalReduceFunction(ArrayRef<ReductionInfo> ReductionInfos,Function * ReduceFn,Type * ReductionsBufferTy,AttributeList FuncAttrs)3129 Function *OpenMPIRBuilder::emitListToGlobalReduceFunction(
3130     ArrayRef<ReductionInfo> ReductionInfos, Function *ReduceFn,
3131     Type *ReductionsBufferTy, AttributeList FuncAttrs) {
3132   IRBuilder<>::InsertPointGuard IPG(Builder);
3133   LLVMContext &Ctx = M.getContext();
3134   FunctionType *FuncTy = FunctionType::get(
3135       Builder.getVoidTy(),
3136       {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
3137       /* IsVarArg */ false);
3138   Function *LtGRFunc =
3139       Function::Create(FuncTy, GlobalVariable::InternalLinkage,
3140                        "_omp_reduction_list_to_global_reduce_func", &M);
3141   LtGRFunc->setAttributes(FuncAttrs);
3142   LtGRFunc->addParamAttr(0, Attribute::NoUndef);
3143   LtGRFunc->addParamAttr(1, Attribute::NoUndef);
3144   LtGRFunc->addParamAttr(2, Attribute::NoUndef);
3145 
3146   BasicBlock *EntryBlock = BasicBlock::Create(Ctx, "entry", LtGRFunc);
3147   Builder.SetInsertPoint(EntryBlock);
3148   Builder.SetCurrentDebugLocation(llvm::DebugLoc());
3149 
3150   // Buffer: global reduction buffer.
3151   Argument *BufferArg = LtGRFunc->getArg(0);
3152   // Idx: index of the buffer.
3153   Argument *IdxArg = LtGRFunc->getArg(1);
3154   // ReduceList: thread local Reduce list.
3155   Argument *ReduceListArg = LtGRFunc->getArg(2);
3156 
3157   Value *BufferArgAlloca = Builder.CreateAlloca(Builder.getPtrTy(), nullptr,
3158                                                 BufferArg->getName() + ".addr");
3159   Value *IdxArgAlloca = Builder.CreateAlloca(Builder.getInt32Ty(), nullptr,
3160                                              IdxArg->getName() + ".addr");
3161   Value *ReduceListArgAlloca = Builder.CreateAlloca(
3162       Builder.getPtrTy(), nullptr, ReduceListArg->getName() + ".addr");
3163   auto *RedListArrayTy =
3164       ArrayType::get(Builder.getPtrTy(), ReductionInfos.size());
3165 
3166   // 1. Build a list of reduction variables.
3167   // void *RedList[<n>] = {<ReductionVars>[0], ..., <ReductionVars>[<n>-1]};
3168   Value *LocalReduceList =
3169       Builder.CreateAlloca(RedListArrayTy, nullptr, ".omp.reduction.red_list");
3170 
3171   Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3172       BufferArgAlloca, Builder.getPtrTy(),
3173       BufferArgAlloca->getName() + ".ascast");
3174   Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3175       IdxArgAlloca, Builder.getPtrTy(), IdxArgAlloca->getName() + ".ascast");
3176   Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3177       ReduceListArgAlloca, Builder.getPtrTy(),
3178       ReduceListArgAlloca->getName() + ".ascast");
3179   Value *LocalReduceListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3180       LocalReduceList, Builder.getPtrTy(),
3181       LocalReduceList->getName() + ".ascast");
3182 
3183   Builder.CreateStore(BufferArg, BufferArgAddrCast);
3184   Builder.CreateStore(IdxArg, IdxArgAddrCast);
3185   Builder.CreateStore(ReduceListArg, ReduceListArgAddrCast);
3186 
3187   Value *BufferVal = Builder.CreateLoad(Builder.getPtrTy(), BufferArgAddrCast);
3188   Value *Idxs[] = {Builder.CreateLoad(Builder.getInt32Ty(), IdxArgAddrCast)};
3189   Type *IndexTy = Builder.getIndexTy(
3190       M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
3191   for (auto En : enumerate(ReductionInfos)) {
3192     Value *TargetElementPtrPtr = Builder.CreateInBoundsGEP(
3193         RedListArrayTy, LocalReduceListAddrCast,
3194         {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
3195     Value *BufferVD =
3196         Builder.CreateInBoundsGEP(ReductionsBufferTy, BufferVal, Idxs);
3197     // Global = Buffer.VD[Idx];
3198     Value *GlobValPtr = Builder.CreateConstInBoundsGEP2_32(
3199         ReductionsBufferTy, BufferVD, 0, En.index());
3200     Builder.CreateStore(GlobValPtr, TargetElementPtrPtr);
3201   }
3202 
3203   // Call reduce_function(GlobalReduceList, ReduceList)
3204   Value *ReduceList =
3205       Builder.CreateLoad(Builder.getPtrTy(), ReduceListArgAddrCast);
3206   Builder.CreateCall(ReduceFn, {LocalReduceListAddrCast, ReduceList})
3207       ->addFnAttr(Attribute::NoUnwind);
3208   Builder.CreateRetVoid();
3209   return LtGRFunc;
3210 }
3211 
emitGlobalToListCopyFunction(ArrayRef<ReductionInfo> ReductionInfos,Type * ReductionsBufferTy,AttributeList FuncAttrs)3212 Function *OpenMPIRBuilder::emitGlobalToListCopyFunction(
3213     ArrayRef<ReductionInfo> ReductionInfos, Type *ReductionsBufferTy,
3214     AttributeList FuncAttrs) {
3215   IRBuilder<>::InsertPointGuard IPG(Builder);
3216   LLVMContext &Ctx = M.getContext();
3217   FunctionType *FuncTy = FunctionType::get(
3218       Builder.getVoidTy(),
3219       {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
3220       /* IsVarArg */ false);
3221   Function *LtGCFunc =
3222       Function::Create(FuncTy, GlobalVariable::InternalLinkage,
3223                        "_omp_reduction_global_to_list_copy_func", &M);
3224   LtGCFunc->setAttributes(FuncAttrs);
3225   LtGCFunc->addParamAttr(0, Attribute::NoUndef);
3226   LtGCFunc->addParamAttr(1, Attribute::NoUndef);
3227   LtGCFunc->addParamAttr(2, Attribute::NoUndef);
3228 
3229   BasicBlock *EntryBlock = BasicBlock::Create(Ctx, "entry", LtGCFunc);
3230   Builder.SetInsertPoint(EntryBlock);
3231   Builder.SetCurrentDebugLocation(llvm::DebugLoc());
3232 
3233   // Buffer: global reduction buffer.
3234   Argument *BufferArg = LtGCFunc->getArg(0);
3235   // Idx: index of the buffer.
3236   Argument *IdxArg = LtGCFunc->getArg(1);
3237   // ReduceList: thread local Reduce list.
3238   Argument *ReduceListArg = LtGCFunc->getArg(2);
3239 
3240   Value *BufferArgAlloca = Builder.CreateAlloca(Builder.getPtrTy(), nullptr,
3241                                                 BufferArg->getName() + ".addr");
3242   Value *IdxArgAlloca = Builder.CreateAlloca(Builder.getInt32Ty(), nullptr,
3243                                              IdxArg->getName() + ".addr");
3244   Value *ReduceListArgAlloca = Builder.CreateAlloca(
3245       Builder.getPtrTy(), nullptr, ReduceListArg->getName() + ".addr");
3246   Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3247       BufferArgAlloca, Builder.getPtrTy(),
3248       BufferArgAlloca->getName() + ".ascast");
3249   Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3250       IdxArgAlloca, Builder.getPtrTy(), IdxArgAlloca->getName() + ".ascast");
3251   Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3252       ReduceListArgAlloca, Builder.getPtrTy(),
3253       ReduceListArgAlloca->getName() + ".ascast");
3254   Builder.CreateStore(BufferArg, BufferArgAddrCast);
3255   Builder.CreateStore(IdxArg, IdxArgAddrCast);
3256   Builder.CreateStore(ReduceListArg, ReduceListArgAddrCast);
3257 
3258   Value *LocalReduceList =
3259       Builder.CreateLoad(Builder.getPtrTy(), ReduceListArgAddrCast);
3260   Value *BufferVal = Builder.CreateLoad(Builder.getPtrTy(), BufferArgAddrCast);
3261   Value *Idxs[] = {Builder.CreateLoad(Builder.getInt32Ty(), IdxArgAddrCast)};
3262   Type *IndexTy = Builder.getIndexTy(
3263       M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
3264   for (auto En : enumerate(ReductionInfos)) {
3265     const OpenMPIRBuilder::ReductionInfo &RI = En.value();
3266     auto *RedListArrayTy =
3267         ArrayType::get(Builder.getPtrTy(), ReductionInfos.size());
3268     // Reduce element = LocalReduceList[i]
3269     Value *ElemPtrPtr = Builder.CreateInBoundsGEP(
3270         RedListArrayTy, LocalReduceList,
3271         {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
3272     // elemptr = ((CopyType*)(elemptrptr)) + I
3273     Value *ElemPtr = Builder.CreateLoad(Builder.getPtrTy(), ElemPtrPtr);
3274     // Global = Buffer.VD[Idx];
3275     Value *BufferVD =
3276         Builder.CreateInBoundsGEP(ReductionsBufferTy, BufferVal, Idxs);
3277     Value *GlobValPtr = Builder.CreateConstInBoundsGEP2_32(
3278         ReductionsBufferTy, BufferVD, 0, En.index());
3279 
3280     switch (RI.EvaluationKind) {
3281     case EvalKind::Scalar: {
3282       Value *TargetElement = Builder.CreateLoad(RI.ElementType, GlobValPtr);
3283       Builder.CreateStore(TargetElement, ElemPtr);
3284       break;
3285     }
3286     case EvalKind::Complex: {
3287       Value *SrcRealPtr = Builder.CreateConstInBoundsGEP2_32(
3288           RI.ElementType, GlobValPtr, 0, 0, ".realp");
3289       Value *SrcReal = Builder.CreateLoad(
3290           RI.ElementType->getStructElementType(0), SrcRealPtr, ".real");
3291       Value *SrcImgPtr = Builder.CreateConstInBoundsGEP2_32(
3292           RI.ElementType, GlobValPtr, 0, 1, ".imagp");
3293       Value *SrcImg = Builder.CreateLoad(
3294           RI.ElementType->getStructElementType(1), SrcImgPtr, ".imag");
3295 
3296       Value *DestRealPtr = Builder.CreateConstInBoundsGEP2_32(
3297           RI.ElementType, ElemPtr, 0, 0, ".realp");
3298       Value *DestImgPtr = Builder.CreateConstInBoundsGEP2_32(
3299           RI.ElementType, ElemPtr, 0, 1, ".imagp");
3300       Builder.CreateStore(SrcReal, DestRealPtr);
3301       Builder.CreateStore(SrcImg, DestImgPtr);
3302       break;
3303     }
3304     case EvalKind::Aggregate: {
3305       Value *SizeVal =
3306           Builder.getInt64(M.getDataLayout().getTypeStoreSize(RI.ElementType));
3307       Builder.CreateMemCpy(
3308           ElemPtr, M.getDataLayout().getPrefTypeAlign(RI.ElementType),
3309           GlobValPtr, M.getDataLayout().getPrefTypeAlign(RI.ElementType),
3310           SizeVal, false);
3311       break;
3312     }
3313     }
3314   }
3315 
3316   Builder.CreateRetVoid();
3317   return LtGCFunc;
3318 }
3319 
emitGlobalToListReduceFunction(ArrayRef<ReductionInfo> ReductionInfos,Function * ReduceFn,Type * ReductionsBufferTy,AttributeList FuncAttrs)3320 Function *OpenMPIRBuilder::emitGlobalToListReduceFunction(
3321     ArrayRef<ReductionInfo> ReductionInfos, Function *ReduceFn,
3322     Type *ReductionsBufferTy, AttributeList FuncAttrs) {
3323   IRBuilder<>::InsertPointGuard IPG(Builder);
3324   LLVMContext &Ctx = M.getContext();
3325   auto *FuncTy = FunctionType::get(
3326       Builder.getVoidTy(),
3327       {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
3328       /* IsVarArg */ false);
3329   Function *LtGRFunc =
3330       Function::Create(FuncTy, GlobalVariable::InternalLinkage,
3331                        "_omp_reduction_global_to_list_reduce_func", &M);
3332   LtGRFunc->setAttributes(FuncAttrs);
3333   LtGRFunc->addParamAttr(0, Attribute::NoUndef);
3334   LtGRFunc->addParamAttr(1, Attribute::NoUndef);
3335   LtGRFunc->addParamAttr(2, Attribute::NoUndef);
3336 
3337   BasicBlock *EntryBlock = BasicBlock::Create(Ctx, "entry", LtGRFunc);
3338   Builder.SetInsertPoint(EntryBlock);
3339   Builder.SetCurrentDebugLocation(llvm::DebugLoc());
3340 
3341   // Buffer: global reduction buffer.
3342   Argument *BufferArg = LtGRFunc->getArg(0);
3343   // Idx: index of the buffer.
3344   Argument *IdxArg = LtGRFunc->getArg(1);
3345   // ReduceList: thread local Reduce list.
3346   Argument *ReduceListArg = LtGRFunc->getArg(2);
3347 
3348   Value *BufferArgAlloca = Builder.CreateAlloca(Builder.getPtrTy(), nullptr,
3349                                                 BufferArg->getName() + ".addr");
3350   Value *IdxArgAlloca = Builder.CreateAlloca(Builder.getInt32Ty(), nullptr,
3351                                              IdxArg->getName() + ".addr");
3352   Value *ReduceListArgAlloca = Builder.CreateAlloca(
3353       Builder.getPtrTy(), nullptr, ReduceListArg->getName() + ".addr");
3354   ArrayType *RedListArrayTy =
3355       ArrayType::get(Builder.getPtrTy(), ReductionInfos.size());
3356 
3357   // 1. Build a list of reduction variables.
3358   // void *RedList[<n>] = {<ReductionVars>[0], ..., <ReductionVars>[<n>-1]};
3359   Value *LocalReduceList =
3360       Builder.CreateAlloca(RedListArrayTy, nullptr, ".omp.reduction.red_list");
3361 
3362   Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3363       BufferArgAlloca, Builder.getPtrTy(),
3364       BufferArgAlloca->getName() + ".ascast");
3365   Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3366       IdxArgAlloca, Builder.getPtrTy(), IdxArgAlloca->getName() + ".ascast");
3367   Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3368       ReduceListArgAlloca, Builder.getPtrTy(),
3369       ReduceListArgAlloca->getName() + ".ascast");
3370   Value *ReductionList = Builder.CreatePointerBitCastOrAddrSpaceCast(
3371       LocalReduceList, Builder.getPtrTy(),
3372       LocalReduceList->getName() + ".ascast");
3373 
3374   Builder.CreateStore(BufferArg, BufferArgAddrCast);
3375   Builder.CreateStore(IdxArg, IdxArgAddrCast);
3376   Builder.CreateStore(ReduceListArg, ReduceListArgAddrCast);
3377 
3378   Value *BufferVal = Builder.CreateLoad(Builder.getPtrTy(), BufferArgAddrCast);
3379   Value *Idxs[] = {Builder.CreateLoad(Builder.getInt32Ty(), IdxArgAddrCast)};
3380   Type *IndexTy = Builder.getIndexTy(
3381       M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
3382   for (auto En : enumerate(ReductionInfos)) {
3383     Value *TargetElementPtrPtr = Builder.CreateInBoundsGEP(
3384         RedListArrayTy, ReductionList,
3385         {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
3386     // Global = Buffer.VD[Idx];
3387     Value *BufferVD =
3388         Builder.CreateInBoundsGEP(ReductionsBufferTy, BufferVal, Idxs);
3389     Value *GlobValPtr = Builder.CreateConstInBoundsGEP2_32(
3390         ReductionsBufferTy, BufferVD, 0, En.index());
3391     Builder.CreateStore(GlobValPtr, TargetElementPtrPtr);
3392   }
3393 
3394   // Call reduce_function(ReduceList, GlobalReduceList)
3395   Value *ReduceList =
3396       Builder.CreateLoad(Builder.getPtrTy(), ReduceListArgAddrCast);
3397   Builder.CreateCall(ReduceFn, {ReduceList, ReductionList})
3398       ->addFnAttr(Attribute::NoUnwind);
3399   Builder.CreateRetVoid();
3400   return LtGRFunc;
3401 }
3402 
getReductionFuncName(StringRef Name) const3403 std::string OpenMPIRBuilder::getReductionFuncName(StringRef Name) const {
3404   std::string Suffix =
3405       createPlatformSpecificName({"omp", "reduction", "reduction_func"});
3406   return (Name + Suffix).str();
3407 }
3408 
createReductionFunction(StringRef ReducerName,ArrayRef<ReductionInfo> ReductionInfos,ReductionGenCBKind ReductionGenCBKind,AttributeList FuncAttrs)3409 Expected<Function *> OpenMPIRBuilder::createReductionFunction(
3410     StringRef ReducerName, ArrayRef<ReductionInfo> ReductionInfos,
3411     ReductionGenCBKind ReductionGenCBKind, AttributeList FuncAttrs) {
3412   IRBuilder<>::InsertPointGuard IPG(Builder);
3413   auto *FuncTy = FunctionType::get(Builder.getVoidTy(),
3414                                    {Builder.getPtrTy(), Builder.getPtrTy()},
3415                                    /* IsVarArg */ false);
3416   std::string Name = getReductionFuncName(ReducerName);
3417   Function *ReductionFunc =
3418       Function::Create(FuncTy, GlobalVariable::InternalLinkage, Name, &M);
3419   ReductionFunc->setAttributes(FuncAttrs);
3420   ReductionFunc->addParamAttr(0, Attribute::NoUndef);
3421   ReductionFunc->addParamAttr(1, Attribute::NoUndef);
3422   BasicBlock *EntryBB =
3423       BasicBlock::Create(M.getContext(), "entry", ReductionFunc);
3424   Builder.SetInsertPoint(EntryBB);
3425   Builder.SetCurrentDebugLocation(llvm::DebugLoc());
3426 
3427   // Need to alloca memory here and deal with the pointers before getting
3428   // LHS/RHS pointers out
3429   Value *LHSArrayPtr = nullptr;
3430   Value *RHSArrayPtr = nullptr;
3431   Argument *Arg0 = ReductionFunc->getArg(0);
3432   Argument *Arg1 = ReductionFunc->getArg(1);
3433   Type *Arg0Type = Arg0->getType();
3434   Type *Arg1Type = Arg1->getType();
3435 
3436   Value *LHSAlloca =
3437       Builder.CreateAlloca(Arg0Type, nullptr, Arg0->getName() + ".addr");
3438   Value *RHSAlloca =
3439       Builder.CreateAlloca(Arg1Type, nullptr, Arg1->getName() + ".addr");
3440   Value *LHSAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3441       LHSAlloca, Arg0Type, LHSAlloca->getName() + ".ascast");
3442   Value *RHSAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3443       RHSAlloca, Arg1Type, RHSAlloca->getName() + ".ascast");
3444   Builder.CreateStore(Arg0, LHSAddrCast);
3445   Builder.CreateStore(Arg1, RHSAddrCast);
3446   LHSArrayPtr = Builder.CreateLoad(Arg0Type, LHSAddrCast);
3447   RHSArrayPtr = Builder.CreateLoad(Arg1Type, RHSAddrCast);
3448 
3449   Type *RedArrayTy = ArrayType::get(Builder.getPtrTy(), ReductionInfos.size());
3450   Type *IndexTy = Builder.getIndexTy(
3451       M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
3452   SmallVector<Value *> LHSPtrs, RHSPtrs;
3453   for (auto En : enumerate(ReductionInfos)) {
3454     const ReductionInfo &RI = En.value();
3455     Value *RHSI8PtrPtr = Builder.CreateInBoundsGEP(
3456         RedArrayTy, RHSArrayPtr,
3457         {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
3458     Value *RHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), RHSI8PtrPtr);
3459     Value *RHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
3460         RHSI8Ptr, RI.PrivateVariable->getType(),
3461         RHSI8Ptr->getName() + ".ascast");
3462 
3463     Value *LHSI8PtrPtr = Builder.CreateInBoundsGEP(
3464         RedArrayTy, LHSArrayPtr,
3465         {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
3466     Value *LHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), LHSI8PtrPtr);
3467     Value *LHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
3468         LHSI8Ptr, RI.Variable->getType(), LHSI8Ptr->getName() + ".ascast");
3469 
3470     if (ReductionGenCBKind == ReductionGenCBKind::Clang) {
3471       LHSPtrs.emplace_back(LHSPtr);
3472       RHSPtrs.emplace_back(RHSPtr);
3473     } else {
3474       Value *LHS = Builder.CreateLoad(RI.ElementType, LHSPtr);
3475       Value *RHS = Builder.CreateLoad(RI.ElementType, RHSPtr);
3476       Value *Reduced;
3477       InsertPointOrErrorTy AfterIP =
3478           RI.ReductionGen(Builder.saveIP(), LHS, RHS, Reduced);
3479       if (!AfterIP)
3480         return AfterIP.takeError();
3481       if (!Builder.GetInsertBlock())
3482         return ReductionFunc;
3483       Builder.CreateStore(Reduced, LHSPtr);
3484     }
3485   }
3486 
3487   if (ReductionGenCBKind == ReductionGenCBKind::Clang)
3488     for (auto En : enumerate(ReductionInfos)) {
3489       unsigned Index = En.index();
3490       const ReductionInfo &RI = En.value();
3491       Value *LHSFixupPtr, *RHSFixupPtr;
3492       Builder.restoreIP(RI.ReductionGenClang(
3493           Builder.saveIP(), Index, &LHSFixupPtr, &RHSFixupPtr, ReductionFunc));
3494 
3495       // Fix the CallBack code genereated to use the correct Values for the LHS
3496       // and RHS
3497       LHSFixupPtr->replaceUsesWithIf(
3498           LHSPtrs[Index], [ReductionFunc](const Use &U) {
3499             return cast<Instruction>(U.getUser())->getParent()->getParent() ==
3500                    ReductionFunc;
3501           });
3502       RHSFixupPtr->replaceUsesWithIf(
3503           RHSPtrs[Index], [ReductionFunc](const Use &U) {
3504             return cast<Instruction>(U.getUser())->getParent()->getParent() ==
3505                    ReductionFunc;
3506           });
3507     }
3508 
3509   Builder.CreateRetVoid();
3510   return ReductionFunc;
3511 }
3512 
3513 static void
checkReductionInfos(ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos,bool IsGPU)3514 checkReductionInfos(ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos,
3515                     bool IsGPU) {
3516   for (const OpenMPIRBuilder::ReductionInfo &RI : ReductionInfos) {
3517     (void)RI;
3518     assert(RI.Variable && "expected non-null variable");
3519     assert(RI.PrivateVariable && "expected non-null private variable");
3520     assert((RI.ReductionGen || RI.ReductionGenClang) &&
3521            "expected non-null reduction generator callback");
3522     if (!IsGPU) {
3523       assert(
3524           RI.Variable->getType() == RI.PrivateVariable->getType() &&
3525           "expected variables and their private equivalents to have the same "
3526           "type");
3527     }
3528     assert(RI.Variable->getType()->isPointerTy() &&
3529            "expected variables to be pointers");
3530   }
3531 }
3532 
createReductionsGPU(const LocationDescription & Loc,InsertPointTy AllocaIP,InsertPointTy CodeGenIP,ArrayRef<ReductionInfo> ReductionInfos,bool IsNoWait,bool IsTeamsReduction,ReductionGenCBKind ReductionGenCBKind,std::optional<omp::GV> GridValue,unsigned ReductionBufNum,Value * SrcLocInfo)3533 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductionsGPU(
3534     const LocationDescription &Loc, InsertPointTy AllocaIP,
3535     InsertPointTy CodeGenIP, ArrayRef<ReductionInfo> ReductionInfos,
3536     bool IsNoWait, bool IsTeamsReduction, ReductionGenCBKind ReductionGenCBKind,
3537     std::optional<omp::GV> GridValue, unsigned ReductionBufNum,
3538     Value *SrcLocInfo) {
3539   if (!updateToLocation(Loc))
3540     return InsertPointTy();
3541   Builder.restoreIP(CodeGenIP);
3542   checkReductionInfos(ReductionInfos, /*IsGPU*/ true);
3543   LLVMContext &Ctx = M.getContext();
3544 
3545   // Source location for the ident struct
3546   if (!SrcLocInfo) {
3547     uint32_t SrcLocStrSize;
3548     Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3549     SrcLocInfo = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3550   }
3551 
3552   if (ReductionInfos.size() == 0)
3553     return Builder.saveIP();
3554 
3555   BasicBlock *ContinuationBlock = nullptr;
3556   if (ReductionGenCBKind != ReductionGenCBKind::Clang) {
3557     // Copied code from createReductions
3558     BasicBlock *InsertBlock = Loc.IP.getBlock();
3559     ContinuationBlock =
3560         InsertBlock->splitBasicBlock(Loc.IP.getPoint(), "reduce.finalize");
3561     InsertBlock->getTerminator()->eraseFromParent();
3562     Builder.SetInsertPoint(InsertBlock, InsertBlock->end());
3563   }
3564 
3565   Function *CurFunc = Builder.GetInsertBlock()->getParent();
3566   AttributeList FuncAttrs;
3567   AttrBuilder AttrBldr(Ctx);
3568   for (auto Attr : CurFunc->getAttributes().getFnAttrs())
3569     AttrBldr.addAttribute(Attr);
3570   AttrBldr.removeAttribute(Attribute::OptimizeNone);
3571   FuncAttrs = FuncAttrs.addFnAttributes(Ctx, AttrBldr);
3572 
3573   CodeGenIP = Builder.saveIP();
3574   Expected<Function *> ReductionResult =
3575       createReductionFunction(Builder.GetInsertBlock()->getParent()->getName(),
3576                               ReductionInfos, ReductionGenCBKind, FuncAttrs);
3577   if (!ReductionResult)
3578     return ReductionResult.takeError();
3579   Function *ReductionFunc = *ReductionResult;
3580   Builder.restoreIP(CodeGenIP);
3581 
3582   // Set the grid value in the config needed for lowering later on
3583   if (GridValue.has_value())
3584     Config.setGridValue(GridValue.value());
3585   else
3586     Config.setGridValue(getGridValue(T, ReductionFunc));
3587 
3588   // Build res = __kmpc_reduce{_nowait}(<gtid>, <n>, sizeof(RedList),
3589   // RedList, shuffle_reduce_func, interwarp_copy_func);
3590   // or
3591   // Build res = __kmpc_reduce_teams_nowait_simple(<loc>, <gtid>, <lck>);
3592   Value *Res;
3593 
3594   // 1. Build a list of reduction variables.
3595   // void *RedList[<n>] = {<ReductionVars>[0], ..., <ReductionVars>[<n>-1]};
3596   auto Size = ReductionInfos.size();
3597   Type *PtrTy = PointerType::getUnqual(Ctx);
3598   Type *RedArrayTy = ArrayType::get(PtrTy, Size);
3599   CodeGenIP = Builder.saveIP();
3600   Builder.restoreIP(AllocaIP);
3601   Value *ReductionListAlloca =
3602       Builder.CreateAlloca(RedArrayTy, nullptr, ".omp.reduction.red_list");
3603   Value *ReductionList = Builder.CreatePointerBitCastOrAddrSpaceCast(
3604       ReductionListAlloca, PtrTy, ReductionListAlloca->getName() + ".ascast");
3605   Builder.restoreIP(CodeGenIP);
3606   Type *IndexTy = Builder.getIndexTy(
3607       M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
3608   for (auto En : enumerate(ReductionInfos)) {
3609     const ReductionInfo &RI = En.value();
3610     Value *ElemPtr = Builder.CreateInBoundsGEP(
3611         RedArrayTy, ReductionList,
3612         {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
3613     Value *CastElem =
3614         Builder.CreatePointerBitCastOrAddrSpaceCast(RI.PrivateVariable, PtrTy);
3615     Builder.CreateStore(CastElem, ElemPtr);
3616   }
3617   CodeGenIP = Builder.saveIP();
3618   Function *SarFunc =
3619       emitShuffleAndReduceFunction(ReductionInfos, ReductionFunc, FuncAttrs);
3620   Expected<Function *> CopyResult =
3621       emitInterWarpCopyFunction(Loc, ReductionInfos, FuncAttrs);
3622   if (!CopyResult)
3623     return CopyResult.takeError();
3624   Function *WcFunc = *CopyResult;
3625   Builder.restoreIP(CodeGenIP);
3626 
3627   Value *RL = Builder.CreatePointerBitCastOrAddrSpaceCast(ReductionList, PtrTy);
3628 
3629   unsigned MaxDataSize = 0;
3630   SmallVector<Type *> ReductionTypeArgs;
3631   for (auto En : enumerate(ReductionInfos)) {
3632     auto Size = M.getDataLayout().getTypeStoreSize(En.value().ElementType);
3633     if (Size > MaxDataSize)
3634       MaxDataSize = Size;
3635     ReductionTypeArgs.emplace_back(En.value().ElementType);
3636   }
3637   Value *ReductionDataSize =
3638       Builder.getInt64(MaxDataSize * ReductionInfos.size());
3639   if (!IsTeamsReduction) {
3640     Value *SarFuncCast =
3641         Builder.CreatePointerBitCastOrAddrSpaceCast(SarFunc, PtrTy);
3642     Value *WcFuncCast =
3643         Builder.CreatePointerBitCastOrAddrSpaceCast(WcFunc, PtrTy);
3644     Value *Args[] = {SrcLocInfo, ReductionDataSize, RL, SarFuncCast,
3645                      WcFuncCast};
3646     Function *Pv2Ptr = getOrCreateRuntimeFunctionPtr(
3647         RuntimeFunction::OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2);
3648     Res = Builder.CreateCall(Pv2Ptr, Args);
3649   } else {
3650     CodeGenIP = Builder.saveIP();
3651     StructType *ReductionsBufferTy = StructType::create(
3652         Ctx, ReductionTypeArgs, "struct._globalized_locals_ty");
3653     Function *RedFixedBuferFn = getOrCreateRuntimeFunctionPtr(
3654         RuntimeFunction::OMPRTL___kmpc_reduction_get_fixed_buffer);
3655     Function *LtGCFunc = emitListToGlobalCopyFunction(
3656         ReductionInfos, ReductionsBufferTy, FuncAttrs);
3657     Function *LtGRFunc = emitListToGlobalReduceFunction(
3658         ReductionInfos, ReductionFunc, ReductionsBufferTy, FuncAttrs);
3659     Function *GtLCFunc = emitGlobalToListCopyFunction(
3660         ReductionInfos, ReductionsBufferTy, FuncAttrs);
3661     Function *GtLRFunc = emitGlobalToListReduceFunction(
3662         ReductionInfos, ReductionFunc, ReductionsBufferTy, FuncAttrs);
3663     Builder.restoreIP(CodeGenIP);
3664 
3665     Value *KernelTeamsReductionPtr = Builder.CreateCall(
3666         RedFixedBuferFn, {}, "_openmp_teams_reductions_buffer_$_$ptr");
3667 
3668     Value *Args3[] = {SrcLocInfo,
3669                       KernelTeamsReductionPtr,
3670                       Builder.getInt32(ReductionBufNum),
3671                       ReductionDataSize,
3672                       RL,
3673                       SarFunc,
3674                       WcFunc,
3675                       LtGCFunc,
3676                       LtGRFunc,
3677                       GtLCFunc,
3678                       GtLRFunc};
3679 
3680     Function *TeamsReduceFn = getOrCreateRuntimeFunctionPtr(
3681         RuntimeFunction::OMPRTL___kmpc_nvptx_teams_reduce_nowait_v2);
3682     Res = Builder.CreateCall(TeamsReduceFn, Args3);
3683   }
3684 
3685   // 5. Build if (res == 1)
3686   BasicBlock *ExitBB = BasicBlock::Create(Ctx, ".omp.reduction.done");
3687   BasicBlock *ThenBB = BasicBlock::Create(Ctx, ".omp.reduction.then");
3688   Value *Cond = Builder.CreateICmpEQ(Res, Builder.getInt32(1));
3689   Builder.CreateCondBr(Cond, ThenBB, ExitBB);
3690 
3691   // 6. Build then branch: where we have reduced values in the master
3692   //    thread in each team.
3693   //    __kmpc_end_reduce{_nowait}(<gtid>);
3694   //    break;
3695   emitBlock(ThenBB, CurFunc);
3696 
3697   // Add emission of __kmpc_end_reduce{_nowait}(<gtid>);
3698   for (auto En : enumerate(ReductionInfos)) {
3699     const ReductionInfo &RI = En.value();
3700     Value *LHS = RI.Variable;
3701     Value *RHS =
3702         Builder.CreatePointerBitCastOrAddrSpaceCast(RI.PrivateVariable, PtrTy);
3703 
3704     if (ReductionGenCBKind == ReductionGenCBKind::Clang) {
3705       Value *LHSPtr, *RHSPtr;
3706       Builder.restoreIP(RI.ReductionGenClang(Builder.saveIP(), En.index(),
3707                                              &LHSPtr, &RHSPtr, CurFunc));
3708 
3709       // Fix the CallBack code genereated to use the correct Values for the LHS
3710       // and RHS
3711       LHSPtr->replaceUsesWithIf(LHS, [ReductionFunc](const Use &U) {
3712         return cast<Instruction>(U.getUser())->getParent()->getParent() ==
3713                ReductionFunc;
3714       });
3715       RHSPtr->replaceUsesWithIf(RHS, [ReductionFunc](const Use &U) {
3716         return cast<Instruction>(U.getUser())->getParent()->getParent() ==
3717                ReductionFunc;
3718       });
3719     } else {
3720       Value *LHSValue = Builder.CreateLoad(RI.ElementType, LHS, "final.lhs");
3721       Value *RHSValue = Builder.CreateLoad(RI.ElementType, RHS, "final.rhs");
3722       Value *Reduced;
3723       InsertPointOrErrorTy AfterIP =
3724           RI.ReductionGen(Builder.saveIP(), RHSValue, LHSValue, Reduced);
3725       if (!AfterIP)
3726         return AfterIP.takeError();
3727       Builder.CreateStore(Reduced, LHS, false);
3728     }
3729   }
3730   emitBlock(ExitBB, CurFunc);
3731   if (ContinuationBlock) {
3732     Builder.CreateBr(ContinuationBlock);
3733     Builder.SetInsertPoint(ContinuationBlock);
3734   }
3735   Config.setEmitLLVMUsed();
3736 
3737   return Builder.saveIP();
3738 }
3739 
getFreshReductionFunc(Module & M)3740 static Function *getFreshReductionFunc(Module &M) {
3741   Type *VoidTy = Type::getVoidTy(M.getContext());
3742   Type *Int8PtrTy = PointerType::getUnqual(M.getContext());
3743   auto *FuncTy =
3744       FunctionType::get(VoidTy, {Int8PtrTy, Int8PtrTy}, /* IsVarArg */ false);
3745   return Function::Create(FuncTy, GlobalVariable::InternalLinkage,
3746                           ".omp.reduction.func", &M);
3747 }
3748 
populateReductionFunction(Function * ReductionFunc,ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos,IRBuilder<> & Builder,ArrayRef<bool> IsByRef,bool IsGPU)3749 static Error populateReductionFunction(
3750     Function *ReductionFunc,
3751     ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos,
3752     IRBuilder<> &Builder, ArrayRef<bool> IsByRef, bool IsGPU) {
3753   IRBuilder<>::InsertPointGuard IPG(Builder);
3754   Module *Module = ReductionFunc->getParent();
3755   BasicBlock *ReductionFuncBlock =
3756       BasicBlock::Create(Module->getContext(), "", ReductionFunc);
3757   Builder.SetInsertPoint(ReductionFuncBlock);
3758   Builder.SetCurrentDebugLocation(llvm::DebugLoc());
3759   Value *LHSArrayPtr = nullptr;
3760   Value *RHSArrayPtr = nullptr;
3761   if (IsGPU) {
3762     // Need to alloca memory here and deal with the pointers before getting
3763     // LHS/RHS pointers out
3764     //
3765     Argument *Arg0 = ReductionFunc->getArg(0);
3766     Argument *Arg1 = ReductionFunc->getArg(1);
3767     Type *Arg0Type = Arg0->getType();
3768     Type *Arg1Type = Arg1->getType();
3769 
3770     Value *LHSAlloca =
3771         Builder.CreateAlloca(Arg0Type, nullptr, Arg0->getName() + ".addr");
3772     Value *RHSAlloca =
3773         Builder.CreateAlloca(Arg1Type, nullptr, Arg1->getName() + ".addr");
3774     Value *LHSAddrCast =
3775         Builder.CreatePointerBitCastOrAddrSpaceCast(LHSAlloca, Arg0Type);
3776     Value *RHSAddrCast =
3777         Builder.CreatePointerBitCastOrAddrSpaceCast(RHSAlloca, Arg1Type);
3778     Builder.CreateStore(Arg0, LHSAddrCast);
3779     Builder.CreateStore(Arg1, RHSAddrCast);
3780     LHSArrayPtr = Builder.CreateLoad(Arg0Type, LHSAddrCast);
3781     RHSArrayPtr = Builder.CreateLoad(Arg1Type, RHSAddrCast);
3782   } else {
3783     LHSArrayPtr = ReductionFunc->getArg(0);
3784     RHSArrayPtr = ReductionFunc->getArg(1);
3785   }
3786 
3787   unsigned NumReductions = ReductionInfos.size();
3788   Type *RedArrayTy = ArrayType::get(Builder.getPtrTy(), NumReductions);
3789 
3790   for (auto En : enumerate(ReductionInfos)) {
3791     const OpenMPIRBuilder::ReductionInfo &RI = En.value();
3792     Value *LHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
3793         RedArrayTy, LHSArrayPtr, 0, En.index());
3794     Value *LHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), LHSI8PtrPtr);
3795     Value *LHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
3796         LHSI8Ptr, RI.Variable->getType());
3797     Value *LHS = Builder.CreateLoad(RI.ElementType, LHSPtr);
3798     Value *RHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
3799         RedArrayTy, RHSArrayPtr, 0, En.index());
3800     Value *RHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), RHSI8PtrPtr);
3801     Value *RHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
3802         RHSI8Ptr, RI.PrivateVariable->getType());
3803     Value *RHS = Builder.CreateLoad(RI.ElementType, RHSPtr);
3804     Value *Reduced;
3805     OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
3806         RI.ReductionGen(Builder.saveIP(), LHS, RHS, Reduced);
3807     if (!AfterIP)
3808       return AfterIP.takeError();
3809 
3810     Builder.restoreIP(*AfterIP);
3811     // TODO: Consider flagging an error.
3812     if (!Builder.GetInsertBlock())
3813       return Error::success();
3814 
3815     // store is inside of the reduction region when using by-ref
3816     if (!IsByRef[En.index()])
3817       Builder.CreateStore(Reduced, LHSPtr);
3818   }
3819   Builder.CreateRetVoid();
3820   return Error::success();
3821 }
3822 
createReductions(const LocationDescription & Loc,InsertPointTy AllocaIP,ArrayRef<ReductionInfo> ReductionInfos,ArrayRef<bool> IsByRef,bool IsNoWait,bool IsTeamsReduction)3823 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductions(
3824     const LocationDescription &Loc, InsertPointTy AllocaIP,
3825     ArrayRef<ReductionInfo> ReductionInfos, ArrayRef<bool> IsByRef,
3826     bool IsNoWait, bool IsTeamsReduction) {
3827   assert(ReductionInfos.size() == IsByRef.size());
3828   if (Config.isGPU())
3829     return createReductionsGPU(Loc, AllocaIP, Builder.saveIP(), ReductionInfos,
3830                                IsNoWait, IsTeamsReduction);
3831 
3832   checkReductionInfos(ReductionInfos, /*IsGPU*/ false);
3833 
3834   if (!updateToLocation(Loc))
3835     return InsertPointTy();
3836 
3837   if (ReductionInfos.size() == 0)
3838     return Builder.saveIP();
3839 
3840   BasicBlock *InsertBlock = Loc.IP.getBlock();
3841   BasicBlock *ContinuationBlock =
3842       InsertBlock->splitBasicBlock(Loc.IP.getPoint(), "reduce.finalize");
3843   InsertBlock->getTerminator()->eraseFromParent();
3844 
3845   // Create and populate array of type-erased pointers to private reduction
3846   // values.
3847   unsigned NumReductions = ReductionInfos.size();
3848   Type *RedArrayTy = ArrayType::get(Builder.getPtrTy(), NumReductions);
3849   Builder.SetInsertPoint(AllocaIP.getBlock()->getTerminator());
3850   Value *RedArray = Builder.CreateAlloca(RedArrayTy, nullptr, "red.array");
3851 
3852   Builder.SetInsertPoint(InsertBlock, InsertBlock->end());
3853 
3854   for (auto En : enumerate(ReductionInfos)) {
3855     unsigned Index = En.index();
3856     const ReductionInfo &RI = En.value();
3857     Value *RedArrayElemPtr = Builder.CreateConstInBoundsGEP2_64(
3858         RedArrayTy, RedArray, 0, Index, "red.array.elem." + Twine(Index));
3859     Builder.CreateStore(RI.PrivateVariable, RedArrayElemPtr);
3860   }
3861 
3862   // Emit a call to the runtime function that orchestrates the reduction.
3863   // Declare the reduction function in the process.
3864   Type *IndexTy = Builder.getIndexTy(
3865       M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
3866   Function *Func = Builder.GetInsertBlock()->getParent();
3867   Module *Module = Func->getParent();
3868   uint32_t SrcLocStrSize;
3869   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3870   bool CanGenerateAtomic = all_of(ReductionInfos, [](const ReductionInfo &RI) {
3871     return RI.AtomicReductionGen;
3872   });
3873   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize,
3874                                   CanGenerateAtomic
3875                                       ? IdentFlag::OMP_IDENT_FLAG_ATOMIC_REDUCE
3876                                       : IdentFlag(0));
3877   Value *ThreadId = getOrCreateThreadID(Ident);
3878   Constant *NumVariables = Builder.getInt32(NumReductions);
3879   const DataLayout &DL = Module->getDataLayout();
3880   unsigned RedArrayByteSize = DL.getTypeStoreSize(RedArrayTy);
3881   Constant *RedArraySize = ConstantInt::get(IndexTy, RedArrayByteSize);
3882   Function *ReductionFunc = getFreshReductionFunc(*Module);
3883   Value *Lock = getOMPCriticalRegionLock(".reduction");
3884   Function *ReduceFunc = getOrCreateRuntimeFunctionPtr(
3885       IsNoWait ? RuntimeFunction::OMPRTL___kmpc_reduce_nowait
3886                : RuntimeFunction::OMPRTL___kmpc_reduce);
3887   CallInst *ReduceCall =
3888       Builder.CreateCall(ReduceFunc,
3889                          {Ident, ThreadId, NumVariables, RedArraySize, RedArray,
3890                           ReductionFunc, Lock},
3891                          "reduce");
3892 
3893   // Create final reduction entry blocks for the atomic and non-atomic case.
3894   // Emit IR that dispatches control flow to one of the blocks based on the
3895   // reduction supporting the atomic mode.
3896   BasicBlock *NonAtomicRedBlock =
3897       BasicBlock::Create(Module->getContext(), "reduce.switch.nonatomic", Func);
3898   BasicBlock *AtomicRedBlock =
3899       BasicBlock::Create(Module->getContext(), "reduce.switch.atomic", Func);
3900   SwitchInst *Switch =
3901       Builder.CreateSwitch(ReduceCall, ContinuationBlock, /* NumCases */ 2);
3902   Switch->addCase(Builder.getInt32(1), NonAtomicRedBlock);
3903   Switch->addCase(Builder.getInt32(2), AtomicRedBlock);
3904 
3905   // Populate the non-atomic reduction using the elementwise reduction function.
3906   // This loads the elements from the global and private variables and reduces
3907   // them before storing back the result to the global variable.
3908   Builder.SetInsertPoint(NonAtomicRedBlock);
3909   for (auto En : enumerate(ReductionInfos)) {
3910     const ReductionInfo &RI = En.value();
3911     Type *ValueType = RI.ElementType;
3912     // We have one less load for by-ref case because that load is now inside of
3913     // the reduction region
3914     Value *RedValue = RI.Variable;
3915     if (!IsByRef[En.index()]) {
3916       RedValue = Builder.CreateLoad(ValueType, RI.Variable,
3917                                     "red.value." + Twine(En.index()));
3918     }
3919     Value *PrivateRedValue =
3920         Builder.CreateLoad(ValueType, RI.PrivateVariable,
3921                            "red.private.value." + Twine(En.index()));
3922     Value *Reduced;
3923     InsertPointOrErrorTy AfterIP =
3924         RI.ReductionGen(Builder.saveIP(), RedValue, PrivateRedValue, Reduced);
3925     if (!AfterIP)
3926       return AfterIP.takeError();
3927     Builder.restoreIP(*AfterIP);
3928 
3929     if (!Builder.GetInsertBlock())
3930       return InsertPointTy();
3931     // for by-ref case, the load is inside of the reduction region
3932     if (!IsByRef[En.index()])
3933       Builder.CreateStore(Reduced, RI.Variable);
3934   }
3935   Function *EndReduceFunc = getOrCreateRuntimeFunctionPtr(
3936       IsNoWait ? RuntimeFunction::OMPRTL___kmpc_end_reduce_nowait
3937                : RuntimeFunction::OMPRTL___kmpc_end_reduce);
3938   Builder.CreateCall(EndReduceFunc, {Ident, ThreadId, Lock});
3939   Builder.CreateBr(ContinuationBlock);
3940 
3941   // Populate the atomic reduction using the atomic elementwise reduction
3942   // function. There are no loads/stores here because they will be happening
3943   // inside the atomic elementwise reduction.
3944   Builder.SetInsertPoint(AtomicRedBlock);
3945   if (CanGenerateAtomic && llvm::none_of(IsByRef, [](bool P) { return P; })) {
3946     for (const ReductionInfo &RI : ReductionInfos) {
3947       InsertPointOrErrorTy AfterIP = RI.AtomicReductionGen(
3948           Builder.saveIP(), RI.ElementType, RI.Variable, RI.PrivateVariable);
3949       if (!AfterIP)
3950         return AfterIP.takeError();
3951       Builder.restoreIP(*AfterIP);
3952       if (!Builder.GetInsertBlock())
3953         return InsertPointTy();
3954     }
3955     Builder.CreateBr(ContinuationBlock);
3956   } else {
3957     Builder.CreateUnreachable();
3958   }
3959 
3960   // Populate the outlined reduction function using the elementwise reduction
3961   // function. Partial values are extracted from the type-erased array of
3962   // pointers to private variables.
3963   Error Err = populateReductionFunction(ReductionFunc, ReductionInfos, Builder,
3964                                         IsByRef, /*isGPU=*/false);
3965   if (Err)
3966     return Err;
3967 
3968   if (!Builder.GetInsertBlock())
3969     return InsertPointTy();
3970 
3971   Builder.SetInsertPoint(ContinuationBlock);
3972   return Builder.saveIP();
3973 }
3974 
3975 OpenMPIRBuilder::InsertPointOrErrorTy
createMaster(const LocationDescription & Loc,BodyGenCallbackTy BodyGenCB,FinalizeCallbackTy FiniCB)3976 OpenMPIRBuilder::createMaster(const LocationDescription &Loc,
3977                               BodyGenCallbackTy BodyGenCB,
3978                               FinalizeCallbackTy FiniCB) {
3979   if (!updateToLocation(Loc))
3980     return Loc.IP;
3981 
3982   Directive OMPD = Directive::OMPD_master;
3983   uint32_t SrcLocStrSize;
3984   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3985   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3986   Value *ThreadId = getOrCreateThreadID(Ident);
3987   Value *Args[] = {Ident, ThreadId};
3988 
3989   Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_master);
3990   Instruction *EntryCall = Builder.CreateCall(EntryRTLFn, Args);
3991 
3992   Function *ExitRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_master);
3993   Instruction *ExitCall = Builder.CreateCall(ExitRTLFn, Args);
3994 
3995   return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
3996                               /*Conditional*/ true, /*hasFinalize*/ true);
3997 }
3998 
3999 OpenMPIRBuilder::InsertPointOrErrorTy
createMasked(const LocationDescription & Loc,BodyGenCallbackTy BodyGenCB,FinalizeCallbackTy FiniCB,Value * Filter)4000 OpenMPIRBuilder::createMasked(const LocationDescription &Loc,
4001                               BodyGenCallbackTy BodyGenCB,
4002                               FinalizeCallbackTy FiniCB, Value *Filter) {
4003   if (!updateToLocation(Loc))
4004     return Loc.IP;
4005 
4006   Directive OMPD = Directive::OMPD_masked;
4007   uint32_t SrcLocStrSize;
4008   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4009   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4010   Value *ThreadId = getOrCreateThreadID(Ident);
4011   Value *Args[] = {Ident, ThreadId, Filter};
4012   Value *ArgsEnd[] = {Ident, ThreadId};
4013 
4014   Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_masked);
4015   Instruction *EntryCall = Builder.CreateCall(EntryRTLFn, Args);
4016 
4017   Function *ExitRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_masked);
4018   Instruction *ExitCall = Builder.CreateCall(ExitRTLFn, ArgsEnd);
4019 
4020   return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
4021                               /*Conditional*/ true, /*hasFinalize*/ true);
4022 }
4023 
createLoopSkeleton(DebugLoc DL,Value * TripCount,Function * F,BasicBlock * PreInsertBefore,BasicBlock * PostInsertBefore,const Twine & Name)4024 CanonicalLoopInfo *OpenMPIRBuilder::createLoopSkeleton(
4025     DebugLoc DL, Value *TripCount, Function *F, BasicBlock *PreInsertBefore,
4026     BasicBlock *PostInsertBefore, const Twine &Name) {
4027   Module *M = F->getParent();
4028   LLVMContext &Ctx = M->getContext();
4029   Type *IndVarTy = TripCount->getType();
4030 
4031   // Create the basic block structure.
4032   BasicBlock *Preheader =
4033       BasicBlock::Create(Ctx, "omp_" + Name + ".preheader", F, PreInsertBefore);
4034   BasicBlock *Header =
4035       BasicBlock::Create(Ctx, "omp_" + Name + ".header", F, PreInsertBefore);
4036   BasicBlock *Cond =
4037       BasicBlock::Create(Ctx, "omp_" + Name + ".cond", F, PreInsertBefore);
4038   BasicBlock *Body =
4039       BasicBlock::Create(Ctx, "omp_" + Name + ".body", F, PreInsertBefore);
4040   BasicBlock *Latch =
4041       BasicBlock::Create(Ctx, "omp_" + Name + ".inc", F, PostInsertBefore);
4042   BasicBlock *Exit =
4043       BasicBlock::Create(Ctx, "omp_" + Name + ".exit", F, PostInsertBefore);
4044   BasicBlock *After =
4045       BasicBlock::Create(Ctx, "omp_" + Name + ".after", F, PostInsertBefore);
4046 
4047   // Use specified DebugLoc for new instructions.
4048   Builder.SetCurrentDebugLocation(DL);
4049 
4050   Builder.SetInsertPoint(Preheader);
4051   Builder.CreateBr(Header);
4052 
4053   Builder.SetInsertPoint(Header);
4054   PHINode *IndVarPHI = Builder.CreatePHI(IndVarTy, 2, "omp_" + Name + ".iv");
4055   IndVarPHI->addIncoming(ConstantInt::get(IndVarTy, 0), Preheader);
4056   Builder.CreateBr(Cond);
4057 
4058   Builder.SetInsertPoint(Cond);
4059   Value *Cmp =
4060       Builder.CreateICmpULT(IndVarPHI, TripCount, "omp_" + Name + ".cmp");
4061   Builder.CreateCondBr(Cmp, Body, Exit);
4062 
4063   Builder.SetInsertPoint(Body);
4064   Builder.CreateBr(Latch);
4065 
4066   Builder.SetInsertPoint(Latch);
4067   Value *Next = Builder.CreateAdd(IndVarPHI, ConstantInt::get(IndVarTy, 1),
4068                                   "omp_" + Name + ".next", /*HasNUW=*/true);
4069   Builder.CreateBr(Header);
4070   IndVarPHI->addIncoming(Next, Latch);
4071 
4072   Builder.SetInsertPoint(Exit);
4073   Builder.CreateBr(After);
4074 
4075   // Remember and return the canonical control flow.
4076   LoopInfos.emplace_front();
4077   CanonicalLoopInfo *CL = &LoopInfos.front();
4078 
4079   CL->Header = Header;
4080   CL->Cond = Cond;
4081   CL->Latch = Latch;
4082   CL->Exit = Exit;
4083 
4084 #ifndef NDEBUG
4085   CL->assertOK();
4086 #endif
4087   return CL;
4088 }
4089 
4090 Expected<CanonicalLoopInfo *>
createCanonicalLoop(const LocationDescription & Loc,LoopBodyGenCallbackTy BodyGenCB,Value * TripCount,const Twine & Name)4091 OpenMPIRBuilder::createCanonicalLoop(const LocationDescription &Loc,
4092                                      LoopBodyGenCallbackTy BodyGenCB,
4093                                      Value *TripCount, const Twine &Name) {
4094   BasicBlock *BB = Loc.IP.getBlock();
4095   BasicBlock *NextBB = BB->getNextNode();
4096 
4097   CanonicalLoopInfo *CL = createLoopSkeleton(Loc.DL, TripCount, BB->getParent(),
4098                                              NextBB, NextBB, Name);
4099   BasicBlock *After = CL->getAfter();
4100 
4101   // If location is not set, don't connect the loop.
4102   if (updateToLocation(Loc)) {
4103     // Split the loop at the insertion point: Branch to the preheader and move
4104     // every following instruction to after the loop (the After BB). Also, the
4105     // new successor is the loop's after block.
4106     spliceBB(Builder, After, /*CreateBranch=*/false);
4107     Builder.CreateBr(CL->getPreheader());
4108   }
4109 
4110   // Emit the body content. We do it after connecting the loop to the CFG to
4111   // avoid that the callback encounters degenerate BBs.
4112   if (Error Err = BodyGenCB(CL->getBodyIP(), CL->getIndVar()))
4113     return Err;
4114 
4115 #ifndef NDEBUG
4116   CL->assertOK();
4117 #endif
4118   return CL;
4119 }
4120 
calculateCanonicalLoopTripCount(const LocationDescription & Loc,Value * Start,Value * Stop,Value * Step,bool IsSigned,bool InclusiveStop,const Twine & Name)4121 Value *OpenMPIRBuilder::calculateCanonicalLoopTripCount(
4122     const LocationDescription &Loc, Value *Start, Value *Stop, Value *Step,
4123     bool IsSigned, bool InclusiveStop, const Twine &Name) {
4124 
4125   // Consider the following difficulties (assuming 8-bit signed integers):
4126   //  * Adding \p Step to the loop counter which passes \p Stop may overflow:
4127   //      DO I = 1, 100, 50
4128   ///  * A \p Step of INT_MIN cannot not be normalized to a positive direction:
4129   //      DO I = 100, 0, -128
4130 
4131   // Start, Stop and Step must be of the same integer type.
4132   auto *IndVarTy = cast<IntegerType>(Start->getType());
4133   assert(IndVarTy == Stop->getType() && "Stop type mismatch");
4134   assert(IndVarTy == Step->getType() && "Step type mismatch");
4135 
4136   updateToLocation(Loc);
4137 
4138   ConstantInt *Zero = ConstantInt::get(IndVarTy, 0);
4139   ConstantInt *One = ConstantInt::get(IndVarTy, 1);
4140 
4141   // Like Step, but always positive.
4142   Value *Incr = Step;
4143 
4144   // Distance between Start and Stop; always positive.
4145   Value *Span;
4146 
4147   // Condition whether there are no iterations are executed at all, e.g. because
4148   // UB < LB.
4149   Value *ZeroCmp;
4150 
4151   if (IsSigned) {
4152     // Ensure that increment is positive. If not, negate and invert LB and UB.
4153     Value *IsNeg = Builder.CreateICmpSLT(Step, Zero);
4154     Incr = Builder.CreateSelect(IsNeg, Builder.CreateNeg(Step), Step);
4155     Value *LB = Builder.CreateSelect(IsNeg, Stop, Start);
4156     Value *UB = Builder.CreateSelect(IsNeg, Start, Stop);
4157     Span = Builder.CreateSub(UB, LB, "", false, true);
4158     ZeroCmp = Builder.CreateICmp(
4159         InclusiveStop ? CmpInst::ICMP_SLT : CmpInst::ICMP_SLE, UB, LB);
4160   } else {
4161     Span = Builder.CreateSub(Stop, Start, "", true);
4162     ZeroCmp = Builder.CreateICmp(
4163         InclusiveStop ? CmpInst::ICMP_ULT : CmpInst::ICMP_ULE, Stop, Start);
4164   }
4165 
4166   Value *CountIfLooping;
4167   if (InclusiveStop) {
4168     CountIfLooping = Builder.CreateAdd(Builder.CreateUDiv(Span, Incr), One);
4169   } else {
4170     // Avoid incrementing past stop since it could overflow.
4171     Value *CountIfTwo = Builder.CreateAdd(
4172         Builder.CreateUDiv(Builder.CreateSub(Span, One), Incr), One);
4173     Value *OneCmp = Builder.CreateICmp(CmpInst::ICMP_ULE, Span, Incr);
4174     CountIfLooping = Builder.CreateSelect(OneCmp, One, CountIfTwo);
4175   }
4176 
4177   return Builder.CreateSelect(ZeroCmp, Zero, CountIfLooping,
4178                               "omp_" + Name + ".tripcount");
4179 }
4180 
createCanonicalLoop(const LocationDescription & Loc,LoopBodyGenCallbackTy BodyGenCB,Value * Start,Value * Stop,Value * Step,bool IsSigned,bool InclusiveStop,InsertPointTy ComputeIP,const Twine & Name)4181 Expected<CanonicalLoopInfo *> OpenMPIRBuilder::createCanonicalLoop(
4182     const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
4183     Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
4184     InsertPointTy ComputeIP, const Twine &Name) {
4185   LocationDescription ComputeLoc =
4186       ComputeIP.isSet() ? LocationDescription(ComputeIP, Loc.DL) : Loc;
4187 
4188   Value *TripCount = calculateCanonicalLoopTripCount(
4189       ComputeLoc, Start, Stop, Step, IsSigned, InclusiveStop, Name);
4190 
4191   auto BodyGen = [=](InsertPointTy CodeGenIP, Value *IV) {
4192     Builder.restoreIP(CodeGenIP);
4193     Value *Span = Builder.CreateMul(IV, Step);
4194     Value *IndVar = Builder.CreateAdd(Span, Start);
4195     return BodyGenCB(Builder.saveIP(), IndVar);
4196   };
4197   LocationDescription LoopLoc =
4198       ComputeIP.isSet()
4199           ? Loc
4200           : LocationDescription(Builder.saveIP(),
4201                                 Builder.getCurrentDebugLocation());
4202   return createCanonicalLoop(LoopLoc, BodyGen, TripCount, Name);
4203 }
4204 
4205 // Returns an LLVM function to call for initializing loop bounds using OpenMP
4206 // static scheduling for composite `distribute parallel for` depending on
4207 // `type`. Only i32 and i64 are supported by the runtime. Always interpret
4208 // integers as unsigned similarly to CanonicalLoopInfo.
4209 static FunctionCallee
getKmpcDistForStaticInitForType(Type * Ty,Module & M,OpenMPIRBuilder & OMPBuilder)4210 getKmpcDistForStaticInitForType(Type *Ty, Module &M,
4211                                 OpenMPIRBuilder &OMPBuilder) {
4212   unsigned Bitwidth = Ty->getIntegerBitWidth();
4213   if (Bitwidth == 32)
4214     return OMPBuilder.getOrCreateRuntimeFunction(
4215         M, omp::RuntimeFunction::OMPRTL___kmpc_dist_for_static_init_4u);
4216   if (Bitwidth == 64)
4217     return OMPBuilder.getOrCreateRuntimeFunction(
4218         M, omp::RuntimeFunction::OMPRTL___kmpc_dist_for_static_init_8u);
4219   llvm_unreachable("unknown OpenMP loop iterator bitwidth");
4220 }
4221 
4222 // Returns an LLVM function to call for initializing loop bounds using OpenMP
4223 // static scheduling depending on `type`. Only i32 and i64 are supported by the
4224 // runtime. Always interpret integers as unsigned similarly to
4225 // CanonicalLoopInfo.
getKmpcForStaticInitForType(Type * Ty,Module & M,OpenMPIRBuilder & OMPBuilder)4226 static FunctionCallee getKmpcForStaticInitForType(Type *Ty, Module &M,
4227                                                   OpenMPIRBuilder &OMPBuilder) {
4228   unsigned Bitwidth = Ty->getIntegerBitWidth();
4229   if (Bitwidth == 32)
4230     return OMPBuilder.getOrCreateRuntimeFunction(
4231         M, omp::RuntimeFunction::OMPRTL___kmpc_for_static_init_4u);
4232   if (Bitwidth == 64)
4233     return OMPBuilder.getOrCreateRuntimeFunction(
4234         M, omp::RuntimeFunction::OMPRTL___kmpc_for_static_init_8u);
4235   llvm_unreachable("unknown OpenMP loop iterator bitwidth");
4236 }
4237 
applyStaticWorkshareLoop(DebugLoc DL,CanonicalLoopInfo * CLI,InsertPointTy AllocaIP,WorksharingLoopType LoopType,bool NeedsBarrier)4238 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::applyStaticWorkshareLoop(
4239     DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
4240     WorksharingLoopType LoopType, bool NeedsBarrier) {
4241   assert(CLI->isValid() && "Requires a valid canonical loop");
4242   assert(!isConflictIP(AllocaIP, CLI->getPreheaderIP()) &&
4243          "Require dedicated allocate IP");
4244 
4245   // Set up the source location value for OpenMP runtime.
4246   Builder.restoreIP(CLI->getPreheaderIP());
4247   Builder.SetCurrentDebugLocation(DL);
4248 
4249   uint32_t SrcLocStrSize;
4250   Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
4251   Value *SrcLoc = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4252 
4253   // Declare useful OpenMP runtime functions.
4254   Value *IV = CLI->getIndVar();
4255   Type *IVTy = IV->getType();
4256   FunctionCallee StaticInit =
4257       LoopType == WorksharingLoopType::DistributeForStaticLoop
4258           ? getKmpcDistForStaticInitForType(IVTy, M, *this)
4259           : getKmpcForStaticInitForType(IVTy, M, *this);
4260   FunctionCallee StaticFini =
4261       getOrCreateRuntimeFunction(M, omp::OMPRTL___kmpc_for_static_fini);
4262 
4263   // Allocate space for computed loop bounds as expected by the "init" function.
4264   Builder.SetInsertPoint(AllocaIP.getBlock()->getFirstNonPHIOrDbgOrAlloca());
4265 
4266   Type *I32Type = Type::getInt32Ty(M.getContext());
4267   Value *PLastIter = Builder.CreateAlloca(I32Type, nullptr, "p.lastiter");
4268   Value *PLowerBound = Builder.CreateAlloca(IVTy, nullptr, "p.lowerbound");
4269   Value *PUpperBound = Builder.CreateAlloca(IVTy, nullptr, "p.upperbound");
4270   Value *PStride = Builder.CreateAlloca(IVTy, nullptr, "p.stride");
4271   CLI->setLastIter(PLastIter);
4272 
4273   // At the end of the preheader, prepare for calling the "init" function by
4274   // storing the current loop bounds into the allocated space. A canonical loop
4275   // always iterates from 0 to trip-count with step 1. Note that "init" expects
4276   // and produces an inclusive upper bound.
4277   Builder.SetInsertPoint(CLI->getPreheader()->getTerminator());
4278   Constant *Zero = ConstantInt::get(IVTy, 0);
4279   Constant *One = ConstantInt::get(IVTy, 1);
4280   Builder.CreateStore(Zero, PLowerBound);
4281   Value *UpperBound = Builder.CreateSub(CLI->getTripCount(), One);
4282   Builder.CreateStore(UpperBound, PUpperBound);
4283   Builder.CreateStore(One, PStride);
4284 
4285   Value *ThreadNum = getOrCreateThreadID(SrcLoc);
4286 
4287   OMPScheduleType SchedType =
4288       (LoopType == WorksharingLoopType::DistributeStaticLoop)
4289           ? OMPScheduleType::OrderedDistribute
4290           : OMPScheduleType::UnorderedStatic;
4291   Constant *SchedulingType =
4292       ConstantInt::get(I32Type, static_cast<int>(SchedType));
4293 
4294   // Call the "init" function and update the trip count of the loop with the
4295   // value it produced.
4296   SmallVector<Value *, 10> Args(
4297       {SrcLoc, ThreadNum, SchedulingType, PLastIter, PLowerBound, PUpperBound});
4298   if (LoopType == WorksharingLoopType::DistributeForStaticLoop) {
4299     Value *PDistUpperBound =
4300         Builder.CreateAlloca(IVTy, nullptr, "p.distupperbound");
4301     Args.push_back(PDistUpperBound);
4302   }
4303   Args.append({PStride, One, Zero});
4304   Builder.CreateCall(StaticInit, Args);
4305   Value *LowerBound = Builder.CreateLoad(IVTy, PLowerBound);
4306   Value *InclusiveUpperBound = Builder.CreateLoad(IVTy, PUpperBound);
4307   Value *TripCountMinusOne = Builder.CreateSub(InclusiveUpperBound, LowerBound);
4308   Value *TripCount = Builder.CreateAdd(TripCountMinusOne, One);
4309   CLI->setTripCount(TripCount);
4310 
4311   // Update all uses of the induction variable except the one in the condition
4312   // block that compares it with the actual upper bound, and the increment in
4313   // the latch block.
4314 
4315   CLI->mapIndVar([&](Instruction *OldIV) -> Value * {
4316     Builder.SetInsertPoint(CLI->getBody(),
4317                            CLI->getBody()->getFirstInsertionPt());
4318     Builder.SetCurrentDebugLocation(DL);
4319     return Builder.CreateAdd(OldIV, LowerBound);
4320   });
4321 
4322   // In the "exit" block, call the "fini" function.
4323   Builder.SetInsertPoint(CLI->getExit(),
4324                          CLI->getExit()->getTerminator()->getIterator());
4325   Builder.CreateCall(StaticFini, {SrcLoc, ThreadNum});
4326 
4327   // Add the barrier if requested.
4328   if (NeedsBarrier) {
4329     InsertPointOrErrorTy BarrierIP =
4330         createBarrier(LocationDescription(Builder.saveIP(), DL),
4331                       omp::Directive::OMPD_for, /* ForceSimpleCall */ false,
4332                       /* CheckCancelFlag */ false);
4333     if (!BarrierIP)
4334       return BarrierIP.takeError();
4335   }
4336 
4337   InsertPointTy AfterIP = CLI->getAfterIP();
4338   CLI->invalidate();
4339 
4340   return AfterIP;
4341 }
4342 
4343 OpenMPIRBuilder::InsertPointOrErrorTy
applyStaticChunkedWorkshareLoop(DebugLoc DL,CanonicalLoopInfo * CLI,InsertPointTy AllocaIP,bool NeedsBarrier,Value * ChunkSize)4344 OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(DebugLoc DL,
4345                                                  CanonicalLoopInfo *CLI,
4346                                                  InsertPointTy AllocaIP,
4347                                                  bool NeedsBarrier,
4348                                                  Value *ChunkSize) {
4349   assert(CLI->isValid() && "Requires a valid canonical loop");
4350   assert(ChunkSize && "Chunk size is required");
4351 
4352   LLVMContext &Ctx = CLI->getFunction()->getContext();
4353   Value *IV = CLI->getIndVar();
4354   Value *OrigTripCount = CLI->getTripCount();
4355   Type *IVTy = IV->getType();
4356   assert(IVTy->getIntegerBitWidth() <= 64 &&
4357          "Max supported tripcount bitwidth is 64 bits");
4358   Type *InternalIVTy = IVTy->getIntegerBitWidth() <= 32 ? Type::getInt32Ty(Ctx)
4359                                                         : Type::getInt64Ty(Ctx);
4360   Type *I32Type = Type::getInt32Ty(M.getContext());
4361   Constant *Zero = ConstantInt::get(InternalIVTy, 0);
4362   Constant *One = ConstantInt::get(InternalIVTy, 1);
4363 
4364   // Declare useful OpenMP runtime functions.
4365   FunctionCallee StaticInit =
4366       getKmpcForStaticInitForType(InternalIVTy, M, *this);
4367   FunctionCallee StaticFini =
4368       getOrCreateRuntimeFunction(M, omp::OMPRTL___kmpc_for_static_fini);
4369 
4370   // Allocate space for computed loop bounds as expected by the "init" function.
4371   Builder.restoreIP(AllocaIP);
4372   Builder.SetCurrentDebugLocation(DL);
4373   Value *PLastIter = Builder.CreateAlloca(I32Type, nullptr, "p.lastiter");
4374   Value *PLowerBound =
4375       Builder.CreateAlloca(InternalIVTy, nullptr, "p.lowerbound");
4376   Value *PUpperBound =
4377       Builder.CreateAlloca(InternalIVTy, nullptr, "p.upperbound");
4378   Value *PStride = Builder.CreateAlloca(InternalIVTy, nullptr, "p.stride");
4379   CLI->setLastIter(PLastIter);
4380 
4381   // Set up the source location value for the OpenMP runtime.
4382   Builder.restoreIP(CLI->getPreheaderIP());
4383   Builder.SetCurrentDebugLocation(DL);
4384 
4385   // TODO: Detect overflow in ubsan or max-out with current tripcount.
4386   Value *CastedChunkSize =
4387       Builder.CreateZExtOrTrunc(ChunkSize, InternalIVTy, "chunksize");
4388   Value *CastedTripCount =
4389       Builder.CreateZExt(OrigTripCount, InternalIVTy, "tripcount");
4390 
4391   Constant *SchedulingType = ConstantInt::get(
4392       I32Type, static_cast<int>(OMPScheduleType::UnorderedStaticChunked));
4393   Builder.CreateStore(Zero, PLowerBound);
4394   Value *OrigUpperBound = Builder.CreateSub(CastedTripCount, One);
4395   Builder.CreateStore(OrigUpperBound, PUpperBound);
4396   Builder.CreateStore(One, PStride);
4397 
4398   // Call the "init" function and update the trip count of the loop with the
4399   // value it produced.
4400   uint32_t SrcLocStrSize;
4401   Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
4402   Value *SrcLoc = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4403   Value *ThreadNum = getOrCreateThreadID(SrcLoc);
4404   Builder.CreateCall(StaticInit,
4405                      {/*loc=*/SrcLoc, /*global_tid=*/ThreadNum,
4406                       /*schedtype=*/SchedulingType, /*plastiter=*/PLastIter,
4407                       /*plower=*/PLowerBound, /*pupper=*/PUpperBound,
4408                       /*pstride=*/PStride, /*incr=*/One,
4409                       /*chunk=*/CastedChunkSize});
4410 
4411   // Load values written by the "init" function.
4412   Value *FirstChunkStart =
4413       Builder.CreateLoad(InternalIVTy, PLowerBound, "omp_firstchunk.lb");
4414   Value *FirstChunkStop =
4415       Builder.CreateLoad(InternalIVTy, PUpperBound, "omp_firstchunk.ub");
4416   Value *FirstChunkEnd = Builder.CreateAdd(FirstChunkStop, One);
4417   Value *ChunkRange =
4418       Builder.CreateSub(FirstChunkEnd, FirstChunkStart, "omp_chunk.range");
4419   Value *NextChunkStride =
4420       Builder.CreateLoad(InternalIVTy, PStride, "omp_dispatch.stride");
4421 
4422   // Create outer "dispatch" loop for enumerating the chunks.
4423   BasicBlock *DispatchEnter = splitBB(Builder, true);
4424   Value *DispatchCounter;
4425 
4426   // It is safe to assume this didn't return an error because the callback
4427   // passed into createCanonicalLoop is the only possible error source, and it
4428   // always returns success.
4429   CanonicalLoopInfo *DispatchCLI = cantFail(createCanonicalLoop(
4430       {Builder.saveIP(), DL},
4431       [&](InsertPointTy BodyIP, Value *Counter) {
4432         DispatchCounter = Counter;
4433         return Error::success();
4434       },
4435       FirstChunkStart, CastedTripCount, NextChunkStride,
4436       /*IsSigned=*/false, /*InclusiveStop=*/false, /*ComputeIP=*/{},
4437       "dispatch"));
4438 
4439   // Remember the BasicBlocks of the dispatch loop we need, then invalidate to
4440   // not have to preserve the canonical invariant.
4441   BasicBlock *DispatchBody = DispatchCLI->getBody();
4442   BasicBlock *DispatchLatch = DispatchCLI->getLatch();
4443   BasicBlock *DispatchExit = DispatchCLI->getExit();
4444   BasicBlock *DispatchAfter = DispatchCLI->getAfter();
4445   DispatchCLI->invalidate();
4446 
4447   // Rewire the original loop to become the chunk loop inside the dispatch loop.
4448   redirectTo(DispatchAfter, CLI->getAfter(), DL);
4449   redirectTo(CLI->getExit(), DispatchLatch, DL);
4450   redirectTo(DispatchBody, DispatchEnter, DL);
4451 
4452   // Prepare the prolog of the chunk loop.
4453   Builder.restoreIP(CLI->getPreheaderIP());
4454   Builder.SetCurrentDebugLocation(DL);
4455 
4456   // Compute the number of iterations of the chunk loop.
4457   Builder.SetInsertPoint(CLI->getPreheader()->getTerminator());
4458   Value *ChunkEnd = Builder.CreateAdd(DispatchCounter, ChunkRange);
4459   Value *IsLastChunk =
4460       Builder.CreateICmpUGE(ChunkEnd, CastedTripCount, "omp_chunk.is_last");
4461   Value *CountUntilOrigTripCount =
4462       Builder.CreateSub(CastedTripCount, DispatchCounter);
4463   Value *ChunkTripCount = Builder.CreateSelect(
4464       IsLastChunk, CountUntilOrigTripCount, ChunkRange, "omp_chunk.tripcount");
4465   Value *BackcastedChunkTC =
4466       Builder.CreateTrunc(ChunkTripCount, IVTy, "omp_chunk.tripcount.trunc");
4467   CLI->setTripCount(BackcastedChunkTC);
4468 
4469   // Update all uses of the induction variable except the one in the condition
4470   // block that compares it with the actual upper bound, and the increment in
4471   // the latch block.
4472   Value *BackcastedDispatchCounter =
4473       Builder.CreateTrunc(DispatchCounter, IVTy, "omp_dispatch.iv.trunc");
4474   CLI->mapIndVar([&](Instruction *) -> Value * {
4475     Builder.restoreIP(CLI->getBodyIP());
4476     return Builder.CreateAdd(IV, BackcastedDispatchCounter);
4477   });
4478 
4479   // In the "exit" block, call the "fini" function.
4480   Builder.SetInsertPoint(DispatchExit, DispatchExit->getFirstInsertionPt());
4481   Builder.CreateCall(StaticFini, {SrcLoc, ThreadNum});
4482 
4483   // Add the barrier if requested.
4484   if (NeedsBarrier) {
4485     InsertPointOrErrorTy AfterIP =
4486         createBarrier(LocationDescription(Builder.saveIP(), DL), OMPD_for,
4487                       /*ForceSimpleCall=*/false, /*CheckCancelFlag=*/false);
4488     if (!AfterIP)
4489       return AfterIP.takeError();
4490   }
4491 
4492 #ifndef NDEBUG
4493   // Even though we currently do not support applying additional methods to it,
4494   // the chunk loop should remain a canonical loop.
4495   CLI->assertOK();
4496 #endif
4497 
4498   return InsertPointTy(DispatchAfter, DispatchAfter->getFirstInsertionPt());
4499 }
4500 
4501 // Returns an LLVM function to call for executing an OpenMP static worksharing
4502 // for loop depending on `type`. Only i32 and i64 are supported by the runtime.
4503 // Always interpret integers as unsigned similarly to CanonicalLoopInfo.
4504 static FunctionCallee
getKmpcForStaticLoopForType(Type * Ty,OpenMPIRBuilder * OMPBuilder,WorksharingLoopType LoopType)4505 getKmpcForStaticLoopForType(Type *Ty, OpenMPIRBuilder *OMPBuilder,
4506                             WorksharingLoopType LoopType) {
4507   unsigned Bitwidth = Ty->getIntegerBitWidth();
4508   Module &M = OMPBuilder->M;
4509   switch (LoopType) {
4510   case WorksharingLoopType::ForStaticLoop:
4511     if (Bitwidth == 32)
4512       return OMPBuilder->getOrCreateRuntimeFunction(
4513           M, omp::RuntimeFunction::OMPRTL___kmpc_for_static_loop_4u);
4514     if (Bitwidth == 64)
4515       return OMPBuilder->getOrCreateRuntimeFunction(
4516           M, omp::RuntimeFunction::OMPRTL___kmpc_for_static_loop_8u);
4517     break;
4518   case WorksharingLoopType::DistributeStaticLoop:
4519     if (Bitwidth == 32)
4520       return OMPBuilder->getOrCreateRuntimeFunction(
4521           M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_static_loop_4u);
4522     if (Bitwidth == 64)
4523       return OMPBuilder->getOrCreateRuntimeFunction(
4524           M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_static_loop_8u);
4525     break;
4526   case WorksharingLoopType::DistributeForStaticLoop:
4527     if (Bitwidth == 32)
4528       return OMPBuilder->getOrCreateRuntimeFunction(
4529           M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_for_static_loop_4u);
4530     if (Bitwidth == 64)
4531       return OMPBuilder->getOrCreateRuntimeFunction(
4532           M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_for_static_loop_8u);
4533     break;
4534   }
4535   if (Bitwidth != 32 && Bitwidth != 64) {
4536     llvm_unreachable("Unknown OpenMP loop iterator bitwidth");
4537   }
4538   llvm_unreachable("Unknown type of OpenMP worksharing loop");
4539 }
4540 
4541 // Inserts a call to proper OpenMP Device RTL function which handles
4542 // loop worksharing.
createTargetLoopWorkshareCall(OpenMPIRBuilder * OMPBuilder,WorksharingLoopType LoopType,BasicBlock * InsertBlock,Value * Ident,Value * LoopBodyArg,Value * TripCount,Function & LoopBodyFn)4543 static void createTargetLoopWorkshareCall(OpenMPIRBuilder *OMPBuilder,
4544                                           WorksharingLoopType LoopType,
4545                                           BasicBlock *InsertBlock, Value *Ident,
4546                                           Value *LoopBodyArg, Value *TripCount,
4547                                           Function &LoopBodyFn) {
4548   Type *TripCountTy = TripCount->getType();
4549   Module &M = OMPBuilder->M;
4550   IRBuilder<> &Builder = OMPBuilder->Builder;
4551   FunctionCallee RTLFn =
4552       getKmpcForStaticLoopForType(TripCountTy, OMPBuilder, LoopType);
4553   SmallVector<Value *, 8> RealArgs;
4554   RealArgs.push_back(Ident);
4555   RealArgs.push_back(&LoopBodyFn);
4556   RealArgs.push_back(LoopBodyArg);
4557   RealArgs.push_back(TripCount);
4558   if (LoopType == WorksharingLoopType::DistributeStaticLoop) {
4559     RealArgs.push_back(ConstantInt::get(TripCountTy, 0));
4560     Builder.restoreIP({InsertBlock, std::prev(InsertBlock->end())});
4561     Builder.CreateCall(RTLFn, RealArgs);
4562     return;
4563   }
4564   FunctionCallee RTLNumThreads = OMPBuilder->getOrCreateRuntimeFunction(
4565       M, omp::RuntimeFunction::OMPRTL_omp_get_num_threads);
4566   Builder.restoreIP({InsertBlock, std::prev(InsertBlock->end())});
4567   Value *NumThreads = Builder.CreateCall(RTLNumThreads, {});
4568 
4569   RealArgs.push_back(
4570       Builder.CreateZExtOrTrunc(NumThreads, TripCountTy, "num.threads.cast"));
4571   RealArgs.push_back(ConstantInt::get(TripCountTy, 0));
4572   if (LoopType == WorksharingLoopType::DistributeForStaticLoop) {
4573     RealArgs.push_back(ConstantInt::get(TripCountTy, 0));
4574   }
4575 
4576   Builder.CreateCall(RTLFn, RealArgs);
4577 }
4578 
workshareLoopTargetCallback(OpenMPIRBuilder * OMPIRBuilder,CanonicalLoopInfo * CLI,Value * Ident,Function & OutlinedFn,const SmallVector<Instruction *,4> & ToBeDeleted,WorksharingLoopType LoopType)4579 static void workshareLoopTargetCallback(
4580     OpenMPIRBuilder *OMPIRBuilder, CanonicalLoopInfo *CLI, Value *Ident,
4581     Function &OutlinedFn, const SmallVector<Instruction *, 4> &ToBeDeleted,
4582     WorksharingLoopType LoopType) {
4583   IRBuilder<> &Builder = OMPIRBuilder->Builder;
4584   BasicBlock *Preheader = CLI->getPreheader();
4585   Value *TripCount = CLI->getTripCount();
4586 
4587   // After loop body outling, the loop body contains only set up
4588   // of loop body argument structure and the call to the outlined
4589   // loop body function. Firstly, we need to move setup of loop body args
4590   // into loop preheader.
4591   Preheader->splice(std::prev(Preheader->end()), CLI->getBody(),
4592                     CLI->getBody()->begin(), std::prev(CLI->getBody()->end()));
4593 
4594   // The next step is to remove the whole loop. We do not it need anymore.
4595   // That's why make an unconditional branch from loop preheader to loop
4596   // exit block
4597   Builder.restoreIP({Preheader, Preheader->end()});
4598   Builder.SetCurrentDebugLocation(Preheader->getTerminator()->getDebugLoc());
4599   Preheader->getTerminator()->eraseFromParent();
4600   Builder.CreateBr(CLI->getExit());
4601 
4602   // Delete dead loop blocks
4603   OpenMPIRBuilder::OutlineInfo CleanUpInfo;
4604   SmallPtrSet<BasicBlock *, 32> RegionBlockSet;
4605   SmallVector<BasicBlock *, 32> BlocksToBeRemoved;
4606   CleanUpInfo.EntryBB = CLI->getHeader();
4607   CleanUpInfo.ExitBB = CLI->getExit();
4608   CleanUpInfo.collectBlocks(RegionBlockSet, BlocksToBeRemoved);
4609   DeleteDeadBlocks(BlocksToBeRemoved);
4610 
4611   // Find the instruction which corresponds to loop body argument structure
4612   // and remove the call to loop body function instruction.
4613   Value *LoopBodyArg;
4614   User *OutlinedFnUser = OutlinedFn.getUniqueUndroppableUser();
4615   assert(OutlinedFnUser &&
4616          "Expected unique undroppable user of outlined function");
4617   CallInst *OutlinedFnCallInstruction = dyn_cast<CallInst>(OutlinedFnUser);
4618   assert(OutlinedFnCallInstruction && "Expected outlined function call");
4619   assert((OutlinedFnCallInstruction->getParent() == Preheader) &&
4620          "Expected outlined function call to be located in loop preheader");
4621   // Check in case no argument structure has been passed.
4622   if (OutlinedFnCallInstruction->arg_size() > 1)
4623     LoopBodyArg = OutlinedFnCallInstruction->getArgOperand(1);
4624   else
4625     LoopBodyArg = Constant::getNullValue(Builder.getPtrTy());
4626   OutlinedFnCallInstruction->eraseFromParent();
4627 
4628   createTargetLoopWorkshareCall(OMPIRBuilder, LoopType, Preheader, Ident,
4629                                 LoopBodyArg, TripCount, OutlinedFn);
4630 
4631   for (auto &ToBeDeletedItem : ToBeDeleted)
4632     ToBeDeletedItem->eraseFromParent();
4633   CLI->invalidate();
4634 }
4635 
4636 OpenMPIRBuilder::InsertPointTy
applyWorkshareLoopTarget(DebugLoc DL,CanonicalLoopInfo * CLI,InsertPointTy AllocaIP,WorksharingLoopType LoopType)4637 OpenMPIRBuilder::applyWorkshareLoopTarget(DebugLoc DL, CanonicalLoopInfo *CLI,
4638                                           InsertPointTy AllocaIP,
4639                                           WorksharingLoopType LoopType) {
4640   uint32_t SrcLocStrSize;
4641   Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
4642   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4643 
4644   OutlineInfo OI;
4645   OI.OuterAllocaBB = CLI->getPreheader();
4646   Function *OuterFn = CLI->getPreheader()->getParent();
4647 
4648   // Instructions which need to be deleted at the end of code generation
4649   SmallVector<Instruction *, 4> ToBeDeleted;
4650 
4651   OI.OuterAllocaBB = AllocaIP.getBlock();
4652 
4653   // Mark the body loop as region which needs to be extracted
4654   OI.EntryBB = CLI->getBody();
4655   OI.ExitBB = CLI->getLatch()->splitBasicBlock(CLI->getLatch()->begin(),
4656                                                "omp.prelatch", true);
4657 
4658   // Prepare loop body for extraction
4659   Builder.restoreIP({CLI->getPreheader(), CLI->getPreheader()->begin()});
4660 
4661   // Insert new loop counter variable which will be used only in loop
4662   // body.
4663   AllocaInst *NewLoopCnt = Builder.CreateAlloca(CLI->getIndVarType(), 0, "");
4664   Instruction *NewLoopCntLoad =
4665       Builder.CreateLoad(CLI->getIndVarType(), NewLoopCnt);
4666   // New loop counter instructions are redundant in the loop preheader when
4667   // code generation for workshare loop is finshed. That's why mark them as
4668   // ready for deletion.
4669   ToBeDeleted.push_back(NewLoopCntLoad);
4670   ToBeDeleted.push_back(NewLoopCnt);
4671 
4672   // Analyse loop body region. Find all input variables which are used inside
4673   // loop body region.
4674   SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
4675   SmallVector<BasicBlock *, 32> Blocks;
4676   OI.collectBlocks(ParallelRegionBlockSet, Blocks);
4677 
4678   CodeExtractorAnalysisCache CEAC(*OuterFn);
4679   CodeExtractor Extractor(Blocks,
4680                           /* DominatorTree */ nullptr,
4681                           /* AggregateArgs */ true,
4682                           /* BlockFrequencyInfo */ nullptr,
4683                           /* BranchProbabilityInfo */ nullptr,
4684                           /* AssumptionCache */ nullptr,
4685                           /* AllowVarArgs */ true,
4686                           /* AllowAlloca */ true,
4687                           /* AllocationBlock */ CLI->getPreheader(),
4688                           /* Suffix */ ".omp_wsloop",
4689                           /* AggrArgsIn0AddrSpace */ true);
4690 
4691   BasicBlock *CommonExit = nullptr;
4692   SetVector<Value *> SinkingCands, HoistingCands;
4693 
4694   // Find allocas outside the loop body region which are used inside loop
4695   // body
4696   Extractor.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
4697 
4698   // We need to model loop body region as the function f(cnt, loop_arg).
4699   // That's why we replace loop induction variable by the new counter
4700   // which will be one of loop body function argument
4701   SmallVector<User *> Users(CLI->getIndVar()->user_begin(),
4702                             CLI->getIndVar()->user_end());
4703   for (auto Use : Users) {
4704     if (Instruction *Inst = dyn_cast<Instruction>(Use)) {
4705       if (ParallelRegionBlockSet.count(Inst->getParent())) {
4706         Inst->replaceUsesOfWith(CLI->getIndVar(), NewLoopCntLoad);
4707       }
4708     }
4709   }
4710   // Make sure that loop counter variable is not merged into loop body
4711   // function argument structure and it is passed as separate variable
4712   OI.ExcludeArgsFromAggregate.push_back(NewLoopCntLoad);
4713 
4714   // PostOutline CB is invoked when loop body function is outlined and
4715   // loop body is replaced by call to outlined function. We need to add
4716   // call to OpenMP device rtl inside loop preheader. OpenMP device rtl
4717   // function will handle loop control logic.
4718   //
4719   OI.PostOutlineCB = [=, ToBeDeletedVec =
4720                              std::move(ToBeDeleted)](Function &OutlinedFn) {
4721     workshareLoopTargetCallback(this, CLI, Ident, OutlinedFn, ToBeDeletedVec,
4722                                 LoopType);
4723   };
4724   addOutlineInfo(std::move(OI));
4725   return CLI->getAfterIP();
4726 }
4727 
applyWorkshareLoop(DebugLoc DL,CanonicalLoopInfo * CLI,InsertPointTy AllocaIP,bool NeedsBarrier,omp::ScheduleKind SchedKind,Value * ChunkSize,bool HasSimdModifier,bool HasMonotonicModifier,bool HasNonmonotonicModifier,bool HasOrderedClause,WorksharingLoopType LoopType)4728 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::applyWorkshareLoop(
4729     DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
4730     bool NeedsBarrier, omp::ScheduleKind SchedKind, Value *ChunkSize,
4731     bool HasSimdModifier, bool HasMonotonicModifier,
4732     bool HasNonmonotonicModifier, bool HasOrderedClause,
4733     WorksharingLoopType LoopType) {
4734   if (Config.isTargetDevice())
4735     return applyWorkshareLoopTarget(DL, CLI, AllocaIP, LoopType);
4736   OMPScheduleType EffectiveScheduleType = computeOpenMPScheduleType(
4737       SchedKind, ChunkSize, HasSimdModifier, HasMonotonicModifier,
4738       HasNonmonotonicModifier, HasOrderedClause);
4739 
4740   bool IsOrdered = (EffectiveScheduleType & OMPScheduleType::ModifierOrdered) ==
4741                    OMPScheduleType::ModifierOrdered;
4742   switch (EffectiveScheduleType & ~OMPScheduleType::ModifierMask) {
4743   case OMPScheduleType::BaseStatic:
4744     assert(!ChunkSize && "No chunk size with static-chunked schedule");
4745     if (IsOrdered)
4746       return applyDynamicWorkshareLoop(DL, CLI, AllocaIP, EffectiveScheduleType,
4747                                        NeedsBarrier, ChunkSize);
4748     // FIXME: Monotonicity ignored?
4749     return applyStaticWorkshareLoop(DL, CLI, AllocaIP, LoopType, NeedsBarrier);
4750 
4751   case OMPScheduleType::BaseStaticChunked:
4752     if (IsOrdered)
4753       return applyDynamicWorkshareLoop(DL, CLI, AllocaIP, EffectiveScheduleType,
4754                                        NeedsBarrier, ChunkSize);
4755     // FIXME: Monotonicity ignored?
4756     return applyStaticChunkedWorkshareLoop(DL, CLI, AllocaIP, NeedsBarrier,
4757                                            ChunkSize);
4758 
4759   case OMPScheduleType::BaseRuntime:
4760   case OMPScheduleType::BaseAuto:
4761   case OMPScheduleType::BaseGreedy:
4762   case OMPScheduleType::BaseBalanced:
4763   case OMPScheduleType::BaseSteal:
4764   case OMPScheduleType::BaseGuidedSimd:
4765   case OMPScheduleType::BaseRuntimeSimd:
4766     assert(!ChunkSize &&
4767            "schedule type does not support user-defined chunk sizes");
4768     [[fallthrough]];
4769   case OMPScheduleType::BaseDynamicChunked:
4770   case OMPScheduleType::BaseGuidedChunked:
4771   case OMPScheduleType::BaseGuidedIterativeChunked:
4772   case OMPScheduleType::BaseGuidedAnalyticalChunked:
4773   case OMPScheduleType::BaseStaticBalancedChunked:
4774     return applyDynamicWorkshareLoop(DL, CLI, AllocaIP, EffectiveScheduleType,
4775                                      NeedsBarrier, ChunkSize);
4776 
4777   default:
4778     llvm_unreachable("Unknown/unimplemented schedule kind");
4779   }
4780 }
4781 
4782 /// Returns an LLVM function to call for initializing loop bounds using OpenMP
4783 /// dynamic scheduling depending on `type`. Only i32 and i64 are supported by
4784 /// the runtime. Always interpret integers as unsigned similarly to
4785 /// CanonicalLoopInfo.
4786 static FunctionCallee
getKmpcForDynamicInitForType(Type * Ty,Module & M,OpenMPIRBuilder & OMPBuilder)4787 getKmpcForDynamicInitForType(Type *Ty, Module &M, OpenMPIRBuilder &OMPBuilder) {
4788   unsigned Bitwidth = Ty->getIntegerBitWidth();
4789   if (Bitwidth == 32)
4790     return OMPBuilder.getOrCreateRuntimeFunction(
4791         M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_init_4u);
4792   if (Bitwidth == 64)
4793     return OMPBuilder.getOrCreateRuntimeFunction(
4794         M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_init_8u);
4795   llvm_unreachable("unknown OpenMP loop iterator bitwidth");
4796 }
4797 
4798 /// Returns an LLVM function to call for updating the next loop using OpenMP
4799 /// dynamic scheduling depending on `type`. Only i32 and i64 are supported by
4800 /// the runtime. Always interpret integers as unsigned similarly to
4801 /// CanonicalLoopInfo.
4802 static FunctionCallee
getKmpcForDynamicNextForType(Type * Ty,Module & M,OpenMPIRBuilder & OMPBuilder)4803 getKmpcForDynamicNextForType(Type *Ty, Module &M, OpenMPIRBuilder &OMPBuilder) {
4804   unsigned Bitwidth = Ty->getIntegerBitWidth();
4805   if (Bitwidth == 32)
4806     return OMPBuilder.getOrCreateRuntimeFunction(
4807         M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_next_4u);
4808   if (Bitwidth == 64)
4809     return OMPBuilder.getOrCreateRuntimeFunction(
4810         M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_next_8u);
4811   llvm_unreachable("unknown OpenMP loop iterator bitwidth");
4812 }
4813 
4814 /// Returns an LLVM function to call for finalizing the dynamic loop using
4815 /// depending on `type`. Only i32 and i64 are supported by the runtime. Always
4816 /// interpret integers as unsigned similarly to CanonicalLoopInfo.
4817 static FunctionCallee
getKmpcForDynamicFiniForType(Type * Ty,Module & M,OpenMPIRBuilder & OMPBuilder)4818 getKmpcForDynamicFiniForType(Type *Ty, Module &M, OpenMPIRBuilder &OMPBuilder) {
4819   unsigned Bitwidth = Ty->getIntegerBitWidth();
4820   if (Bitwidth == 32)
4821     return OMPBuilder.getOrCreateRuntimeFunction(
4822         M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_fini_4u);
4823   if (Bitwidth == 64)
4824     return OMPBuilder.getOrCreateRuntimeFunction(
4825         M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_fini_8u);
4826   llvm_unreachable("unknown OpenMP loop iterator bitwidth");
4827 }
4828 
4829 OpenMPIRBuilder::InsertPointOrErrorTy
applyDynamicWorkshareLoop(DebugLoc DL,CanonicalLoopInfo * CLI,InsertPointTy AllocaIP,OMPScheduleType SchedType,bool NeedsBarrier,Value * Chunk)4830 OpenMPIRBuilder::applyDynamicWorkshareLoop(DebugLoc DL, CanonicalLoopInfo *CLI,
4831                                            InsertPointTy AllocaIP,
4832                                            OMPScheduleType SchedType,
4833                                            bool NeedsBarrier, Value *Chunk) {
4834   assert(CLI->isValid() && "Requires a valid canonical loop");
4835   assert(!isConflictIP(AllocaIP, CLI->getPreheaderIP()) &&
4836          "Require dedicated allocate IP");
4837   assert(isValidWorkshareLoopScheduleType(SchedType) &&
4838          "Require valid schedule type");
4839 
4840   bool Ordered = (SchedType & OMPScheduleType::ModifierOrdered) ==
4841                  OMPScheduleType::ModifierOrdered;
4842 
4843   // Set up the source location value for OpenMP runtime.
4844   Builder.SetCurrentDebugLocation(DL);
4845 
4846   uint32_t SrcLocStrSize;
4847   Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
4848   Value *SrcLoc = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4849 
4850   // Declare useful OpenMP runtime functions.
4851   Value *IV = CLI->getIndVar();
4852   Type *IVTy = IV->getType();
4853   FunctionCallee DynamicInit = getKmpcForDynamicInitForType(IVTy, M, *this);
4854   FunctionCallee DynamicNext = getKmpcForDynamicNextForType(IVTy, M, *this);
4855 
4856   // Allocate space for computed loop bounds as expected by the "init" function.
4857   Builder.SetInsertPoint(AllocaIP.getBlock()->getFirstNonPHIOrDbgOrAlloca());
4858   Type *I32Type = Type::getInt32Ty(M.getContext());
4859   Value *PLastIter = Builder.CreateAlloca(I32Type, nullptr, "p.lastiter");
4860   Value *PLowerBound = Builder.CreateAlloca(IVTy, nullptr, "p.lowerbound");
4861   Value *PUpperBound = Builder.CreateAlloca(IVTy, nullptr, "p.upperbound");
4862   Value *PStride = Builder.CreateAlloca(IVTy, nullptr, "p.stride");
4863   CLI->setLastIter(PLastIter);
4864 
4865   // At the end of the preheader, prepare for calling the "init" function by
4866   // storing the current loop bounds into the allocated space. A canonical loop
4867   // always iterates from 0 to trip-count with step 1. Note that "init" expects
4868   // and produces an inclusive upper bound.
4869   BasicBlock *PreHeader = CLI->getPreheader();
4870   Builder.SetInsertPoint(PreHeader->getTerminator());
4871   Constant *One = ConstantInt::get(IVTy, 1);
4872   Builder.CreateStore(One, PLowerBound);
4873   Value *UpperBound = CLI->getTripCount();
4874   Builder.CreateStore(UpperBound, PUpperBound);
4875   Builder.CreateStore(One, PStride);
4876 
4877   BasicBlock *Header = CLI->getHeader();
4878   BasicBlock *Exit = CLI->getExit();
4879   BasicBlock *Cond = CLI->getCond();
4880   BasicBlock *Latch = CLI->getLatch();
4881   InsertPointTy AfterIP = CLI->getAfterIP();
4882 
4883   // The CLI will be "broken" in the code below, as the loop is no longer
4884   // a valid canonical loop.
4885 
4886   if (!Chunk)
4887     Chunk = One;
4888 
4889   Value *ThreadNum = getOrCreateThreadID(SrcLoc);
4890 
4891   Constant *SchedulingType =
4892       ConstantInt::get(I32Type, static_cast<int>(SchedType));
4893 
4894   // Call the "init" function.
4895   Builder.CreateCall(DynamicInit,
4896                      {SrcLoc, ThreadNum, SchedulingType, /* LowerBound */ One,
4897                       UpperBound, /* step */ One, Chunk});
4898 
4899   // An outer loop around the existing one.
4900   BasicBlock *OuterCond = BasicBlock::Create(
4901       PreHeader->getContext(), Twine(PreHeader->getName()) + ".outer.cond",
4902       PreHeader->getParent());
4903   // This needs to be 32-bit always, so can't use the IVTy Zero above.
4904   Builder.SetInsertPoint(OuterCond, OuterCond->getFirstInsertionPt());
4905   Value *Res =
4906       Builder.CreateCall(DynamicNext, {SrcLoc, ThreadNum, PLastIter,
4907                                        PLowerBound, PUpperBound, PStride});
4908   Constant *Zero32 = ConstantInt::get(I32Type, 0);
4909   Value *MoreWork = Builder.CreateCmp(CmpInst::ICMP_NE, Res, Zero32);
4910   Value *LowerBound =
4911       Builder.CreateSub(Builder.CreateLoad(IVTy, PLowerBound), One, "lb");
4912   Builder.CreateCondBr(MoreWork, Header, Exit);
4913 
4914   // Change PHI-node in loop header to use outer cond rather than preheader,
4915   // and set IV to the LowerBound.
4916   Instruction *Phi = &Header->front();
4917   auto *PI = cast<PHINode>(Phi);
4918   PI->setIncomingBlock(0, OuterCond);
4919   PI->setIncomingValue(0, LowerBound);
4920 
4921   // Then set the pre-header to jump to the OuterCond
4922   Instruction *Term = PreHeader->getTerminator();
4923   auto *Br = cast<BranchInst>(Term);
4924   Br->setSuccessor(0, OuterCond);
4925 
4926   // Modify the inner condition:
4927   // * Use the UpperBound returned from the DynamicNext call.
4928   // * jump to the loop outer loop when done with one of the inner loops.
4929   Builder.SetInsertPoint(Cond, Cond->getFirstInsertionPt());
4930   UpperBound = Builder.CreateLoad(IVTy, PUpperBound, "ub");
4931   Instruction *Comp = &*Builder.GetInsertPoint();
4932   auto *CI = cast<CmpInst>(Comp);
4933   CI->setOperand(1, UpperBound);
4934   // Redirect the inner exit to branch to outer condition.
4935   Instruction *Branch = &Cond->back();
4936   auto *BI = cast<BranchInst>(Branch);
4937   assert(BI->getSuccessor(1) == Exit);
4938   BI->setSuccessor(1, OuterCond);
4939 
4940   // Call the "fini" function if "ordered" is present in wsloop directive.
4941   if (Ordered) {
4942     Builder.SetInsertPoint(&Latch->back());
4943     FunctionCallee DynamicFini = getKmpcForDynamicFiniForType(IVTy, M, *this);
4944     Builder.CreateCall(DynamicFini, {SrcLoc, ThreadNum});
4945   }
4946 
4947   // Add the barrier if requested.
4948   if (NeedsBarrier) {
4949     Builder.SetInsertPoint(&Exit->back());
4950     InsertPointOrErrorTy BarrierIP =
4951         createBarrier(LocationDescription(Builder.saveIP(), DL),
4952                       omp::Directive::OMPD_for, /* ForceSimpleCall */ false,
4953                       /* CheckCancelFlag */ false);
4954     if (!BarrierIP)
4955       return BarrierIP.takeError();
4956   }
4957 
4958   CLI->invalidate();
4959   return AfterIP;
4960 }
4961 
4962 /// Redirect all edges that branch to \p OldTarget to \p NewTarget. That is,
4963 /// after this \p OldTarget will be orphaned.
redirectAllPredecessorsTo(BasicBlock * OldTarget,BasicBlock * NewTarget,DebugLoc DL)4964 static void redirectAllPredecessorsTo(BasicBlock *OldTarget,
4965                                       BasicBlock *NewTarget, DebugLoc DL) {
4966   for (BasicBlock *Pred : make_early_inc_range(predecessors(OldTarget)))
4967     redirectTo(Pred, NewTarget, DL);
4968 }
4969 
4970 /// Determine which blocks in \p BBs are reachable from outside and remove the
4971 /// ones that are not reachable from the function.
removeUnusedBlocksFromParent(ArrayRef<BasicBlock * > BBs)4972 static void removeUnusedBlocksFromParent(ArrayRef<BasicBlock *> BBs) {
4973   SmallPtrSet<BasicBlock *, 6> BBsToErase(llvm::from_range, BBs);
4974   auto HasRemainingUses = [&BBsToErase](BasicBlock *BB) {
4975     for (Use &U : BB->uses()) {
4976       auto *UseInst = dyn_cast<Instruction>(U.getUser());
4977       if (!UseInst)
4978         continue;
4979       if (BBsToErase.count(UseInst->getParent()))
4980         continue;
4981       return true;
4982     }
4983     return false;
4984   };
4985 
4986   while (BBsToErase.remove_if(HasRemainingUses)) {
4987     // Try again if anything was removed.
4988   }
4989 
4990   SmallVector<BasicBlock *, 7> BBVec(BBsToErase.begin(), BBsToErase.end());
4991   DeleteDeadBlocks(BBVec);
4992 }
4993 
4994 CanonicalLoopInfo *
collapseLoops(DebugLoc DL,ArrayRef<CanonicalLoopInfo * > Loops,InsertPointTy ComputeIP)4995 OpenMPIRBuilder::collapseLoops(DebugLoc DL, ArrayRef<CanonicalLoopInfo *> Loops,
4996                                InsertPointTy ComputeIP) {
4997   assert(Loops.size() >= 1 && "At least one loop required");
4998   size_t NumLoops = Loops.size();
4999 
5000   // Nothing to do if there is already just one loop.
5001   if (NumLoops == 1)
5002     return Loops.front();
5003 
5004   CanonicalLoopInfo *Outermost = Loops.front();
5005   CanonicalLoopInfo *Innermost = Loops.back();
5006   BasicBlock *OrigPreheader = Outermost->getPreheader();
5007   BasicBlock *OrigAfter = Outermost->getAfter();
5008   Function *F = OrigPreheader->getParent();
5009 
5010   // Loop control blocks that may become orphaned later.
5011   SmallVector<BasicBlock *, 12> OldControlBBs;
5012   OldControlBBs.reserve(6 * Loops.size());
5013   for (CanonicalLoopInfo *Loop : Loops)
5014     Loop->collectControlBlocks(OldControlBBs);
5015 
5016   // Setup the IRBuilder for inserting the trip count computation.
5017   Builder.SetCurrentDebugLocation(DL);
5018   if (ComputeIP.isSet())
5019     Builder.restoreIP(ComputeIP);
5020   else
5021     Builder.restoreIP(Outermost->getPreheaderIP());
5022 
5023   // Derive the collapsed' loop trip count.
5024   // TODO: Find common/largest indvar type.
5025   Value *CollapsedTripCount = nullptr;
5026   for (CanonicalLoopInfo *L : Loops) {
5027     assert(L->isValid() &&
5028            "All loops to collapse must be valid canonical loops");
5029     Value *OrigTripCount = L->getTripCount();
5030     if (!CollapsedTripCount) {
5031       CollapsedTripCount = OrigTripCount;
5032       continue;
5033     }
5034 
5035     // TODO: Enable UndefinedSanitizer to diagnose an overflow here.
5036     CollapsedTripCount = Builder.CreateMul(CollapsedTripCount, OrigTripCount,
5037                                            {}, /*HasNUW=*/true);
5038   }
5039 
5040   // Create the collapsed loop control flow.
5041   CanonicalLoopInfo *Result =
5042       createLoopSkeleton(DL, CollapsedTripCount, F,
5043                          OrigPreheader->getNextNode(), OrigAfter, "collapsed");
5044 
5045   // Build the collapsed loop body code.
5046   // Start with deriving the input loop induction variables from the collapsed
5047   // one, using a divmod scheme. To preserve the original loops' order, the
5048   // innermost loop use the least significant bits.
5049   Builder.restoreIP(Result->getBodyIP());
5050 
5051   Value *Leftover = Result->getIndVar();
5052   SmallVector<Value *> NewIndVars;
5053   NewIndVars.resize(NumLoops);
5054   for (int i = NumLoops - 1; i >= 1; --i) {
5055     Value *OrigTripCount = Loops[i]->getTripCount();
5056 
5057     Value *NewIndVar = Builder.CreateURem(Leftover, OrigTripCount);
5058     NewIndVars[i] = NewIndVar;
5059 
5060     Leftover = Builder.CreateUDiv(Leftover, OrigTripCount);
5061   }
5062   // Outermost loop gets all the remaining bits.
5063   NewIndVars[0] = Leftover;
5064 
5065   // Construct the loop body control flow.
5066   // We progressively construct the branch structure following in direction of
5067   // the control flow, from the leading in-between code, the loop nest body, the
5068   // trailing in-between code, and rejoining the collapsed loop's latch.
5069   // ContinueBlock and ContinuePred keep track of the source(s) of next edge. If
5070   // the ContinueBlock is set, continue with that block. If ContinuePred, use
5071   // its predecessors as sources.
5072   BasicBlock *ContinueBlock = Result->getBody();
5073   BasicBlock *ContinuePred = nullptr;
5074   auto ContinueWith = [&ContinueBlock, &ContinuePred, DL](BasicBlock *Dest,
5075                                                           BasicBlock *NextSrc) {
5076     if (ContinueBlock)
5077       redirectTo(ContinueBlock, Dest, DL);
5078     else
5079       redirectAllPredecessorsTo(ContinuePred, Dest, DL);
5080 
5081     ContinueBlock = nullptr;
5082     ContinuePred = NextSrc;
5083   };
5084 
5085   // The code before the nested loop of each level.
5086   // Because we are sinking it into the nest, it will be executed more often
5087   // that the original loop. More sophisticated schemes could keep track of what
5088   // the in-between code is and instantiate it only once per thread.
5089   for (size_t i = 0; i < NumLoops - 1; ++i)
5090     ContinueWith(Loops[i]->getBody(), Loops[i + 1]->getHeader());
5091 
5092   // Connect the loop nest body.
5093   ContinueWith(Innermost->getBody(), Innermost->getLatch());
5094 
5095   // The code after the nested loop at each level.
5096   for (size_t i = NumLoops - 1; i > 0; --i)
5097     ContinueWith(Loops[i]->getAfter(), Loops[i - 1]->getLatch());
5098 
5099   // Connect the finished loop to the collapsed loop latch.
5100   ContinueWith(Result->getLatch(), nullptr);
5101 
5102   // Replace the input loops with the new collapsed loop.
5103   redirectTo(Outermost->getPreheader(), Result->getPreheader(), DL);
5104   redirectTo(Result->getAfter(), Outermost->getAfter(), DL);
5105 
5106   // Replace the input loop indvars with the derived ones.
5107   for (size_t i = 0; i < NumLoops; ++i)
5108     Loops[i]->getIndVar()->replaceAllUsesWith(NewIndVars[i]);
5109 
5110   // Remove unused parts of the input loops.
5111   removeUnusedBlocksFromParent(OldControlBBs);
5112 
5113   for (CanonicalLoopInfo *L : Loops)
5114     L->invalidate();
5115 
5116 #ifndef NDEBUG
5117   Result->assertOK();
5118 #endif
5119   return Result;
5120 }
5121 
5122 std::vector<CanonicalLoopInfo *>
tileLoops(DebugLoc DL,ArrayRef<CanonicalLoopInfo * > Loops,ArrayRef<Value * > TileSizes)5123 OpenMPIRBuilder::tileLoops(DebugLoc DL, ArrayRef<CanonicalLoopInfo *> Loops,
5124                            ArrayRef<Value *> TileSizes) {
5125   assert(TileSizes.size() == Loops.size() &&
5126          "Must pass as many tile sizes as there are loops");
5127   int NumLoops = Loops.size();
5128   assert(NumLoops >= 1 && "At least one loop to tile required");
5129 
5130   CanonicalLoopInfo *OutermostLoop = Loops.front();
5131   CanonicalLoopInfo *InnermostLoop = Loops.back();
5132   Function *F = OutermostLoop->getBody()->getParent();
5133   BasicBlock *InnerEnter = InnermostLoop->getBody();
5134   BasicBlock *InnerLatch = InnermostLoop->getLatch();
5135 
5136   // Loop control blocks that may become orphaned later.
5137   SmallVector<BasicBlock *, 12> OldControlBBs;
5138   OldControlBBs.reserve(6 * Loops.size());
5139   for (CanonicalLoopInfo *Loop : Loops)
5140     Loop->collectControlBlocks(OldControlBBs);
5141 
5142   // Collect original trip counts and induction variable to be accessible by
5143   // index. Also, the structure of the original loops is not preserved during
5144   // the construction of the tiled loops, so do it before we scavenge the BBs of
5145   // any original CanonicalLoopInfo.
5146   SmallVector<Value *, 4> OrigTripCounts, OrigIndVars;
5147   for (CanonicalLoopInfo *L : Loops) {
5148     assert(L->isValid() && "All input loops must be valid canonical loops");
5149     OrigTripCounts.push_back(L->getTripCount());
5150     OrigIndVars.push_back(L->getIndVar());
5151   }
5152 
5153   // Collect the code between loop headers. These may contain SSA definitions
5154   // that are used in the loop nest body. To be usable with in the innermost
5155   // body, these BasicBlocks will be sunk into the loop nest body. That is,
5156   // these instructions may be executed more often than before the tiling.
5157   // TODO: It would be sufficient to only sink them into body of the
5158   // corresponding tile loop.
5159   SmallVector<std::pair<BasicBlock *, BasicBlock *>, 4> InbetweenCode;
5160   for (int i = 0; i < NumLoops - 1; ++i) {
5161     CanonicalLoopInfo *Surrounding = Loops[i];
5162     CanonicalLoopInfo *Nested = Loops[i + 1];
5163 
5164     BasicBlock *EnterBB = Surrounding->getBody();
5165     BasicBlock *ExitBB = Nested->getHeader();
5166     InbetweenCode.emplace_back(EnterBB, ExitBB);
5167   }
5168 
5169   // Compute the trip counts of the floor loops.
5170   Builder.SetCurrentDebugLocation(DL);
5171   Builder.restoreIP(OutermostLoop->getPreheaderIP());
5172   SmallVector<Value *, 4> FloorCount, FloorRems;
5173   for (int i = 0; i < NumLoops; ++i) {
5174     Value *TileSize = TileSizes[i];
5175     Value *OrigTripCount = OrigTripCounts[i];
5176     Type *IVType = OrigTripCount->getType();
5177 
5178     Value *FloorTripCount = Builder.CreateUDiv(OrigTripCount, TileSize);
5179     Value *FloorTripRem = Builder.CreateURem(OrigTripCount, TileSize);
5180 
5181     // 0 if tripcount divides the tilesize, 1 otherwise.
5182     // 1 means we need an additional iteration for a partial tile.
5183     //
5184     // Unfortunately we cannot just use the roundup-formula
5185     //   (tripcount + tilesize - 1)/tilesize
5186     // because the summation might overflow. We do not want introduce undefined
5187     // behavior when the untiled loop nest did not.
5188     Value *FloorTripOverflow =
5189         Builder.CreateICmpNE(FloorTripRem, ConstantInt::get(IVType, 0));
5190 
5191     FloorTripOverflow = Builder.CreateZExt(FloorTripOverflow, IVType);
5192     FloorTripCount =
5193         Builder.CreateAdd(FloorTripCount, FloorTripOverflow,
5194                           "omp_floor" + Twine(i) + ".tripcount", true);
5195 
5196     // Remember some values for later use.
5197     FloorCount.push_back(FloorTripCount);
5198     FloorRems.push_back(FloorTripRem);
5199   }
5200 
5201   // Generate the new loop nest, from the outermost to the innermost.
5202   std::vector<CanonicalLoopInfo *> Result;
5203   Result.reserve(NumLoops * 2);
5204 
5205   // The basic block of the surrounding loop that enters the nest generated
5206   // loop.
5207   BasicBlock *Enter = OutermostLoop->getPreheader();
5208 
5209   // The basic block of the surrounding loop where the inner code should
5210   // continue.
5211   BasicBlock *Continue = OutermostLoop->getAfter();
5212 
5213   // Where the next loop basic block should be inserted.
5214   BasicBlock *OutroInsertBefore = InnermostLoop->getExit();
5215 
5216   auto EmbeddNewLoop =
5217       [this, DL, F, InnerEnter, &Enter, &Continue, &OutroInsertBefore](
5218           Value *TripCount, const Twine &Name) -> CanonicalLoopInfo * {
5219     CanonicalLoopInfo *EmbeddedLoop = createLoopSkeleton(
5220         DL, TripCount, F, InnerEnter, OutroInsertBefore, Name);
5221     redirectTo(Enter, EmbeddedLoop->getPreheader(), DL);
5222     redirectTo(EmbeddedLoop->getAfter(), Continue, DL);
5223 
5224     // Setup the position where the next embedded loop connects to this loop.
5225     Enter = EmbeddedLoop->getBody();
5226     Continue = EmbeddedLoop->getLatch();
5227     OutroInsertBefore = EmbeddedLoop->getLatch();
5228     return EmbeddedLoop;
5229   };
5230 
5231   auto EmbeddNewLoops = [&Result, &EmbeddNewLoop](ArrayRef<Value *> TripCounts,
5232                                                   const Twine &NameBase) {
5233     for (auto P : enumerate(TripCounts)) {
5234       CanonicalLoopInfo *EmbeddedLoop =
5235           EmbeddNewLoop(P.value(), NameBase + Twine(P.index()));
5236       Result.push_back(EmbeddedLoop);
5237     }
5238   };
5239 
5240   EmbeddNewLoops(FloorCount, "floor");
5241 
5242   // Within the innermost floor loop, emit the code that computes the tile
5243   // sizes.
5244   Builder.SetInsertPoint(Enter->getTerminator());
5245   SmallVector<Value *, 4> TileCounts;
5246   for (int i = 0; i < NumLoops; ++i) {
5247     CanonicalLoopInfo *FloorLoop = Result[i];
5248     Value *TileSize = TileSizes[i];
5249 
5250     Value *FloorIsEpilogue =
5251         Builder.CreateICmpEQ(FloorLoop->getIndVar(), FloorCount[i]);
5252     Value *TileTripCount =
5253         Builder.CreateSelect(FloorIsEpilogue, FloorRems[i], TileSize);
5254 
5255     TileCounts.push_back(TileTripCount);
5256   }
5257 
5258   // Create the tile loops.
5259   EmbeddNewLoops(TileCounts, "tile");
5260 
5261   // Insert the inbetween code into the body.
5262   BasicBlock *BodyEnter = Enter;
5263   BasicBlock *BodyEntered = nullptr;
5264   for (std::pair<BasicBlock *, BasicBlock *> P : InbetweenCode) {
5265     BasicBlock *EnterBB = P.first;
5266     BasicBlock *ExitBB = P.second;
5267 
5268     if (BodyEnter)
5269       redirectTo(BodyEnter, EnterBB, DL);
5270     else
5271       redirectAllPredecessorsTo(BodyEntered, EnterBB, DL);
5272 
5273     BodyEnter = nullptr;
5274     BodyEntered = ExitBB;
5275   }
5276 
5277   // Append the original loop nest body into the generated loop nest body.
5278   if (BodyEnter)
5279     redirectTo(BodyEnter, InnerEnter, DL);
5280   else
5281     redirectAllPredecessorsTo(BodyEntered, InnerEnter, DL);
5282   redirectAllPredecessorsTo(InnerLatch, Continue, DL);
5283 
5284   // Replace the original induction variable with an induction variable computed
5285   // from the tile and floor induction variables.
5286   Builder.restoreIP(Result.back()->getBodyIP());
5287   for (int i = 0; i < NumLoops; ++i) {
5288     CanonicalLoopInfo *FloorLoop = Result[i];
5289     CanonicalLoopInfo *TileLoop = Result[NumLoops + i];
5290     Value *OrigIndVar = OrigIndVars[i];
5291     Value *Size = TileSizes[i];
5292 
5293     Value *Scale =
5294         Builder.CreateMul(Size, FloorLoop->getIndVar(), {}, /*HasNUW=*/true);
5295     Value *Shift =
5296         Builder.CreateAdd(Scale, TileLoop->getIndVar(), {}, /*HasNUW=*/true);
5297     OrigIndVar->replaceAllUsesWith(Shift);
5298   }
5299 
5300   // Remove unused parts of the original loops.
5301   removeUnusedBlocksFromParent(OldControlBBs);
5302 
5303   for (CanonicalLoopInfo *L : Loops)
5304     L->invalidate();
5305 
5306 #ifndef NDEBUG
5307   for (CanonicalLoopInfo *GenL : Result)
5308     GenL->assertOK();
5309 #endif
5310   return Result;
5311 }
5312 
5313 /// Attach metadata \p Properties to the basic block described by \p BB. If the
5314 /// basic block already has metadata, the basic block properties are appended.
addBasicBlockMetadata(BasicBlock * BB,ArrayRef<Metadata * > Properties)5315 static void addBasicBlockMetadata(BasicBlock *BB,
5316                                   ArrayRef<Metadata *> Properties) {
5317   // Nothing to do if no property to attach.
5318   if (Properties.empty())
5319     return;
5320 
5321   LLVMContext &Ctx = BB->getContext();
5322   SmallVector<Metadata *> NewProperties;
5323   NewProperties.push_back(nullptr);
5324 
5325   // If the basic block already has metadata, prepend it to the new metadata.
5326   MDNode *Existing = BB->getTerminator()->getMetadata(LLVMContext::MD_loop);
5327   if (Existing)
5328     append_range(NewProperties, drop_begin(Existing->operands(), 1));
5329 
5330   append_range(NewProperties, Properties);
5331   MDNode *BasicBlockID = MDNode::getDistinct(Ctx, NewProperties);
5332   BasicBlockID->replaceOperandWith(0, BasicBlockID);
5333 
5334   BB->getTerminator()->setMetadata(LLVMContext::MD_loop, BasicBlockID);
5335 }
5336 
5337 /// Attach loop metadata \p Properties to the loop described by \p Loop. If the
5338 /// loop already has metadata, the loop properties are appended.
addLoopMetadata(CanonicalLoopInfo * Loop,ArrayRef<Metadata * > Properties)5339 static void addLoopMetadata(CanonicalLoopInfo *Loop,
5340                             ArrayRef<Metadata *> Properties) {
5341   assert(Loop->isValid() && "Expecting a valid CanonicalLoopInfo");
5342 
5343   // Attach metadata to the loop's latch
5344   BasicBlock *Latch = Loop->getLatch();
5345   assert(Latch && "A valid CanonicalLoopInfo must have a unique latch");
5346   addBasicBlockMetadata(Latch, Properties);
5347 }
5348 
5349 /// Attach llvm.access.group metadata to the memref instructions of \p Block
addSimdMetadata(BasicBlock * Block,MDNode * AccessGroup,LoopInfo & LI)5350 static void addSimdMetadata(BasicBlock *Block, MDNode *AccessGroup,
5351                             LoopInfo &LI) {
5352   for (Instruction &I : *Block) {
5353     if (I.mayReadOrWriteMemory()) {
5354       // TODO: This instruction may already have access group from
5355       // other pragmas e.g. #pragma clang loop vectorize.  Append
5356       // so that the existing metadata is not overwritten.
5357       I.setMetadata(LLVMContext::MD_access_group, AccessGroup);
5358     }
5359   }
5360 }
5361 
unrollLoopFull(DebugLoc,CanonicalLoopInfo * Loop)5362 void OpenMPIRBuilder::unrollLoopFull(DebugLoc, CanonicalLoopInfo *Loop) {
5363   LLVMContext &Ctx = Builder.getContext();
5364   addLoopMetadata(
5365       Loop, {MDNode::get(Ctx, MDString::get(Ctx, "llvm.loop.unroll.enable")),
5366              MDNode::get(Ctx, MDString::get(Ctx, "llvm.loop.unroll.full"))});
5367 }
5368 
unrollLoopHeuristic(DebugLoc,CanonicalLoopInfo * Loop)5369 void OpenMPIRBuilder::unrollLoopHeuristic(DebugLoc, CanonicalLoopInfo *Loop) {
5370   LLVMContext &Ctx = Builder.getContext();
5371   addLoopMetadata(
5372       Loop, {
5373                 MDNode::get(Ctx, MDString::get(Ctx, "llvm.loop.unroll.enable")),
5374             });
5375 }
5376 
createIfVersion(CanonicalLoopInfo * CanonicalLoop,Value * IfCond,ValueToValueMapTy & VMap,LoopAnalysis & LIA,LoopInfo & LI,Loop * L,const Twine & NamePrefix)5377 void OpenMPIRBuilder::createIfVersion(CanonicalLoopInfo *CanonicalLoop,
5378                                       Value *IfCond, ValueToValueMapTy &VMap,
5379                                       LoopAnalysis &LIA, LoopInfo &LI, Loop *L,
5380                                       const Twine &NamePrefix) {
5381   Function *F = CanonicalLoop->getFunction();
5382 
5383   // We can't do
5384   // if (cond) {
5385   //   simd_loop;
5386   // } else {
5387   //   non_simd_loop;
5388   // }
5389   // because then the CanonicalLoopInfo would only point to one of the loops:
5390   // leading to other constructs operating on the same loop to malfunction.
5391   // Instead generate
5392   // while (...) {
5393   //   if (cond) {
5394   //     simd_body;
5395   //   } else {
5396   //     not_simd_body;
5397   //   }
5398   // }
5399   // At least for simple loops, LLVM seems able to hoist the if out of the loop
5400   // body at -O3
5401 
5402   // Define where if branch should be inserted
5403   auto SplitBeforeIt = CanonicalLoop->getBody()->getFirstNonPHIIt();
5404 
5405   // Create additional blocks for the if statement
5406   BasicBlock *Cond = SplitBeforeIt->getParent();
5407   llvm::LLVMContext &C = Cond->getContext();
5408   llvm::BasicBlock *ThenBlock = llvm::BasicBlock::Create(
5409       C, NamePrefix + ".if.then", Cond->getParent(), Cond->getNextNode());
5410   llvm::BasicBlock *ElseBlock = llvm::BasicBlock::Create(
5411       C, NamePrefix + ".if.else", Cond->getParent(), CanonicalLoop->getExit());
5412 
5413   // Create if condition branch.
5414   Builder.SetInsertPoint(SplitBeforeIt);
5415   Instruction *BrInstr =
5416       Builder.CreateCondBr(IfCond, ThenBlock, /*ifFalse*/ ElseBlock);
5417   InsertPointTy IP{BrInstr->getParent(), ++BrInstr->getIterator()};
5418   // Then block contains branch to omp loop body which needs to be vectorized
5419   spliceBB(IP, ThenBlock, false, Builder.getCurrentDebugLocation());
5420   ThenBlock->replaceSuccessorsPhiUsesWith(Cond, ThenBlock);
5421 
5422   Builder.SetInsertPoint(ElseBlock);
5423 
5424   // Clone loop for the else branch
5425   SmallVector<BasicBlock *, 8> NewBlocks;
5426 
5427   SmallVector<BasicBlock *, 8> ExistingBlocks;
5428   ExistingBlocks.reserve(L->getNumBlocks() + 1);
5429   ExistingBlocks.push_back(ThenBlock);
5430   ExistingBlocks.append(L->block_begin(), L->block_end());
5431   // Cond is the block that has the if clause condition
5432   // LoopCond is omp_loop.cond
5433   // LoopHeader is omp_loop.header
5434   BasicBlock *LoopCond = Cond->getUniquePredecessor();
5435   BasicBlock *LoopHeader = LoopCond->getUniquePredecessor();
5436   assert(LoopCond && LoopHeader && "Invalid loop structure");
5437   for (BasicBlock *Block : ExistingBlocks) {
5438     if (Block == L->getLoopPreheader() || Block == L->getLoopLatch() ||
5439         Block == LoopHeader || Block == LoopCond || Block == Cond) {
5440       continue;
5441     }
5442     BasicBlock *NewBB = CloneBasicBlock(Block, VMap, "", F);
5443 
5444     // fix name not to be omp.if.then
5445     if (Block == ThenBlock)
5446       NewBB->setName(NamePrefix + ".if.else");
5447 
5448     NewBB->moveBefore(CanonicalLoop->getExit());
5449     VMap[Block] = NewBB;
5450     NewBlocks.push_back(NewBB);
5451   }
5452   remapInstructionsInBlocks(NewBlocks, VMap);
5453   Builder.CreateBr(NewBlocks.front());
5454 
5455   // The loop latch must have only one predecessor. Currently it is branched to
5456   // from both the 'then' and 'else' branches.
5457   L->getLoopLatch()->splitBasicBlock(
5458       L->getLoopLatch()->begin(), NamePrefix + ".pre_latch", /*Before=*/true);
5459 
5460   // Ensure that the then block is added to the loop so we add the attributes in
5461   // the next step
5462   L->addBasicBlockToLoop(ThenBlock, LI);
5463 }
5464 
5465 unsigned
getOpenMPDefaultSimdAlign(const Triple & TargetTriple,const StringMap<bool> & Features)5466 OpenMPIRBuilder::getOpenMPDefaultSimdAlign(const Triple &TargetTriple,
5467                                            const StringMap<bool> &Features) {
5468   if (TargetTriple.isX86()) {
5469     if (Features.lookup("avx512f"))
5470       return 512;
5471     else if (Features.lookup("avx"))
5472       return 256;
5473     return 128;
5474   }
5475   if (TargetTriple.isPPC())
5476     return 128;
5477   if (TargetTriple.isWasm())
5478     return 128;
5479   return 0;
5480 }
5481 
applySimd(CanonicalLoopInfo * CanonicalLoop,MapVector<Value *,Value * > AlignedVars,Value * IfCond,OrderKind Order,ConstantInt * Simdlen,ConstantInt * Safelen)5482 void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop,
5483                                 MapVector<Value *, Value *> AlignedVars,
5484                                 Value *IfCond, OrderKind Order,
5485                                 ConstantInt *Simdlen, ConstantInt *Safelen) {
5486   LLVMContext &Ctx = Builder.getContext();
5487 
5488   Function *F = CanonicalLoop->getFunction();
5489 
5490   // TODO: We should not rely on pass manager. Currently we use pass manager
5491   // only for getting llvm::Loop which corresponds to given CanonicalLoopInfo
5492   // object. We should have a method  which returns all blocks between
5493   // CanonicalLoopInfo::getHeader() and CanonicalLoopInfo::getAfter()
5494   FunctionAnalysisManager FAM;
5495   FAM.registerPass([]() { return DominatorTreeAnalysis(); });
5496   FAM.registerPass([]() { return LoopAnalysis(); });
5497   FAM.registerPass([]() { return PassInstrumentationAnalysis(); });
5498 
5499   LoopAnalysis LIA;
5500   LoopInfo &&LI = LIA.run(*F, FAM);
5501 
5502   Loop *L = LI.getLoopFor(CanonicalLoop->getHeader());
5503   if (AlignedVars.size()) {
5504     InsertPointTy IP = Builder.saveIP();
5505     for (auto &AlignedItem : AlignedVars) {
5506       Value *AlignedPtr = AlignedItem.first;
5507       Value *Alignment = AlignedItem.second;
5508       Instruction *loadInst = dyn_cast<Instruction>(AlignedPtr);
5509       Builder.SetInsertPoint(loadInst->getNextNode());
5510       Builder.CreateAlignmentAssumption(F->getDataLayout(), AlignedPtr,
5511                                         Alignment);
5512     }
5513     Builder.restoreIP(IP);
5514   }
5515 
5516   if (IfCond) {
5517     ValueToValueMapTy VMap;
5518     createIfVersion(CanonicalLoop, IfCond, VMap, LIA, LI, L, "simd");
5519   }
5520 
5521   SmallSet<BasicBlock *, 8> Reachable;
5522 
5523   // Get the basic blocks from the loop in which memref instructions
5524   // can be found.
5525   // TODO: Generalize getting all blocks inside a CanonicalizeLoopInfo,
5526   // preferably without running any passes.
5527   for (BasicBlock *Block : L->getBlocks()) {
5528     if (Block == CanonicalLoop->getCond() ||
5529         Block == CanonicalLoop->getHeader())
5530       continue;
5531     Reachable.insert(Block);
5532   }
5533 
5534   SmallVector<Metadata *> LoopMDList;
5535 
5536   // In presence of finite 'safelen', it may be unsafe to mark all
5537   // the memory instructions parallel, because loop-carried
5538   // dependences of 'safelen' iterations are possible.
5539   // If clause order(concurrent) is specified then the memory instructions
5540   // are marked parallel even if 'safelen' is finite.
5541   if ((Safelen == nullptr) || (Order == OrderKind::OMP_ORDER_concurrent)) {
5542     // Add access group metadata to memory-access instructions.
5543     MDNode *AccessGroup = MDNode::getDistinct(Ctx, {});
5544     for (BasicBlock *BB : Reachable)
5545       addSimdMetadata(BB, AccessGroup, LI);
5546     // TODO:  If the loop has existing parallel access metadata, have
5547     // to combine two lists.
5548     LoopMDList.push_back(MDNode::get(
5549         Ctx, {MDString::get(Ctx, "llvm.loop.parallel_accesses"), AccessGroup}));
5550   }
5551 
5552   // FIXME: the IF clause shares a loop backedge for the SIMD and non-SIMD
5553   // versions so we can't add the loop attributes in that case.
5554   if (IfCond) {
5555     // we can still add llvm.loop.parallel_access
5556     addLoopMetadata(CanonicalLoop, LoopMDList);
5557     return;
5558   }
5559 
5560   // Use the above access group metadata to create loop level
5561   // metadata, which should be distinct for each loop.
5562   ConstantAsMetadata *BoolConst =
5563       ConstantAsMetadata::get(ConstantInt::getTrue(Type::getInt1Ty(Ctx)));
5564   LoopMDList.push_back(MDNode::get(
5565       Ctx, {MDString::get(Ctx, "llvm.loop.vectorize.enable"), BoolConst}));
5566 
5567   if (Simdlen || Safelen) {
5568     // If both simdlen and safelen clauses are specified, the value of the
5569     // simdlen parameter must be less than or equal to the value of the safelen
5570     // parameter. Therefore, use safelen only in the absence of simdlen.
5571     ConstantInt *VectorizeWidth = Simdlen == nullptr ? Safelen : Simdlen;
5572     LoopMDList.push_back(
5573         MDNode::get(Ctx, {MDString::get(Ctx, "llvm.loop.vectorize.width"),
5574                           ConstantAsMetadata::get(VectorizeWidth)}));
5575   }
5576 
5577   addLoopMetadata(CanonicalLoop, LoopMDList);
5578 }
5579 
5580 /// Create the TargetMachine object to query the backend for optimization
5581 /// preferences.
5582 ///
5583 /// Ideally, this would be passed from the front-end to the OpenMPBuilder, but
5584 /// e.g. Clang does not pass it to its CodeGen layer and creates it only when
5585 /// needed for the LLVM pass pipline. We use some default options to avoid
5586 /// having to pass too many settings from the frontend that probably do not
5587 /// matter.
5588 ///
5589 /// Currently, TargetMachine is only used sometimes by the unrollLoopPartial
5590 /// method. If we are going to use TargetMachine for more purposes, especially
5591 /// those that are sensitive to TargetOptions, RelocModel and CodeModel, it
5592 /// might become be worth requiring front-ends to pass on their TargetMachine,
5593 /// or at least cache it between methods. Note that while fontends such as Clang
5594 /// have just a single main TargetMachine per translation unit, "target-cpu" and
5595 /// "target-features" that determine the TargetMachine are per-function and can
5596 /// be overrided using __attribute__((target("OPTIONS"))).
5597 static std::unique_ptr<TargetMachine>
createTargetMachine(Function * F,CodeGenOptLevel OptLevel)5598 createTargetMachine(Function *F, CodeGenOptLevel OptLevel) {
5599   Module *M = F->getParent();
5600 
5601   StringRef CPU = F->getFnAttribute("target-cpu").getValueAsString();
5602   StringRef Features = F->getFnAttribute("target-features").getValueAsString();
5603   const llvm::Triple &Triple = M->getTargetTriple();
5604 
5605   std::string Error;
5606   const llvm::Target *TheTarget = TargetRegistry::lookupTarget(Triple, Error);
5607   if (!TheTarget)
5608     return {};
5609 
5610   llvm::TargetOptions Options;
5611   return std::unique_ptr<TargetMachine>(TheTarget->createTargetMachine(
5612       Triple, CPU, Features, Options, /*RelocModel=*/std::nullopt,
5613       /*CodeModel=*/std::nullopt, OptLevel));
5614 }
5615 
5616 /// Heuristically determine the best-performant unroll factor for \p CLI. This
5617 /// depends on the target processor. We are re-using the same heuristics as the
5618 /// LoopUnrollPass.
computeHeuristicUnrollFactor(CanonicalLoopInfo * CLI)5619 static int32_t computeHeuristicUnrollFactor(CanonicalLoopInfo *CLI) {
5620   Function *F = CLI->getFunction();
5621 
5622   // Assume the user requests the most aggressive unrolling, even if the rest of
5623   // the code is optimized using a lower setting.
5624   CodeGenOptLevel OptLevel = CodeGenOptLevel::Aggressive;
5625   std::unique_ptr<TargetMachine> TM = createTargetMachine(F, OptLevel);
5626 
5627   FunctionAnalysisManager FAM;
5628   FAM.registerPass([]() { return TargetLibraryAnalysis(); });
5629   FAM.registerPass([]() { return AssumptionAnalysis(); });
5630   FAM.registerPass([]() { return DominatorTreeAnalysis(); });
5631   FAM.registerPass([]() { return LoopAnalysis(); });
5632   FAM.registerPass([]() { return ScalarEvolutionAnalysis(); });
5633   FAM.registerPass([]() { return PassInstrumentationAnalysis(); });
5634   TargetIRAnalysis TIRA;
5635   if (TM)
5636     TIRA = TargetIRAnalysis(
5637         [&](const Function &F) { return TM->getTargetTransformInfo(F); });
5638   FAM.registerPass([&]() { return TIRA; });
5639 
5640   TargetIRAnalysis::Result &&TTI = TIRA.run(*F, FAM);
5641   ScalarEvolutionAnalysis SEA;
5642   ScalarEvolution &&SE = SEA.run(*F, FAM);
5643   DominatorTreeAnalysis DTA;
5644   DominatorTree &&DT = DTA.run(*F, FAM);
5645   LoopAnalysis LIA;
5646   LoopInfo &&LI = LIA.run(*F, FAM);
5647   AssumptionAnalysis ACT;
5648   AssumptionCache &&AC = ACT.run(*F, FAM);
5649   OptimizationRemarkEmitter ORE{F};
5650 
5651   Loop *L = LI.getLoopFor(CLI->getHeader());
5652   assert(L && "Expecting CanonicalLoopInfo to be recognized as a loop");
5653 
5654   TargetTransformInfo::UnrollingPreferences UP = gatherUnrollingPreferences(
5655       L, SE, TTI,
5656       /*BlockFrequencyInfo=*/nullptr,
5657       /*ProfileSummaryInfo=*/nullptr, ORE, static_cast<int>(OptLevel),
5658       /*UserThreshold=*/std::nullopt,
5659       /*UserCount=*/std::nullopt,
5660       /*UserAllowPartial=*/true,
5661       /*UserAllowRuntime=*/true,
5662       /*UserUpperBound=*/std::nullopt,
5663       /*UserFullUnrollMaxCount=*/std::nullopt);
5664 
5665   UP.Force = true;
5666 
5667   // Account for additional optimizations taking place before the LoopUnrollPass
5668   // would unroll the loop.
5669   UP.Threshold *= UnrollThresholdFactor;
5670   UP.PartialThreshold *= UnrollThresholdFactor;
5671 
5672   // Use normal unroll factors even if the rest of the code is optimized for
5673   // size.
5674   UP.OptSizeThreshold = UP.Threshold;
5675   UP.PartialOptSizeThreshold = UP.PartialThreshold;
5676 
5677   LLVM_DEBUG(dbgs() << "Unroll heuristic thresholds:\n"
5678                     << "  Threshold=" << UP.Threshold << "\n"
5679                     << "  PartialThreshold=" << UP.PartialThreshold << "\n"
5680                     << "  OptSizeThreshold=" << UP.OptSizeThreshold << "\n"
5681                     << "  PartialOptSizeThreshold="
5682                     << UP.PartialOptSizeThreshold << "\n");
5683 
5684   // Disable peeling.
5685   TargetTransformInfo::PeelingPreferences PP =
5686       gatherPeelingPreferences(L, SE, TTI,
5687                                /*UserAllowPeeling=*/false,
5688                                /*UserAllowProfileBasedPeeling=*/false,
5689                                /*UnrollingSpecficValues=*/false);
5690 
5691   SmallPtrSet<const Value *, 32> EphValues;
5692   CodeMetrics::collectEphemeralValues(L, &AC, EphValues);
5693 
5694   // Assume that reads and writes to stack variables can be eliminated by
5695   // Mem2Reg, SROA or LICM. That is, don't count them towards the loop body's
5696   // size.
5697   for (BasicBlock *BB : L->blocks()) {
5698     for (Instruction &I : *BB) {
5699       Value *Ptr;
5700       if (auto *Load = dyn_cast<LoadInst>(&I)) {
5701         Ptr = Load->getPointerOperand();
5702       } else if (auto *Store = dyn_cast<StoreInst>(&I)) {
5703         Ptr = Store->getPointerOperand();
5704       } else
5705         continue;
5706 
5707       Ptr = Ptr->stripPointerCasts();
5708 
5709       if (auto *Alloca = dyn_cast<AllocaInst>(Ptr)) {
5710         if (Alloca->getParent() == &F->getEntryBlock())
5711           EphValues.insert(&I);
5712       }
5713     }
5714   }
5715 
5716   UnrollCostEstimator UCE(L, TTI, EphValues, UP.BEInsns);
5717 
5718   // Loop is not unrollable if the loop contains certain instructions.
5719   if (!UCE.canUnroll()) {
5720     LLVM_DEBUG(dbgs() << "Loop not considered unrollable\n");
5721     return 1;
5722   }
5723 
5724   LLVM_DEBUG(dbgs() << "Estimated loop size is " << UCE.getRolledLoopSize()
5725                     << "\n");
5726 
5727   // TODO: Determine trip count of \p CLI if constant, computeUnrollCount might
5728   // be able to use it.
5729   int TripCount = 0;
5730   int MaxTripCount = 0;
5731   bool MaxOrZero = false;
5732   unsigned TripMultiple = 0;
5733 
5734   bool UseUpperBound = false;
5735   computeUnrollCount(L, TTI, DT, &LI, &AC, SE, EphValues, &ORE, TripCount,
5736                      MaxTripCount, MaxOrZero, TripMultiple, UCE, UP, PP,
5737                      UseUpperBound);
5738   unsigned Factor = UP.Count;
5739   LLVM_DEBUG(dbgs() << "Suggesting unroll factor of " << Factor << "\n");
5740 
5741   // This function returns 1 to signal to not unroll a loop.
5742   if (Factor == 0)
5743     return 1;
5744   return Factor;
5745 }
5746 
unrollLoopPartial(DebugLoc DL,CanonicalLoopInfo * Loop,int32_t Factor,CanonicalLoopInfo ** UnrolledCLI)5747 void OpenMPIRBuilder::unrollLoopPartial(DebugLoc DL, CanonicalLoopInfo *Loop,
5748                                         int32_t Factor,
5749                                         CanonicalLoopInfo **UnrolledCLI) {
5750   assert(Factor >= 0 && "Unroll factor must not be negative");
5751 
5752   Function *F = Loop->getFunction();
5753   LLVMContext &Ctx = F->getContext();
5754 
5755   // If the unrolled loop is not used for another loop-associated directive, it
5756   // is sufficient to add metadata for the LoopUnrollPass.
5757   if (!UnrolledCLI) {
5758     SmallVector<Metadata *, 2> LoopMetadata;
5759     LoopMetadata.push_back(
5760         MDNode::get(Ctx, MDString::get(Ctx, "llvm.loop.unroll.enable")));
5761 
5762     if (Factor >= 1) {
5763       ConstantAsMetadata *FactorConst = ConstantAsMetadata::get(
5764           ConstantInt::get(Type::getInt32Ty(Ctx), APInt(32, Factor)));
5765       LoopMetadata.push_back(MDNode::get(
5766           Ctx, {MDString::get(Ctx, "llvm.loop.unroll.count"), FactorConst}));
5767     }
5768 
5769     addLoopMetadata(Loop, LoopMetadata);
5770     return;
5771   }
5772 
5773   // Heuristically determine the unroll factor.
5774   if (Factor == 0)
5775     Factor = computeHeuristicUnrollFactor(Loop);
5776 
5777   // No change required with unroll factor 1.
5778   if (Factor == 1) {
5779     *UnrolledCLI = Loop;
5780     return;
5781   }
5782 
5783   assert(Factor >= 2 &&
5784          "unrolling only makes sense with a factor of 2 or larger");
5785 
5786   Type *IndVarTy = Loop->getIndVarType();
5787 
5788   // Apply partial unrolling by tiling the loop by the unroll-factor, then fully
5789   // unroll the inner loop.
5790   Value *FactorVal =
5791       ConstantInt::get(IndVarTy, APInt(IndVarTy->getIntegerBitWidth(), Factor,
5792                                        /*isSigned=*/false));
5793   std::vector<CanonicalLoopInfo *> LoopNest =
5794       tileLoops(DL, {Loop}, {FactorVal});
5795   assert(LoopNest.size() == 2 && "Expect 2 loops after tiling");
5796   *UnrolledCLI = LoopNest[0];
5797   CanonicalLoopInfo *InnerLoop = LoopNest[1];
5798 
5799   // LoopUnrollPass can only fully unroll loops with constant trip count.
5800   // Unroll by the unroll factor with a fallback epilog for the remainder
5801   // iterations if necessary.
5802   ConstantAsMetadata *FactorConst = ConstantAsMetadata::get(
5803       ConstantInt::get(Type::getInt32Ty(Ctx), APInt(32, Factor)));
5804   addLoopMetadata(
5805       InnerLoop,
5806       {MDNode::get(Ctx, MDString::get(Ctx, "llvm.loop.unroll.enable")),
5807        MDNode::get(
5808            Ctx, {MDString::get(Ctx, "llvm.loop.unroll.count"), FactorConst})});
5809 
5810 #ifndef NDEBUG
5811   (*UnrolledCLI)->assertOK();
5812 #endif
5813 }
5814 
5815 OpenMPIRBuilder::InsertPointTy
createCopyPrivate(const LocationDescription & Loc,llvm::Value * BufSize,llvm::Value * CpyBuf,llvm::Value * CpyFn,llvm::Value * DidIt)5816 OpenMPIRBuilder::createCopyPrivate(const LocationDescription &Loc,
5817                                    llvm::Value *BufSize, llvm::Value *CpyBuf,
5818                                    llvm::Value *CpyFn, llvm::Value *DidIt) {
5819   if (!updateToLocation(Loc))
5820     return Loc.IP;
5821 
5822   uint32_t SrcLocStrSize;
5823   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5824   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5825   Value *ThreadId = getOrCreateThreadID(Ident);
5826 
5827   llvm::Value *DidItLD = Builder.CreateLoad(Builder.getInt32Ty(), DidIt);
5828 
5829   Value *Args[] = {Ident, ThreadId, BufSize, CpyBuf, CpyFn, DidItLD};
5830 
5831   Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_copyprivate);
5832   Builder.CreateCall(Fn, Args);
5833 
5834   return Builder.saveIP();
5835 }
5836 
createSingle(const LocationDescription & Loc,BodyGenCallbackTy BodyGenCB,FinalizeCallbackTy FiniCB,bool IsNowait,ArrayRef<llvm::Value * > CPVars,ArrayRef<llvm::Function * > CPFuncs)5837 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createSingle(
5838     const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
5839     FinalizeCallbackTy FiniCB, bool IsNowait, ArrayRef<llvm::Value *> CPVars,
5840     ArrayRef<llvm::Function *> CPFuncs) {
5841 
5842   if (!updateToLocation(Loc))
5843     return Loc.IP;
5844 
5845   // If needed allocate and initialize `DidIt` with 0.
5846   // DidIt: flag variable: 1=single thread; 0=not single thread.
5847   llvm::Value *DidIt = nullptr;
5848   if (!CPVars.empty()) {
5849     DidIt = Builder.CreateAlloca(llvm::Type::getInt32Ty(Builder.getContext()));
5850     Builder.CreateStore(Builder.getInt32(0), DidIt);
5851   }
5852 
5853   Directive OMPD = Directive::OMPD_single;
5854   uint32_t SrcLocStrSize;
5855   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5856   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5857   Value *ThreadId = getOrCreateThreadID(Ident);
5858   Value *Args[] = {Ident, ThreadId};
5859 
5860   Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_single);
5861   Instruction *EntryCall = Builder.CreateCall(EntryRTLFn, Args);
5862 
5863   Function *ExitRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_single);
5864   Instruction *ExitCall = Builder.CreateCall(ExitRTLFn, Args);
5865 
5866   auto FiniCBWrapper = [&](InsertPointTy IP) -> Error {
5867     if (Error Err = FiniCB(IP))
5868       return Err;
5869 
5870     // The thread that executes the single region must set `DidIt` to 1.
5871     // This is used by __kmpc_copyprivate, to know if the caller is the
5872     // single thread or not.
5873     if (DidIt)
5874       Builder.CreateStore(Builder.getInt32(1), DidIt);
5875 
5876     return Error::success();
5877   };
5878 
5879   // generates the following:
5880   // if (__kmpc_single()) {
5881   //		.... single region ...
5882   // 		__kmpc_end_single
5883   // }
5884   // __kmpc_copyprivate
5885   // __kmpc_barrier
5886 
5887   InsertPointOrErrorTy AfterIP =
5888       EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCBWrapper,
5889                            /*Conditional*/ true,
5890                            /*hasFinalize*/ true);
5891   if (!AfterIP)
5892     return AfterIP.takeError();
5893 
5894   if (DidIt) {
5895     for (size_t I = 0, E = CPVars.size(); I < E; ++I)
5896       // NOTE BufSize is currently unused, so just pass 0.
5897       createCopyPrivate(LocationDescription(Builder.saveIP(), Loc.DL),
5898                         /*BufSize=*/ConstantInt::get(Int64, 0), CPVars[I],
5899                         CPFuncs[I], DidIt);
5900     // NOTE __kmpc_copyprivate already inserts a barrier
5901   } else if (!IsNowait) {
5902     InsertPointOrErrorTy AfterIP =
5903         createBarrier(LocationDescription(Builder.saveIP(), Loc.DL),
5904                       omp::Directive::OMPD_unknown, /* ForceSimpleCall */ false,
5905                       /* CheckCancelFlag */ false);
5906     if (!AfterIP)
5907       return AfterIP.takeError();
5908   }
5909   return Builder.saveIP();
5910 }
5911 
createCritical(const LocationDescription & Loc,BodyGenCallbackTy BodyGenCB,FinalizeCallbackTy FiniCB,StringRef CriticalName,Value * HintInst)5912 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createCritical(
5913     const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
5914     FinalizeCallbackTy FiniCB, StringRef CriticalName, Value *HintInst) {
5915 
5916   if (!updateToLocation(Loc))
5917     return Loc.IP;
5918 
5919   Directive OMPD = Directive::OMPD_critical;
5920   uint32_t SrcLocStrSize;
5921   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5922   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5923   Value *ThreadId = getOrCreateThreadID(Ident);
5924   Value *LockVar = getOMPCriticalRegionLock(CriticalName);
5925   Value *Args[] = {Ident, ThreadId, LockVar};
5926 
5927   SmallVector<llvm::Value *, 4> EnterArgs(std::begin(Args), std::end(Args));
5928   Function *RTFn = nullptr;
5929   if (HintInst) {
5930     // Add Hint to entry Args and create call
5931     EnterArgs.push_back(HintInst);
5932     RTFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_critical_with_hint);
5933   } else {
5934     RTFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_critical);
5935   }
5936   Instruction *EntryCall = Builder.CreateCall(RTFn, EnterArgs);
5937 
5938   Function *ExitRTLFn =
5939       getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_critical);
5940   Instruction *ExitCall = Builder.CreateCall(ExitRTLFn, Args);
5941 
5942   return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
5943                               /*Conditional*/ false, /*hasFinalize*/ true);
5944 }
5945 
5946 OpenMPIRBuilder::InsertPointTy
createOrderedDepend(const LocationDescription & Loc,InsertPointTy AllocaIP,unsigned NumLoops,ArrayRef<llvm::Value * > StoreValues,const Twine & Name,bool IsDependSource)5947 OpenMPIRBuilder::createOrderedDepend(const LocationDescription &Loc,
5948                                      InsertPointTy AllocaIP, unsigned NumLoops,
5949                                      ArrayRef<llvm::Value *> StoreValues,
5950                                      const Twine &Name, bool IsDependSource) {
5951   assert(
5952       llvm::all_of(StoreValues,
5953                    [](Value *SV) { return SV->getType()->isIntegerTy(64); }) &&
5954       "OpenMP runtime requires depend vec with i64 type");
5955 
5956   if (!updateToLocation(Loc))
5957     return Loc.IP;
5958 
5959   // Allocate space for vector and generate alloc instruction.
5960   auto *ArrI64Ty = ArrayType::get(Int64, NumLoops);
5961   Builder.restoreIP(AllocaIP);
5962   AllocaInst *ArgsBase = Builder.CreateAlloca(ArrI64Ty, nullptr, Name);
5963   ArgsBase->setAlignment(Align(8));
5964   Builder.restoreIP(Loc.IP);
5965 
5966   // Store the index value with offset in depend vector.
5967   for (unsigned I = 0; I < NumLoops; ++I) {
5968     Value *DependAddrGEPIter = Builder.CreateInBoundsGEP(
5969         ArrI64Ty, ArgsBase, {Builder.getInt64(0), Builder.getInt64(I)});
5970     StoreInst *STInst = Builder.CreateStore(StoreValues[I], DependAddrGEPIter);
5971     STInst->setAlignment(Align(8));
5972   }
5973 
5974   Value *DependBaseAddrGEP = Builder.CreateInBoundsGEP(
5975       ArrI64Ty, ArgsBase, {Builder.getInt64(0), Builder.getInt64(0)});
5976 
5977   uint32_t SrcLocStrSize;
5978   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5979   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5980   Value *ThreadId = getOrCreateThreadID(Ident);
5981   Value *Args[] = {Ident, ThreadId, DependBaseAddrGEP};
5982 
5983   Function *RTLFn = nullptr;
5984   if (IsDependSource)
5985     RTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_doacross_post);
5986   else
5987     RTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_doacross_wait);
5988   Builder.CreateCall(RTLFn, Args);
5989 
5990   return Builder.saveIP();
5991 }
5992 
createOrderedThreadsSimd(const LocationDescription & Loc,BodyGenCallbackTy BodyGenCB,FinalizeCallbackTy FiniCB,bool IsThreads)5993 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createOrderedThreadsSimd(
5994     const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
5995     FinalizeCallbackTy FiniCB, bool IsThreads) {
5996   if (!updateToLocation(Loc))
5997     return Loc.IP;
5998 
5999   Directive OMPD = Directive::OMPD_ordered;
6000   Instruction *EntryCall = nullptr;
6001   Instruction *ExitCall = nullptr;
6002 
6003   if (IsThreads) {
6004     uint32_t SrcLocStrSize;
6005     Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
6006     Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
6007     Value *ThreadId = getOrCreateThreadID(Ident);
6008     Value *Args[] = {Ident, ThreadId};
6009 
6010     Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_ordered);
6011     EntryCall = Builder.CreateCall(EntryRTLFn, Args);
6012 
6013     Function *ExitRTLFn =
6014         getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_ordered);
6015     ExitCall = Builder.CreateCall(ExitRTLFn, Args);
6016   }
6017 
6018   return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
6019                               /*Conditional*/ false, /*hasFinalize*/ true);
6020 }
6021 
EmitOMPInlinedRegion(Directive OMPD,Instruction * EntryCall,Instruction * ExitCall,BodyGenCallbackTy BodyGenCB,FinalizeCallbackTy FiniCB,bool Conditional,bool HasFinalize,bool IsCancellable)6022 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::EmitOMPInlinedRegion(
6023     Directive OMPD, Instruction *EntryCall, Instruction *ExitCall,
6024     BodyGenCallbackTy BodyGenCB, FinalizeCallbackTy FiniCB, bool Conditional,
6025     bool HasFinalize, bool IsCancellable) {
6026 
6027   if (HasFinalize)
6028     FinalizationStack.push_back({FiniCB, OMPD, IsCancellable});
6029 
6030   // Create inlined region's entry and body blocks, in preparation
6031   // for conditional creation
6032   BasicBlock *EntryBB = Builder.GetInsertBlock();
6033   Instruction *SplitPos = EntryBB->getTerminator();
6034   if (!isa_and_nonnull<BranchInst>(SplitPos))
6035     SplitPos = new UnreachableInst(Builder.getContext(), EntryBB);
6036   BasicBlock *ExitBB = EntryBB->splitBasicBlock(SplitPos, "omp_region.end");
6037   BasicBlock *FiniBB =
6038       EntryBB->splitBasicBlock(EntryBB->getTerminator(), "omp_region.finalize");
6039 
6040   Builder.SetInsertPoint(EntryBB->getTerminator());
6041   emitCommonDirectiveEntry(OMPD, EntryCall, ExitBB, Conditional);
6042 
6043   // generate body
6044   if (Error Err = BodyGenCB(/* AllocaIP */ InsertPointTy(),
6045                             /* CodeGenIP */ Builder.saveIP()))
6046     return Err;
6047 
6048   // emit exit call and do any needed finalization.
6049   auto FinIP = InsertPointTy(FiniBB, FiniBB->getFirstInsertionPt());
6050   assert(FiniBB->getTerminator()->getNumSuccessors() == 1 &&
6051          FiniBB->getTerminator()->getSuccessor(0) == ExitBB &&
6052          "Unexpected control flow graph state!!");
6053   InsertPointOrErrorTy AfterIP =
6054       emitCommonDirectiveExit(OMPD, FinIP, ExitCall, HasFinalize);
6055   if (!AfterIP)
6056     return AfterIP.takeError();
6057   assert(FiniBB->getUniquePredecessor()->getUniqueSuccessor() == FiniBB &&
6058          "Unexpected Control Flow State!");
6059   MergeBlockIntoPredecessor(FiniBB);
6060 
6061   // If we are skipping the region of a non conditional, remove the exit
6062   // block, and clear the builder's insertion point.
6063   assert(SplitPos->getParent() == ExitBB &&
6064          "Unexpected Insertion point location!");
6065   auto merged = MergeBlockIntoPredecessor(ExitBB);
6066   BasicBlock *ExitPredBB = SplitPos->getParent();
6067   auto InsertBB = merged ? ExitPredBB : ExitBB;
6068   if (!isa_and_nonnull<BranchInst>(SplitPos))
6069     SplitPos->eraseFromParent();
6070   Builder.SetInsertPoint(InsertBB);
6071 
6072   return Builder.saveIP();
6073 }
6074 
emitCommonDirectiveEntry(Directive OMPD,Value * EntryCall,BasicBlock * ExitBB,bool Conditional)6075 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitCommonDirectiveEntry(
6076     Directive OMPD, Value *EntryCall, BasicBlock *ExitBB, bool Conditional) {
6077   // if nothing to do, Return current insertion point.
6078   if (!Conditional || !EntryCall)
6079     return Builder.saveIP();
6080 
6081   BasicBlock *EntryBB = Builder.GetInsertBlock();
6082   Value *CallBool = Builder.CreateIsNotNull(EntryCall);
6083   auto *ThenBB = BasicBlock::Create(M.getContext(), "omp_region.body");
6084   auto *UI = new UnreachableInst(Builder.getContext(), ThenBB);
6085 
6086   // Emit thenBB and set the Builder's insertion point there for
6087   // body generation next. Place the block after the current block.
6088   Function *CurFn = EntryBB->getParent();
6089   CurFn->insert(std::next(EntryBB->getIterator()), ThenBB);
6090 
6091   // Move Entry branch to end of ThenBB, and replace with conditional
6092   // branch (If-stmt)
6093   Instruction *EntryBBTI = EntryBB->getTerminator();
6094   Builder.CreateCondBr(CallBool, ThenBB, ExitBB);
6095   EntryBBTI->removeFromParent();
6096   Builder.SetInsertPoint(UI);
6097   Builder.Insert(EntryBBTI);
6098   UI->eraseFromParent();
6099   Builder.SetInsertPoint(ThenBB->getTerminator());
6100 
6101   // return an insertion point to ExitBB.
6102   return IRBuilder<>::InsertPoint(ExitBB, ExitBB->getFirstInsertionPt());
6103 }
6104 
emitCommonDirectiveExit(omp::Directive OMPD,InsertPointTy FinIP,Instruction * ExitCall,bool HasFinalize)6105 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitCommonDirectiveExit(
6106     omp::Directive OMPD, InsertPointTy FinIP, Instruction *ExitCall,
6107     bool HasFinalize) {
6108 
6109   Builder.restoreIP(FinIP);
6110 
6111   // If there is finalization to do, emit it before the exit call
6112   if (HasFinalize) {
6113     assert(!FinalizationStack.empty() &&
6114            "Unexpected finalization stack state!");
6115 
6116     FinalizationInfo Fi = FinalizationStack.pop_back_val();
6117     assert(Fi.DK == OMPD && "Unexpected Directive for Finalization call!");
6118 
6119     if (Error Err = Fi.FiniCB(FinIP))
6120       return Err;
6121 
6122     BasicBlock *FiniBB = FinIP.getBlock();
6123     Instruction *FiniBBTI = FiniBB->getTerminator();
6124 
6125     // set Builder IP for call creation
6126     Builder.SetInsertPoint(FiniBBTI);
6127   }
6128 
6129   if (!ExitCall)
6130     return Builder.saveIP();
6131 
6132   // place the Exitcall as last instruction before Finalization block terminator
6133   ExitCall->removeFromParent();
6134   Builder.Insert(ExitCall);
6135 
6136   return IRBuilder<>::InsertPoint(ExitCall->getParent(),
6137                                   ExitCall->getIterator());
6138 }
6139 
createCopyinClauseBlocks(InsertPointTy IP,Value * MasterAddr,Value * PrivateAddr,llvm::IntegerType * IntPtrTy,bool BranchtoEnd)6140 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createCopyinClauseBlocks(
6141     InsertPointTy IP, Value *MasterAddr, Value *PrivateAddr,
6142     llvm::IntegerType *IntPtrTy, bool BranchtoEnd) {
6143   if (!IP.isSet())
6144     return IP;
6145 
6146   IRBuilder<>::InsertPointGuard IPG(Builder);
6147 
6148   // creates the following CFG structure
6149   //	   OMP_Entry : (MasterAddr != PrivateAddr)?
6150   //       F     T
6151   //       |      \
6152   //       |     copin.not.master
6153   //       |      /
6154   //       v     /
6155   //   copyin.not.master.end
6156   //		     |
6157   //         v
6158   //   OMP.Entry.Next
6159 
6160   BasicBlock *OMP_Entry = IP.getBlock();
6161   Function *CurFn = OMP_Entry->getParent();
6162   BasicBlock *CopyBegin =
6163       BasicBlock::Create(M.getContext(), "copyin.not.master", CurFn);
6164   BasicBlock *CopyEnd = nullptr;
6165 
6166   // If entry block is terminated, split to preserve the branch to following
6167   // basic block (i.e. OMP.Entry.Next), otherwise, leave everything as is.
6168   if (isa_and_nonnull<BranchInst>(OMP_Entry->getTerminator())) {
6169     CopyEnd = OMP_Entry->splitBasicBlock(OMP_Entry->getTerminator(),
6170                                          "copyin.not.master.end");
6171     OMP_Entry->getTerminator()->eraseFromParent();
6172   } else {
6173     CopyEnd =
6174         BasicBlock::Create(M.getContext(), "copyin.not.master.end", CurFn);
6175   }
6176 
6177   Builder.SetInsertPoint(OMP_Entry);
6178   Value *MasterPtr = Builder.CreatePtrToInt(MasterAddr, IntPtrTy);
6179   Value *PrivatePtr = Builder.CreatePtrToInt(PrivateAddr, IntPtrTy);
6180   Value *cmp = Builder.CreateICmpNE(MasterPtr, PrivatePtr);
6181   Builder.CreateCondBr(cmp, CopyBegin, CopyEnd);
6182 
6183   Builder.SetInsertPoint(CopyBegin);
6184   if (BranchtoEnd)
6185     Builder.SetInsertPoint(Builder.CreateBr(CopyEnd));
6186 
6187   return Builder.saveIP();
6188 }
6189 
createOMPAlloc(const LocationDescription & Loc,Value * Size,Value * Allocator,std::string Name)6190 CallInst *OpenMPIRBuilder::createOMPAlloc(const LocationDescription &Loc,
6191                                           Value *Size, Value *Allocator,
6192                                           std::string Name) {
6193   IRBuilder<>::InsertPointGuard IPG(Builder);
6194   updateToLocation(Loc);
6195 
6196   uint32_t SrcLocStrSize;
6197   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
6198   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
6199   Value *ThreadId = getOrCreateThreadID(Ident);
6200   Value *Args[] = {ThreadId, Size, Allocator};
6201 
6202   Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_alloc);
6203 
6204   return Builder.CreateCall(Fn, Args, Name);
6205 }
6206 
createOMPFree(const LocationDescription & Loc,Value * Addr,Value * Allocator,std::string Name)6207 CallInst *OpenMPIRBuilder::createOMPFree(const LocationDescription &Loc,
6208                                          Value *Addr, Value *Allocator,
6209                                          std::string Name) {
6210   IRBuilder<>::InsertPointGuard IPG(Builder);
6211   updateToLocation(Loc);
6212 
6213   uint32_t SrcLocStrSize;
6214   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
6215   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
6216   Value *ThreadId = getOrCreateThreadID(Ident);
6217   Value *Args[] = {ThreadId, Addr, Allocator};
6218   Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_free);
6219   return Builder.CreateCall(Fn, Args, Name);
6220 }
6221 
createOMPInteropInit(const LocationDescription & Loc,Value * InteropVar,omp::OMPInteropType InteropType,Value * Device,Value * NumDependences,Value * DependenceAddress,bool HaveNowaitClause)6222 CallInst *OpenMPIRBuilder::createOMPInteropInit(
6223     const LocationDescription &Loc, Value *InteropVar,
6224     omp::OMPInteropType InteropType, Value *Device, Value *NumDependences,
6225     Value *DependenceAddress, bool HaveNowaitClause) {
6226   IRBuilder<>::InsertPointGuard IPG(Builder);
6227   updateToLocation(Loc);
6228 
6229   uint32_t SrcLocStrSize;
6230   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
6231   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
6232   Value *ThreadId = getOrCreateThreadID(Ident);
6233   if (Device == nullptr)
6234     Device = Constant::getAllOnesValue(Int32);
6235   Constant *InteropTypeVal = ConstantInt::get(Int32, (int)InteropType);
6236   if (NumDependences == nullptr) {
6237     NumDependences = ConstantInt::get(Int32, 0);
6238     PointerType *PointerTypeVar = PointerType::getUnqual(M.getContext());
6239     DependenceAddress = ConstantPointerNull::get(PointerTypeVar);
6240   }
6241   Value *HaveNowaitClauseVal = ConstantInt::get(Int32, HaveNowaitClause);
6242   Value *Args[] = {
6243       Ident,  ThreadId,       InteropVar,        InteropTypeVal,
6244       Device, NumDependences, DependenceAddress, HaveNowaitClauseVal};
6245 
6246   Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___tgt_interop_init);
6247 
6248   return Builder.CreateCall(Fn, Args);
6249 }
6250 
createOMPInteropDestroy(const LocationDescription & Loc,Value * InteropVar,Value * Device,Value * NumDependences,Value * DependenceAddress,bool HaveNowaitClause)6251 CallInst *OpenMPIRBuilder::createOMPInteropDestroy(
6252     const LocationDescription &Loc, Value *InteropVar, Value *Device,
6253     Value *NumDependences, Value *DependenceAddress, bool HaveNowaitClause) {
6254   IRBuilder<>::InsertPointGuard IPG(Builder);
6255   updateToLocation(Loc);
6256 
6257   uint32_t SrcLocStrSize;
6258   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
6259   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
6260   Value *ThreadId = getOrCreateThreadID(Ident);
6261   if (Device == nullptr)
6262     Device = Constant::getAllOnesValue(Int32);
6263   if (NumDependences == nullptr) {
6264     NumDependences = ConstantInt::get(Int32, 0);
6265     PointerType *PointerTypeVar = PointerType::getUnqual(M.getContext());
6266     DependenceAddress = ConstantPointerNull::get(PointerTypeVar);
6267   }
6268   Value *HaveNowaitClauseVal = ConstantInt::get(Int32, HaveNowaitClause);
6269   Value *Args[] = {
6270       Ident,          ThreadId,          InteropVar,         Device,
6271       NumDependences, DependenceAddress, HaveNowaitClauseVal};
6272 
6273   Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___tgt_interop_destroy);
6274 
6275   return Builder.CreateCall(Fn, Args);
6276 }
6277 
createOMPInteropUse(const LocationDescription & Loc,Value * InteropVar,Value * Device,Value * NumDependences,Value * DependenceAddress,bool HaveNowaitClause)6278 CallInst *OpenMPIRBuilder::createOMPInteropUse(const LocationDescription &Loc,
6279                                                Value *InteropVar, Value *Device,
6280                                                Value *NumDependences,
6281                                                Value *DependenceAddress,
6282                                                bool HaveNowaitClause) {
6283   IRBuilder<>::InsertPointGuard IPG(Builder);
6284   updateToLocation(Loc);
6285   uint32_t SrcLocStrSize;
6286   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
6287   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
6288   Value *ThreadId = getOrCreateThreadID(Ident);
6289   if (Device == nullptr)
6290     Device = Constant::getAllOnesValue(Int32);
6291   if (NumDependences == nullptr) {
6292     NumDependences = ConstantInt::get(Int32, 0);
6293     PointerType *PointerTypeVar = PointerType::getUnqual(M.getContext());
6294     DependenceAddress = ConstantPointerNull::get(PointerTypeVar);
6295   }
6296   Value *HaveNowaitClauseVal = ConstantInt::get(Int32, HaveNowaitClause);
6297   Value *Args[] = {
6298       Ident,          ThreadId,          InteropVar,         Device,
6299       NumDependences, DependenceAddress, HaveNowaitClauseVal};
6300 
6301   Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___tgt_interop_use);
6302 
6303   return Builder.CreateCall(Fn, Args);
6304 }
6305 
createCachedThreadPrivate(const LocationDescription & Loc,llvm::Value * Pointer,llvm::ConstantInt * Size,const llvm::Twine & Name)6306 CallInst *OpenMPIRBuilder::createCachedThreadPrivate(
6307     const LocationDescription &Loc, llvm::Value *Pointer,
6308     llvm::ConstantInt *Size, const llvm::Twine &Name) {
6309   IRBuilder<>::InsertPointGuard IPG(Builder);
6310   updateToLocation(Loc);
6311 
6312   uint32_t SrcLocStrSize;
6313   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
6314   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
6315   Value *ThreadId = getOrCreateThreadID(Ident);
6316   Constant *ThreadPrivateCache =
6317       getOrCreateInternalVariable(Int8PtrPtr, Name.str());
6318   llvm::Value *Args[] = {Ident, ThreadId, Pointer, Size, ThreadPrivateCache};
6319 
6320   Function *Fn =
6321       getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_threadprivate_cached);
6322 
6323   return Builder.CreateCall(Fn, Args);
6324 }
6325 
createTargetInit(const LocationDescription & Loc,const llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs & Attrs)6326 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetInit(
6327     const LocationDescription &Loc,
6328     const llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &Attrs) {
6329   assert(!Attrs.MaxThreads.empty() && !Attrs.MaxTeams.empty() &&
6330          "expected num_threads and num_teams to be specified");
6331 
6332   if (!updateToLocation(Loc))
6333     return Loc.IP;
6334 
6335   uint32_t SrcLocStrSize;
6336   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
6337   Constant *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
6338   Constant *IsSPMDVal = ConstantInt::getSigned(Int8, Attrs.ExecFlags);
6339   Constant *UseGenericStateMachineVal = ConstantInt::getSigned(
6340       Int8, Attrs.ExecFlags != omp::OMP_TGT_EXEC_MODE_SPMD);
6341   Constant *MayUseNestedParallelismVal = ConstantInt::getSigned(Int8, true);
6342   Constant *DebugIndentionLevelVal = ConstantInt::getSigned(Int16, 0);
6343 
6344   Function *DebugKernelWrapper = Builder.GetInsertBlock()->getParent();
6345   Function *Kernel = DebugKernelWrapper;
6346 
6347   // We need to strip the debug prefix to get the correct kernel name.
6348   StringRef KernelName = Kernel->getName();
6349   const std::string DebugPrefix = "_debug__";
6350   if (KernelName.ends_with(DebugPrefix)) {
6351     KernelName = KernelName.drop_back(DebugPrefix.length());
6352     Kernel = M.getFunction(KernelName);
6353     assert(Kernel && "Expected the real kernel to exist");
6354   }
6355 
6356   // Manifest the launch configuration in the metadata matching the kernel
6357   // environment.
6358   if (Attrs.MinTeams > 1 || Attrs.MaxTeams.front() > 0)
6359     writeTeamsForKernel(T, *Kernel, Attrs.MinTeams, Attrs.MaxTeams.front());
6360 
6361   // If MaxThreads not set, select the maximum between the default workgroup
6362   // size and the MinThreads value.
6363   int32_t MaxThreadsVal = Attrs.MaxThreads.front();
6364   if (MaxThreadsVal < 0)
6365     MaxThreadsVal = std::max(
6366         int32_t(getGridValue(T, Kernel).GV_Default_WG_Size), Attrs.MinThreads);
6367 
6368   if (MaxThreadsVal > 0)
6369     writeThreadBoundsForKernel(T, *Kernel, Attrs.MinThreads, MaxThreadsVal);
6370 
6371   Constant *MinThreads = ConstantInt::getSigned(Int32, Attrs.MinThreads);
6372   Constant *MaxThreads = ConstantInt::getSigned(Int32, MaxThreadsVal);
6373   Constant *MinTeams = ConstantInt::getSigned(Int32, Attrs.MinTeams);
6374   Constant *MaxTeams = ConstantInt::getSigned(Int32, Attrs.MaxTeams.front());
6375   Constant *ReductionDataSize =
6376       ConstantInt::getSigned(Int32, Attrs.ReductionDataSize);
6377   Constant *ReductionBufferLength =
6378       ConstantInt::getSigned(Int32, Attrs.ReductionBufferLength);
6379 
6380   Function *Fn = getOrCreateRuntimeFunctionPtr(
6381       omp::RuntimeFunction::OMPRTL___kmpc_target_init);
6382   const DataLayout &DL = Fn->getDataLayout();
6383 
6384   Twine DynamicEnvironmentName = KernelName + "_dynamic_environment";
6385   Constant *DynamicEnvironmentInitializer =
6386       ConstantStruct::get(DynamicEnvironment, {DebugIndentionLevelVal});
6387   GlobalVariable *DynamicEnvironmentGV = new GlobalVariable(
6388       M, DynamicEnvironment, /*IsConstant=*/false, GlobalValue::WeakODRLinkage,
6389       DynamicEnvironmentInitializer, DynamicEnvironmentName,
6390       /*InsertBefore=*/nullptr, GlobalValue::NotThreadLocal,
6391       DL.getDefaultGlobalsAddressSpace());
6392   DynamicEnvironmentGV->setVisibility(GlobalValue::ProtectedVisibility);
6393 
6394   Constant *DynamicEnvironment =
6395       DynamicEnvironmentGV->getType() == DynamicEnvironmentPtr
6396           ? DynamicEnvironmentGV
6397           : ConstantExpr::getAddrSpaceCast(DynamicEnvironmentGV,
6398                                            DynamicEnvironmentPtr);
6399 
6400   Constant *ConfigurationEnvironmentInitializer = ConstantStruct::get(
6401       ConfigurationEnvironment, {
6402                                     UseGenericStateMachineVal,
6403                                     MayUseNestedParallelismVal,
6404                                     IsSPMDVal,
6405                                     MinThreads,
6406                                     MaxThreads,
6407                                     MinTeams,
6408                                     MaxTeams,
6409                                     ReductionDataSize,
6410                                     ReductionBufferLength,
6411                                 });
6412   Constant *KernelEnvironmentInitializer = ConstantStruct::get(
6413       KernelEnvironment, {
6414                              ConfigurationEnvironmentInitializer,
6415                              Ident,
6416                              DynamicEnvironment,
6417                          });
6418   std::string KernelEnvironmentName =
6419       (KernelName + "_kernel_environment").str();
6420   GlobalVariable *KernelEnvironmentGV = new GlobalVariable(
6421       M, KernelEnvironment, /*IsConstant=*/true, GlobalValue::WeakODRLinkage,
6422       KernelEnvironmentInitializer, KernelEnvironmentName,
6423       /*InsertBefore=*/nullptr, GlobalValue::NotThreadLocal,
6424       DL.getDefaultGlobalsAddressSpace());
6425   KernelEnvironmentGV->setVisibility(GlobalValue::ProtectedVisibility);
6426 
6427   Constant *KernelEnvironment =
6428       KernelEnvironmentGV->getType() == KernelEnvironmentPtr
6429           ? KernelEnvironmentGV
6430           : ConstantExpr::getAddrSpaceCast(KernelEnvironmentGV,
6431                                            KernelEnvironmentPtr);
6432   Value *KernelLaunchEnvironment = DebugKernelWrapper->getArg(0);
6433   Type *KernelLaunchEnvParamTy = Fn->getFunctionType()->getParamType(1);
6434   KernelLaunchEnvironment =
6435       KernelLaunchEnvironment->getType() == KernelLaunchEnvParamTy
6436           ? KernelLaunchEnvironment
6437           : Builder.CreateAddrSpaceCast(KernelLaunchEnvironment,
6438                                         KernelLaunchEnvParamTy);
6439   CallInst *ThreadKind =
6440       Builder.CreateCall(Fn, {KernelEnvironment, KernelLaunchEnvironment});
6441 
6442   Value *ExecUserCode = Builder.CreateICmpEQ(
6443       ThreadKind, Constant::getAllOnesValue(ThreadKind->getType()),
6444       "exec_user_code");
6445 
6446   // ThreadKind = __kmpc_target_init(...)
6447   // if (ThreadKind == -1)
6448   //   user_code
6449   // else
6450   //   return;
6451 
6452   auto *UI = Builder.CreateUnreachable();
6453   BasicBlock *CheckBB = UI->getParent();
6454   BasicBlock *UserCodeEntryBB = CheckBB->splitBasicBlock(UI, "user_code.entry");
6455 
6456   BasicBlock *WorkerExitBB = BasicBlock::Create(
6457       CheckBB->getContext(), "worker.exit", CheckBB->getParent());
6458   Builder.SetInsertPoint(WorkerExitBB);
6459   Builder.CreateRetVoid();
6460 
6461   auto *CheckBBTI = CheckBB->getTerminator();
6462   Builder.SetInsertPoint(CheckBBTI);
6463   Builder.CreateCondBr(ExecUserCode, UI->getParent(), WorkerExitBB);
6464 
6465   CheckBBTI->eraseFromParent();
6466   UI->eraseFromParent();
6467 
6468   // Continue in the "user_code" block, see diagram above and in
6469   // openmp/libomptarget/deviceRTLs/common/include/target.h .
6470   return InsertPointTy(UserCodeEntryBB, UserCodeEntryBB->getFirstInsertionPt());
6471 }
6472 
createTargetDeinit(const LocationDescription & Loc,int32_t TeamsReductionDataSize,int32_t TeamsReductionBufferLength)6473 void OpenMPIRBuilder::createTargetDeinit(const LocationDescription &Loc,
6474                                          int32_t TeamsReductionDataSize,
6475                                          int32_t TeamsReductionBufferLength) {
6476   if (!updateToLocation(Loc))
6477     return;
6478 
6479   Function *Fn = getOrCreateRuntimeFunctionPtr(
6480       omp::RuntimeFunction::OMPRTL___kmpc_target_deinit);
6481 
6482   Builder.CreateCall(Fn, {});
6483 
6484   if (!TeamsReductionBufferLength || !TeamsReductionDataSize)
6485     return;
6486 
6487   Function *Kernel = Builder.GetInsertBlock()->getParent();
6488   // We need to strip the debug prefix to get the correct kernel name.
6489   StringRef KernelName = Kernel->getName();
6490   const std::string DebugPrefix = "_debug__";
6491   if (KernelName.ends_with(DebugPrefix))
6492     KernelName = KernelName.drop_back(DebugPrefix.length());
6493   auto *KernelEnvironmentGV =
6494       M.getNamedGlobal((KernelName + "_kernel_environment").str());
6495   assert(KernelEnvironmentGV && "Expected kernel environment global\n");
6496   auto *KernelEnvironmentInitializer = KernelEnvironmentGV->getInitializer();
6497   auto *NewInitializer = ConstantFoldInsertValueInstruction(
6498       KernelEnvironmentInitializer,
6499       ConstantInt::get(Int32, TeamsReductionDataSize), {0, 7});
6500   NewInitializer = ConstantFoldInsertValueInstruction(
6501       NewInitializer, ConstantInt::get(Int32, TeamsReductionBufferLength),
6502       {0, 8});
6503   KernelEnvironmentGV->setInitializer(NewInitializer);
6504 }
6505 
updateNVPTXAttr(Function & Kernel,StringRef Name,int32_t Value,bool Min)6506 static void updateNVPTXAttr(Function &Kernel, StringRef Name, int32_t Value,
6507                             bool Min) {
6508   if (Kernel.hasFnAttribute(Name)) {
6509     int32_t OldLimit = Kernel.getFnAttributeAsParsedInteger(Name);
6510     Value = Min ? std::min(OldLimit, Value) : std::max(OldLimit, Value);
6511   }
6512   Kernel.addFnAttr(Name, llvm::utostr(Value));
6513 }
6514 
6515 std::pair<int32_t, int32_t>
readThreadBoundsForKernel(const Triple & T,Function & Kernel)6516 OpenMPIRBuilder::readThreadBoundsForKernel(const Triple &T, Function &Kernel) {
6517   int32_t ThreadLimit =
6518       Kernel.getFnAttributeAsParsedInteger("omp_target_thread_limit");
6519 
6520   if (T.isAMDGPU()) {
6521     const auto &Attr = Kernel.getFnAttribute("amdgpu-flat-work-group-size");
6522     if (!Attr.isValid() || !Attr.isStringAttribute())
6523       return {0, ThreadLimit};
6524     auto [LBStr, UBStr] = Attr.getValueAsString().split(',');
6525     int32_t LB, UB;
6526     if (!llvm::to_integer(UBStr, UB, 10))
6527       return {0, ThreadLimit};
6528     UB = ThreadLimit ? std::min(ThreadLimit, UB) : UB;
6529     if (!llvm::to_integer(LBStr, LB, 10))
6530       return {0, UB};
6531     return {LB, UB};
6532   }
6533 
6534   if (Kernel.hasFnAttribute("nvvm.maxntid")) {
6535     int32_t UB = Kernel.getFnAttributeAsParsedInteger("nvvm.maxntid");
6536     return {0, ThreadLimit ? std::min(ThreadLimit, UB) : UB};
6537   }
6538   return {0, ThreadLimit};
6539 }
6540 
writeThreadBoundsForKernel(const Triple & T,Function & Kernel,int32_t LB,int32_t UB)6541 void OpenMPIRBuilder::writeThreadBoundsForKernel(const Triple &T,
6542                                                  Function &Kernel, int32_t LB,
6543                                                  int32_t UB) {
6544   Kernel.addFnAttr("omp_target_thread_limit", std::to_string(UB));
6545 
6546   if (T.isAMDGPU()) {
6547     Kernel.addFnAttr("amdgpu-flat-work-group-size",
6548                      llvm::utostr(LB) + "," + llvm::utostr(UB));
6549     return;
6550   }
6551 
6552   updateNVPTXAttr(Kernel, "nvvm.maxntid", UB, true);
6553 }
6554 
6555 std::pair<int32_t, int32_t>
readTeamBoundsForKernel(const Triple &,Function & Kernel)6556 OpenMPIRBuilder::readTeamBoundsForKernel(const Triple &, Function &Kernel) {
6557   // TODO: Read from backend annotations if available.
6558   return {0, Kernel.getFnAttributeAsParsedInteger("omp_target_num_teams")};
6559 }
6560 
writeTeamsForKernel(const Triple & T,Function & Kernel,int32_t LB,int32_t UB)6561 void OpenMPIRBuilder::writeTeamsForKernel(const Triple &T, Function &Kernel,
6562                                           int32_t LB, int32_t UB) {
6563   if (T.isNVPTX())
6564     if (UB > 0)
6565       Kernel.addFnAttr("nvvm.maxclusterrank", llvm::utostr(UB));
6566   if (T.isAMDGPU())
6567     Kernel.addFnAttr("amdgpu-max-num-workgroups", llvm::utostr(LB) + ",1,1");
6568 
6569   Kernel.addFnAttr("omp_target_num_teams", std::to_string(LB));
6570 }
6571 
setOutlinedTargetRegionFunctionAttributes(Function * OutlinedFn)6572 void OpenMPIRBuilder::setOutlinedTargetRegionFunctionAttributes(
6573     Function *OutlinedFn) {
6574   if (Config.isTargetDevice()) {
6575     OutlinedFn->setLinkage(GlobalValue::WeakODRLinkage);
6576     // TODO: Determine if DSO local can be set to true.
6577     OutlinedFn->setDSOLocal(false);
6578     OutlinedFn->setVisibility(GlobalValue::ProtectedVisibility);
6579     if (T.isAMDGCN())
6580       OutlinedFn->setCallingConv(CallingConv::AMDGPU_KERNEL);
6581     else if (T.isNVPTX())
6582       OutlinedFn->setCallingConv(CallingConv::PTX_Kernel);
6583     else if (T.isSPIRV())
6584       OutlinedFn->setCallingConv(CallingConv::SPIR_KERNEL);
6585   }
6586 }
6587 
createOutlinedFunctionID(Function * OutlinedFn,StringRef EntryFnIDName)6588 Constant *OpenMPIRBuilder::createOutlinedFunctionID(Function *OutlinedFn,
6589                                                     StringRef EntryFnIDName) {
6590   if (Config.isTargetDevice()) {
6591     assert(OutlinedFn && "The outlined function must exist if embedded");
6592     return OutlinedFn;
6593   }
6594 
6595   return new GlobalVariable(
6596       M, Builder.getInt8Ty(), /*isConstant=*/true, GlobalValue::WeakAnyLinkage,
6597       Constant::getNullValue(Builder.getInt8Ty()), EntryFnIDName);
6598 }
6599 
createTargetRegionEntryAddr(Function * OutlinedFn,StringRef EntryFnName)6600 Constant *OpenMPIRBuilder::createTargetRegionEntryAddr(Function *OutlinedFn,
6601                                                        StringRef EntryFnName) {
6602   if (OutlinedFn)
6603     return OutlinedFn;
6604 
6605   assert(!M.getGlobalVariable(EntryFnName, true) &&
6606          "Named kernel already exists?");
6607   return new GlobalVariable(
6608       M, Builder.getInt8Ty(), /*isConstant=*/true, GlobalValue::InternalLinkage,
6609       Constant::getNullValue(Builder.getInt8Ty()), EntryFnName);
6610 }
6611 
emitTargetRegionFunction(TargetRegionEntryInfo & EntryInfo,FunctionGenCallback & GenerateFunctionCallback,bool IsOffloadEntry,Function * & OutlinedFn,Constant * & OutlinedFnID)6612 Error OpenMPIRBuilder::emitTargetRegionFunction(
6613     TargetRegionEntryInfo &EntryInfo,
6614     FunctionGenCallback &GenerateFunctionCallback, bool IsOffloadEntry,
6615     Function *&OutlinedFn, Constant *&OutlinedFnID) {
6616 
6617   SmallString<64> EntryFnName;
6618   OffloadInfoManager.getTargetRegionEntryFnName(EntryFnName, EntryInfo);
6619 
6620   if (Config.isTargetDevice() || !Config.openMPOffloadMandatory()) {
6621     Expected<Function *> CBResult = GenerateFunctionCallback(EntryFnName);
6622     if (!CBResult)
6623       return CBResult.takeError();
6624     OutlinedFn = *CBResult;
6625   } else {
6626     OutlinedFn = nullptr;
6627   }
6628 
6629   // If this target outline function is not an offload entry, we don't need to
6630   // register it. This may be in the case of a false if clause, or if there are
6631   // no OpenMP targets.
6632   if (!IsOffloadEntry)
6633     return Error::success();
6634 
6635   std::string EntryFnIDName =
6636       Config.isTargetDevice()
6637           ? std::string(EntryFnName)
6638           : createPlatformSpecificName({EntryFnName, "region_id"});
6639 
6640   OutlinedFnID = registerTargetRegionFunction(EntryInfo, OutlinedFn,
6641                                               EntryFnName, EntryFnIDName);
6642   return Error::success();
6643 }
6644 
registerTargetRegionFunction(TargetRegionEntryInfo & EntryInfo,Function * OutlinedFn,StringRef EntryFnName,StringRef EntryFnIDName)6645 Constant *OpenMPIRBuilder::registerTargetRegionFunction(
6646     TargetRegionEntryInfo &EntryInfo, Function *OutlinedFn,
6647     StringRef EntryFnName, StringRef EntryFnIDName) {
6648   if (OutlinedFn)
6649     setOutlinedTargetRegionFunctionAttributes(OutlinedFn);
6650   auto OutlinedFnID = createOutlinedFunctionID(OutlinedFn, EntryFnIDName);
6651   auto EntryAddr = createTargetRegionEntryAddr(OutlinedFn, EntryFnName);
6652   OffloadInfoManager.registerTargetRegionEntryInfo(
6653       EntryInfo, EntryAddr, OutlinedFnID,
6654       OffloadEntriesInfoManager::OMPTargetRegionEntryTargetRegion);
6655   return OutlinedFnID;
6656 }
6657 
createTargetData(const LocationDescription & Loc,InsertPointTy AllocaIP,InsertPointTy CodeGenIP,Value * DeviceID,Value * IfCond,TargetDataInfo & Info,GenMapInfoCallbackTy GenMapInfoCB,CustomMapperCallbackTy CustomMapperCB,omp::RuntimeFunction * MapperFunc,function_ref<InsertPointOrErrorTy (InsertPointTy CodeGenIP,BodyGenTy BodyGenType)> BodyGenCB,function_ref<void (unsigned int,Value *)> DeviceAddrCB,Value * SrcLocInfo)6658 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTargetData(
6659     const LocationDescription &Loc, InsertPointTy AllocaIP,
6660     InsertPointTy CodeGenIP, Value *DeviceID, Value *IfCond,
6661     TargetDataInfo &Info, GenMapInfoCallbackTy GenMapInfoCB,
6662     CustomMapperCallbackTy CustomMapperCB, omp::RuntimeFunction *MapperFunc,
6663     function_ref<InsertPointOrErrorTy(InsertPointTy CodeGenIP,
6664                                       BodyGenTy BodyGenType)>
6665         BodyGenCB,
6666     function_ref<void(unsigned int, Value *)> DeviceAddrCB, Value *SrcLocInfo) {
6667   if (!updateToLocation(Loc))
6668     return InsertPointTy();
6669 
6670   Builder.restoreIP(CodeGenIP);
6671   // Disable TargetData CodeGen on Device pass.
6672   if (Config.IsTargetDevice.value_or(false)) {
6673     if (BodyGenCB) {
6674       InsertPointOrErrorTy AfterIP =
6675           BodyGenCB(Builder.saveIP(), BodyGenTy::NoPriv);
6676       if (!AfterIP)
6677         return AfterIP.takeError();
6678       Builder.restoreIP(*AfterIP);
6679     }
6680     return Builder.saveIP();
6681   }
6682 
6683   bool IsStandAlone = !BodyGenCB;
6684   MapInfosTy *MapInfo;
6685   // Generate the code for the opening of the data environment. Capture all the
6686   // arguments of the runtime call by reference because they are used in the
6687   // closing of the region.
6688   auto BeginThenGen = [&](InsertPointTy AllocaIP,
6689                           InsertPointTy CodeGenIP) -> Error {
6690     MapInfo = &GenMapInfoCB(Builder.saveIP());
6691     if (Error Err = emitOffloadingArrays(
6692             AllocaIP, Builder.saveIP(), *MapInfo, Info, CustomMapperCB,
6693             /*IsNonContiguous=*/true, DeviceAddrCB))
6694       return Err;
6695 
6696     TargetDataRTArgs RTArgs;
6697     emitOffloadingArraysArgument(Builder, RTArgs, Info);
6698 
6699     // Emit the number of elements in the offloading arrays.
6700     Value *PointerNum = Builder.getInt32(Info.NumberOfPtrs);
6701 
6702     // Source location for the ident struct
6703     if (!SrcLocInfo) {
6704       uint32_t SrcLocStrSize;
6705       Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
6706       SrcLocInfo = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
6707     }
6708 
6709     SmallVector<llvm::Value *, 13> OffloadingArgs = {
6710         SrcLocInfo,           DeviceID,
6711         PointerNum,           RTArgs.BasePointersArray,
6712         RTArgs.PointersArray, RTArgs.SizesArray,
6713         RTArgs.MapTypesArray, RTArgs.MapNamesArray,
6714         RTArgs.MappersArray};
6715 
6716     if (IsStandAlone) {
6717       assert(MapperFunc && "MapperFunc missing for standalone target data");
6718 
6719       auto TaskBodyCB = [&](Value *, Value *,
6720                             IRBuilderBase::InsertPoint) -> Error {
6721         if (Info.HasNoWait) {
6722           OffloadingArgs.append({llvm::Constant::getNullValue(Int32),
6723                                  llvm::Constant::getNullValue(VoidPtr),
6724                                  llvm::Constant::getNullValue(Int32),
6725                                  llvm::Constant::getNullValue(VoidPtr)});
6726         }
6727 
6728         Builder.CreateCall(getOrCreateRuntimeFunctionPtr(*MapperFunc),
6729                            OffloadingArgs);
6730 
6731         if (Info.HasNoWait) {
6732           BasicBlock *OffloadContBlock =
6733               BasicBlock::Create(Builder.getContext(), "omp_offload.cont");
6734           Function *CurFn = Builder.GetInsertBlock()->getParent();
6735           emitBlock(OffloadContBlock, CurFn, /*IsFinished=*/true);
6736           Builder.restoreIP(Builder.saveIP());
6737         }
6738         return Error::success();
6739       };
6740 
6741       bool RequiresOuterTargetTask = Info.HasNoWait;
6742       if (!RequiresOuterTargetTask)
6743         cantFail(TaskBodyCB(/*DeviceID=*/nullptr, /*RTLoc=*/nullptr,
6744                             /*TargetTaskAllocaIP=*/{}));
6745       else
6746         cantFail(emitTargetTask(TaskBodyCB, DeviceID, SrcLocInfo, AllocaIP,
6747                                 /*Dependencies=*/{}, RTArgs, Info.HasNoWait));
6748     } else {
6749       Function *BeginMapperFunc = getOrCreateRuntimeFunctionPtr(
6750           omp::OMPRTL___tgt_target_data_begin_mapper);
6751 
6752       Builder.CreateCall(BeginMapperFunc, OffloadingArgs);
6753 
6754       for (auto DeviceMap : Info.DevicePtrInfoMap) {
6755         if (isa<AllocaInst>(DeviceMap.second.second)) {
6756           auto *LI =
6757               Builder.CreateLoad(Builder.getPtrTy(), DeviceMap.second.first);
6758           Builder.CreateStore(LI, DeviceMap.second.second);
6759         }
6760       }
6761 
6762       // If device pointer privatization is required, emit the body of the
6763       // region here. It will have to be duplicated: with and without
6764       // privatization.
6765       InsertPointOrErrorTy AfterIP =
6766           BodyGenCB(Builder.saveIP(), BodyGenTy::Priv);
6767       if (!AfterIP)
6768         return AfterIP.takeError();
6769       Builder.restoreIP(*AfterIP);
6770     }
6771     return Error::success();
6772   };
6773 
6774   // If we need device pointer privatization, we need to emit the body of the
6775   // region with no privatization in the 'else' branch of the conditional.
6776   // Otherwise, we don't have to do anything.
6777   auto BeginElseGen = [&](InsertPointTy AllocaIP,
6778                           InsertPointTy CodeGenIP) -> Error {
6779     InsertPointOrErrorTy AfterIP =
6780         BodyGenCB(Builder.saveIP(), BodyGenTy::DupNoPriv);
6781     if (!AfterIP)
6782       return AfterIP.takeError();
6783     Builder.restoreIP(*AfterIP);
6784     return Error::success();
6785   };
6786 
6787   // Generate code for the closing of the data region.
6788   auto EndThenGen = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
6789     TargetDataRTArgs RTArgs;
6790     Info.EmitDebug = !MapInfo->Names.empty();
6791     emitOffloadingArraysArgument(Builder, RTArgs, Info, /*ForEndCall=*/true);
6792 
6793     // Emit the number of elements in the offloading arrays.
6794     Value *PointerNum = Builder.getInt32(Info.NumberOfPtrs);
6795 
6796     // Source location for the ident struct
6797     if (!SrcLocInfo) {
6798       uint32_t SrcLocStrSize;
6799       Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
6800       SrcLocInfo = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
6801     }
6802 
6803     Value *OffloadingArgs[] = {SrcLocInfo,           DeviceID,
6804                                PointerNum,           RTArgs.BasePointersArray,
6805                                RTArgs.PointersArray, RTArgs.SizesArray,
6806                                RTArgs.MapTypesArray, RTArgs.MapNamesArray,
6807                                RTArgs.MappersArray};
6808     Function *EndMapperFunc =
6809         getOrCreateRuntimeFunctionPtr(omp::OMPRTL___tgt_target_data_end_mapper);
6810 
6811     Builder.CreateCall(EndMapperFunc, OffloadingArgs);
6812     return Error::success();
6813   };
6814 
6815   // We don't have to do anything to close the region if the if clause evaluates
6816   // to false.
6817   auto EndElseGen = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
6818     return Error::success();
6819   };
6820 
6821   Error Err = [&]() -> Error {
6822     if (BodyGenCB) {
6823       Error Err = [&]() {
6824         if (IfCond)
6825           return emitIfClause(IfCond, BeginThenGen, BeginElseGen, AllocaIP);
6826         return BeginThenGen(AllocaIP, Builder.saveIP());
6827       }();
6828 
6829       if (Err)
6830         return Err;
6831 
6832       // If we don't require privatization of device pointers, we emit the body
6833       // in between the runtime calls. This avoids duplicating the body code.
6834       InsertPointOrErrorTy AfterIP =
6835           BodyGenCB(Builder.saveIP(), BodyGenTy::NoPriv);
6836       if (!AfterIP)
6837         return AfterIP.takeError();
6838       Builder.restoreIP(*AfterIP);
6839 
6840       if (IfCond)
6841         return emitIfClause(IfCond, EndThenGen, EndElseGen, AllocaIP);
6842       return EndThenGen(AllocaIP, Builder.saveIP());
6843     }
6844     if (IfCond)
6845       return emitIfClause(IfCond, BeginThenGen, EndElseGen, AllocaIP);
6846     return BeginThenGen(AllocaIP, Builder.saveIP());
6847   }();
6848 
6849   if (Err)
6850     return Err;
6851 
6852   return Builder.saveIP();
6853 }
6854 
6855 FunctionCallee
createForStaticInitFunction(unsigned IVSize,bool IVSigned,bool IsGPUDistribute)6856 OpenMPIRBuilder::createForStaticInitFunction(unsigned IVSize, bool IVSigned,
6857                                              bool IsGPUDistribute) {
6858   assert((IVSize == 32 || IVSize == 64) &&
6859          "IV size is not compatible with the omp runtime");
6860   RuntimeFunction Name;
6861   if (IsGPUDistribute)
6862     Name = IVSize == 32
6863                ? (IVSigned ? omp::OMPRTL___kmpc_distribute_static_init_4
6864                            : omp::OMPRTL___kmpc_distribute_static_init_4u)
6865                : (IVSigned ? omp::OMPRTL___kmpc_distribute_static_init_8
6866                            : omp::OMPRTL___kmpc_distribute_static_init_8u);
6867   else
6868     Name = IVSize == 32 ? (IVSigned ? omp::OMPRTL___kmpc_for_static_init_4
6869                                     : omp::OMPRTL___kmpc_for_static_init_4u)
6870                         : (IVSigned ? omp::OMPRTL___kmpc_for_static_init_8
6871                                     : omp::OMPRTL___kmpc_for_static_init_8u);
6872 
6873   return getOrCreateRuntimeFunction(M, Name);
6874 }
6875 
createDispatchInitFunction(unsigned IVSize,bool IVSigned)6876 FunctionCallee OpenMPIRBuilder::createDispatchInitFunction(unsigned IVSize,
6877                                                            bool IVSigned) {
6878   assert((IVSize == 32 || IVSize == 64) &&
6879          "IV size is not compatible with the omp runtime");
6880   RuntimeFunction Name = IVSize == 32
6881                              ? (IVSigned ? omp::OMPRTL___kmpc_dispatch_init_4
6882                                          : omp::OMPRTL___kmpc_dispatch_init_4u)
6883                              : (IVSigned ? omp::OMPRTL___kmpc_dispatch_init_8
6884                                          : omp::OMPRTL___kmpc_dispatch_init_8u);
6885 
6886   return getOrCreateRuntimeFunction(M, Name);
6887 }
6888 
createDispatchNextFunction(unsigned IVSize,bool IVSigned)6889 FunctionCallee OpenMPIRBuilder::createDispatchNextFunction(unsigned IVSize,
6890                                                            bool IVSigned) {
6891   assert((IVSize == 32 || IVSize == 64) &&
6892          "IV size is not compatible with the omp runtime");
6893   RuntimeFunction Name = IVSize == 32
6894                              ? (IVSigned ? omp::OMPRTL___kmpc_dispatch_next_4
6895                                          : omp::OMPRTL___kmpc_dispatch_next_4u)
6896                              : (IVSigned ? omp::OMPRTL___kmpc_dispatch_next_8
6897                                          : omp::OMPRTL___kmpc_dispatch_next_8u);
6898 
6899   return getOrCreateRuntimeFunction(M, Name);
6900 }
6901 
createDispatchFiniFunction(unsigned IVSize,bool IVSigned)6902 FunctionCallee OpenMPIRBuilder::createDispatchFiniFunction(unsigned IVSize,
6903                                                            bool IVSigned) {
6904   assert((IVSize == 32 || IVSize == 64) &&
6905          "IV size is not compatible with the omp runtime");
6906   RuntimeFunction Name = IVSize == 32
6907                              ? (IVSigned ? omp::OMPRTL___kmpc_dispatch_fini_4
6908                                          : omp::OMPRTL___kmpc_dispatch_fini_4u)
6909                              : (IVSigned ? omp::OMPRTL___kmpc_dispatch_fini_8
6910                                          : omp::OMPRTL___kmpc_dispatch_fini_8u);
6911 
6912   return getOrCreateRuntimeFunction(M, Name);
6913 }
6914 
createDispatchDeinitFunction()6915 FunctionCallee OpenMPIRBuilder::createDispatchDeinitFunction() {
6916   return getOrCreateRuntimeFunction(M, omp::OMPRTL___kmpc_dispatch_deinit);
6917 }
6918 
FixupDebugInfoForOutlinedFunction(OpenMPIRBuilder & OMPBuilder,IRBuilderBase & Builder,Function * Func,DenseMap<Value *,std::tuple<Value *,unsigned>> & ValueReplacementMap)6919 static void FixupDebugInfoForOutlinedFunction(
6920     OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, Function *Func,
6921     DenseMap<Value *, std::tuple<Value *, unsigned>> &ValueReplacementMap) {
6922 
6923   DISubprogram *NewSP = Func->getSubprogram();
6924   if (!NewSP)
6925     return;
6926 
6927   SmallDenseMap<DILocalVariable *, DILocalVariable *> RemappedVariables;
6928 
6929   auto GetUpdatedDIVariable = [&](DILocalVariable *OldVar, unsigned arg) {
6930     DILocalVariable *&NewVar = RemappedVariables[OldVar];
6931     // Only use cached variable if the arg number matches. This is important
6932     // so that DIVariable created for privatized variables are not discarded.
6933     if (NewVar && (arg == NewVar->getArg()))
6934       return NewVar;
6935 
6936     NewVar = llvm::DILocalVariable::get(
6937         Builder.getContext(), OldVar->getScope(), OldVar->getName(),
6938         OldVar->getFile(), OldVar->getLine(), OldVar->getType(), arg,
6939         OldVar->getFlags(), OldVar->getAlignInBits(), OldVar->getAnnotations());
6940     return NewVar;
6941   };
6942 
6943   auto UpdateDebugRecord = [&](auto *DR) {
6944     DILocalVariable *OldVar = DR->getVariable();
6945     unsigned ArgNo = 0;
6946     for (auto Loc : DR->location_ops()) {
6947       auto Iter = ValueReplacementMap.find(Loc);
6948       if (Iter != ValueReplacementMap.end()) {
6949         DR->replaceVariableLocationOp(Loc, std::get<0>(Iter->second));
6950         ArgNo = std::get<1>(Iter->second) + 1;
6951       }
6952     }
6953     if (ArgNo != 0)
6954       DR->setVariable(GetUpdatedDIVariable(OldVar, ArgNo));
6955   };
6956 
6957   // The location and scope of variable intrinsics and records still point to
6958   // the parent function of the target region. Update them.
6959   for (Instruction &I : instructions(Func)) {
6960     if (auto *DDI = dyn_cast<llvm::DbgVariableIntrinsic>(&I))
6961       UpdateDebugRecord(DDI);
6962 
6963     for (DbgVariableRecord &DVR : filterDbgVars(I.getDbgRecordRange()))
6964       UpdateDebugRecord(&DVR);
6965   }
6966   // An extra argument is passed to the device. Create the debug data for it.
6967   if (OMPBuilder.Config.isTargetDevice()) {
6968     DICompileUnit *CU = NewSP->getUnit();
6969     Module *M = Func->getParent();
6970     DIBuilder DB(*M, true, CU);
6971     DIType *VoidPtrTy =
6972         DB.createQualifiedType(dwarf::DW_TAG_pointer_type, nullptr);
6973     DILocalVariable *Var = DB.createParameterVariable(
6974         NewSP, "dyn_ptr", /*ArgNo*/ 1, NewSP->getFile(), /*LineNo=*/0,
6975         VoidPtrTy, /*AlwaysPreserve=*/false, DINode::DIFlags::FlagArtificial);
6976     auto Loc = DILocation::get(Func->getContext(), 0, 0, NewSP, 0);
6977     DB.insertDeclare(&(*Func->arg_begin()), Var, DB.createExpression(), Loc,
6978                      &(*Func->begin()));
6979   }
6980 }
6981 
createOutlinedFunction(OpenMPIRBuilder & OMPBuilder,IRBuilderBase & Builder,const OpenMPIRBuilder::TargetKernelDefaultAttrs & DefaultAttrs,StringRef FuncName,SmallVectorImpl<Value * > & Inputs,OpenMPIRBuilder::TargetBodyGenCallbackTy & CBFunc,OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy & ArgAccessorFuncCB)6982 static Expected<Function *> createOutlinedFunction(
6983     OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
6984     const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
6985     StringRef FuncName, SmallVectorImpl<Value *> &Inputs,
6986     OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
6987     OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy &ArgAccessorFuncCB) {
6988   SmallVector<Type *> ParameterTypes;
6989   if (OMPBuilder.Config.isTargetDevice()) {
6990     // Add the "implicit" runtime argument we use to provide launch specific
6991     // information for target devices.
6992     auto *Int8PtrTy = PointerType::getUnqual(Builder.getContext());
6993     ParameterTypes.push_back(Int8PtrTy);
6994 
6995     // All parameters to target devices are passed as pointers
6996     // or i64. This assumes 64-bit address spaces/pointers.
6997     for (auto &Arg : Inputs)
6998       ParameterTypes.push_back(Arg->getType()->isPointerTy()
6999                                    ? Arg->getType()
7000                                    : Type::getInt64Ty(Builder.getContext()));
7001   } else {
7002     for (auto &Arg : Inputs)
7003       ParameterTypes.push_back(Arg->getType());
7004   }
7005 
7006   auto BB = Builder.GetInsertBlock();
7007   auto M = BB->getModule();
7008   auto FuncType = FunctionType::get(Builder.getVoidTy(), ParameterTypes,
7009                                     /*isVarArg*/ false);
7010   auto Func =
7011       Function::Create(FuncType, GlobalValue::InternalLinkage, FuncName, M);
7012 
7013   // Forward target-cpu and target-features function attributes from the
7014   // original function to the new outlined function.
7015   Function *ParentFn = Builder.GetInsertBlock()->getParent();
7016 
7017   auto TargetCpuAttr = ParentFn->getFnAttribute("target-cpu");
7018   if (TargetCpuAttr.isStringAttribute())
7019     Func->addFnAttr(TargetCpuAttr);
7020 
7021   auto TargetFeaturesAttr = ParentFn->getFnAttribute("target-features");
7022   if (TargetFeaturesAttr.isStringAttribute())
7023     Func->addFnAttr(TargetFeaturesAttr);
7024 
7025   if (OMPBuilder.Config.isTargetDevice()) {
7026     Value *ExecMode =
7027         OMPBuilder.emitKernelExecutionMode(FuncName, DefaultAttrs.ExecFlags);
7028     OMPBuilder.emitUsed("llvm.compiler.used", {ExecMode});
7029   }
7030 
7031   // Save insert point.
7032   IRBuilder<>::InsertPointGuard IPG(Builder);
7033   // We will generate the entries in the outlined function but the debug
7034   // location may still be pointing to the parent function. Reset it now.
7035   Builder.SetCurrentDebugLocation(llvm::DebugLoc());
7036 
7037   // Generate the region into the function.
7038   BasicBlock *EntryBB = BasicBlock::Create(Builder.getContext(), "entry", Func);
7039   Builder.SetInsertPoint(EntryBB);
7040 
7041   // Insert target init call in the device compilation pass.
7042   if (OMPBuilder.Config.isTargetDevice())
7043     Builder.restoreIP(OMPBuilder.createTargetInit(Builder, DefaultAttrs));
7044 
7045   BasicBlock *UserCodeEntryBB = Builder.GetInsertBlock();
7046 
7047   // As we embed the user code in the middle of our target region after we
7048   // generate entry code, we must move what allocas we can into the entry
7049   // block to avoid possible breaking optimisations for device
7050   if (OMPBuilder.Config.isTargetDevice())
7051     OMPBuilder.ConstantAllocaRaiseCandidates.emplace_back(Func);
7052 
7053   // Insert target deinit call in the device compilation pass.
7054   BasicBlock *OutlinedBodyBB =
7055       splitBB(Builder, /*CreateBranch=*/true, "outlined.body");
7056   llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = CBFunc(
7057       Builder.saveIP(),
7058       OpenMPIRBuilder::InsertPointTy(OutlinedBodyBB, OutlinedBodyBB->begin()));
7059   if (!AfterIP)
7060     return AfterIP.takeError();
7061   Builder.restoreIP(*AfterIP);
7062   if (OMPBuilder.Config.isTargetDevice())
7063     OMPBuilder.createTargetDeinit(Builder);
7064 
7065   // Insert return instruction.
7066   Builder.CreateRetVoid();
7067 
7068   // New Alloca IP at entry point of created device function.
7069   Builder.SetInsertPoint(EntryBB->getFirstNonPHIIt());
7070   auto AllocaIP = Builder.saveIP();
7071 
7072   Builder.SetInsertPoint(UserCodeEntryBB->getFirstNonPHIOrDbg());
7073 
7074   // Skip the artificial dyn_ptr on the device.
7075   const auto &ArgRange =
7076       OMPBuilder.Config.isTargetDevice()
7077           ? make_range(Func->arg_begin() + 1, Func->arg_end())
7078           : Func->args();
7079 
7080   DenseMap<Value *, std::tuple<Value *, unsigned>> ValueReplacementMap;
7081 
7082   auto ReplaceValue = [](Value *Input, Value *InputCopy, Function *Func) {
7083     // Things like GEP's can come in the form of Constants. Constants and
7084     // ConstantExpr's do not have access to the knowledge of what they're
7085     // contained in, so we must dig a little to find an instruction so we
7086     // can tell if they're used inside of the function we're outlining. We
7087     // also replace the original constant expression with a new instruction
7088     // equivalent; an instruction as it allows easy modification in the
7089     // following loop, as we can now know the constant (instruction) is
7090     // owned by our target function and replaceUsesOfWith can now be invoked
7091     // on it (cannot do this with constants it seems). A brand new one also
7092     // allows us to be cautious as it is perhaps possible the old expression
7093     // was used inside of the function but exists and is used externally
7094     // (unlikely by the nature of a Constant, but still).
7095     // NOTE: We cannot remove dead constants that have been rewritten to
7096     // instructions at this stage, we run the risk of breaking later lowering
7097     // by doing so as we could still be in the process of lowering the module
7098     // from MLIR to LLVM-IR and the MLIR lowering may still require the original
7099     // constants we have created rewritten versions of.
7100     if (auto *Const = dyn_cast<Constant>(Input))
7101       convertUsersOfConstantsToInstructions(Const, Func, false);
7102 
7103     // Collect users before iterating over them to avoid invalidating the
7104     // iteration in case a user uses Input more than once (e.g. a call
7105     // instruction).
7106     SetVector<User *> Users(Input->users().begin(), Input->users().end());
7107     // Collect all the instructions
7108     for (User *User : make_early_inc_range(Users))
7109       if (auto *Instr = dyn_cast<Instruction>(User))
7110         if (Instr->getFunction() == Func)
7111           Instr->replaceUsesOfWith(Input, InputCopy);
7112   };
7113 
7114   SmallVector<std::pair<Value *, Value *>> DeferredReplacement;
7115 
7116   // Rewrite uses of input valus to parameters.
7117   for (auto InArg : zip(Inputs, ArgRange)) {
7118     Value *Input = std::get<0>(InArg);
7119     Argument &Arg = std::get<1>(InArg);
7120     Value *InputCopy = nullptr;
7121 
7122     llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
7123         ArgAccessorFuncCB(Arg, Input, InputCopy, AllocaIP, Builder.saveIP());
7124     if (!AfterIP)
7125       return AfterIP.takeError();
7126     Builder.restoreIP(*AfterIP);
7127     ValueReplacementMap[Input] = std::make_tuple(InputCopy, Arg.getArgNo());
7128 
7129     // In certain cases a Global may be set up for replacement, however, this
7130     // Global may be used in multiple arguments to the kernel, just segmented
7131     // apart, for example, if we have a global array, that is sectioned into
7132     // multiple mappings (technically not legal in OpenMP, but there is a case
7133     // in Fortran for Common Blocks where this is neccesary), we will end up
7134     // with GEP's into this array inside the kernel, that refer to the Global
7135     // but are technically seperate arguments to the kernel for all intents and
7136     // purposes. If we have mapped a segment that requires a GEP into the 0-th
7137     // index, it will fold into an referal to the Global, if we then encounter
7138     // this folded GEP during replacement all of the references to the
7139     // Global in the kernel will be replaced with the argument we have generated
7140     // that corresponds to it, including any other GEP's that refer to the
7141     // Global that may be other arguments. This will invalidate all of the other
7142     // preceding mapped arguments that refer to the same global that may be
7143     // seperate segments. To prevent this, we defer global processing until all
7144     // other processing has been performed.
7145     if (isa<GlobalValue>(Input)) {
7146       DeferredReplacement.push_back(std::make_pair(Input, InputCopy));
7147       continue;
7148     }
7149 
7150     if (isa<ConstantData>(Input))
7151       continue;
7152 
7153     ReplaceValue(Input, InputCopy, Func);
7154   }
7155 
7156   // Replace all of our deferred Input values, currently just Globals.
7157   for (auto Deferred : DeferredReplacement)
7158     ReplaceValue(std::get<0>(Deferred), std::get<1>(Deferred), Func);
7159 
7160   FixupDebugInfoForOutlinedFunction(OMPBuilder, Builder, Func,
7161                                     ValueReplacementMap);
7162   return Func;
7163 }
7164 /// Given a task descriptor, TaskWithPrivates, return the pointer to the block
7165 /// of pointers containing shared data between the parent task and the created
7166 /// task.
loadSharedDataFromTaskDescriptor(OpenMPIRBuilder & OMPIRBuilder,IRBuilderBase & Builder,Value * TaskWithPrivates,Type * TaskWithPrivatesTy)7167 static LoadInst *loadSharedDataFromTaskDescriptor(OpenMPIRBuilder &OMPIRBuilder,
7168                                                   IRBuilderBase &Builder,
7169                                                   Value *TaskWithPrivates,
7170                                                   Type *TaskWithPrivatesTy) {
7171 
7172   Type *TaskTy = OMPIRBuilder.Task;
7173   LLVMContext &Ctx = Builder.getContext();
7174   Value *TaskT =
7175       Builder.CreateStructGEP(TaskWithPrivatesTy, TaskWithPrivates, 0);
7176   Value *Shareds = TaskT;
7177   // TaskWithPrivatesTy can be one of the following
7178   // 1. %struct.task_with_privates = type { %struct.kmp_task_ompbuilder_t,
7179   //                                        %struct.privates }
7180   // 2. %struct.kmp_task_ompbuilder_t ;; This is simply TaskTy
7181   //
7182   // In the former case, that is when  TaskWithPrivatesTy != TaskTy,
7183   // its first member has to be the task descriptor. TaskTy is the type of the
7184   // task descriptor. TaskT is the pointer to the task descriptor. Loading the
7185   // first member of TaskT, gives us the pointer to shared data.
7186   if (TaskWithPrivatesTy != TaskTy)
7187     Shareds = Builder.CreateStructGEP(TaskTy, TaskT, 0);
7188   return Builder.CreateLoad(PointerType::getUnqual(Ctx), Shareds);
7189 }
7190 /// Create an entry point for a target task with the following.
7191 /// It'll have the following signature
7192 /// void @.omp_target_task_proxy_func(i32 %thread.id, ptr %task)
7193 /// This function is called from emitTargetTask once the
7194 /// code to launch the target kernel has been outlined already.
7195 /// NumOffloadingArrays is the number of offloading arrays that we need to copy
7196 /// into the task structure so that the deferred target task can access this
7197 /// data even after the stack frame of the generating task has been rolled
7198 /// back. Offloading arrays contain base pointers, pointers, sizes etc
7199 /// of the data that the target kernel will access. These in effect are the
7200 /// non-empty arrays of pointers held by OpenMPIRBuilder::TargetDataRTArgs.
emitTargetTaskProxyFunction(OpenMPIRBuilder & OMPBuilder,IRBuilderBase & Builder,CallInst * StaleCI,StructType * PrivatesTy,StructType * TaskWithPrivatesTy,const size_t NumOffloadingArrays,const int SharedArgsOperandNo)7201 static Function *emitTargetTaskProxyFunction(
7202     OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, CallInst *StaleCI,
7203     StructType *PrivatesTy, StructType *TaskWithPrivatesTy,
7204     const size_t NumOffloadingArrays, const int SharedArgsOperandNo) {
7205 
7206   // If NumOffloadingArrays is non-zero, PrivatesTy better not be nullptr.
7207   // This is because PrivatesTy is the type of the structure in which
7208   // we pass the offloading arrays to the deferred target task.
7209   assert((!NumOffloadingArrays || PrivatesTy) &&
7210          "PrivatesTy cannot be nullptr when there are offloadingArrays"
7211          "to privatize");
7212 
7213   Module &M = OMPBuilder.M;
7214   // KernelLaunchFunction is the target launch function, i.e.
7215   // the function that sets up kernel arguments and calls
7216   // __tgt_target_kernel to launch the kernel on the device.
7217   //
7218   Function *KernelLaunchFunction = StaleCI->getCalledFunction();
7219 
7220   // StaleCI is the CallInst which is the call to the outlined
7221   // target kernel launch function. If there are local live-in values
7222   // that the outlined function uses then these are aggregated into a structure
7223   // which is passed as the second argument. If there are no local live-in
7224   // values or if all values used by the outlined kernel are global variables,
7225   // then there's only one argument, the threadID. So, StaleCI can be
7226   //
7227   // %structArg = alloca { ptr, ptr }, align 8
7228   // %gep_ = getelementptr { ptr, ptr }, ptr %structArg, i32 0, i32 0
7229   // store ptr %20, ptr %gep_, align 8
7230   // %gep_8 = getelementptr { ptr, ptr }, ptr %structArg, i32 0, i32 1
7231   // store ptr %21, ptr %gep_8, align 8
7232   // call void @_QQmain..omp_par.1(i32 %global.tid.val6, ptr %structArg)
7233   //
7234   // OR
7235   //
7236   // call void @_QQmain..omp_par.1(i32 %global.tid.val6)
7237   OpenMPIRBuilder::InsertPointTy IP(StaleCI->getParent(),
7238                                     StaleCI->getIterator());
7239 
7240   LLVMContext &Ctx = StaleCI->getParent()->getContext();
7241 
7242   Type *ThreadIDTy = Type::getInt32Ty(Ctx);
7243   Type *TaskPtrTy = OMPBuilder.TaskPtr;
7244   [[maybe_unused]] Type *TaskTy = OMPBuilder.Task;
7245 
7246   auto ProxyFnTy =
7247       FunctionType::get(Builder.getVoidTy(), {ThreadIDTy, TaskPtrTy},
7248                         /* isVarArg */ false);
7249   auto ProxyFn = Function::Create(ProxyFnTy, GlobalValue::InternalLinkage,
7250                                   ".omp_target_task_proxy_func",
7251                                   Builder.GetInsertBlock()->getModule());
7252   Value *ThreadId = ProxyFn->getArg(0);
7253   Value *TaskWithPrivates = ProxyFn->getArg(1);
7254   ThreadId->setName("thread.id");
7255   TaskWithPrivates->setName("task");
7256 
7257   bool HasShareds = SharedArgsOperandNo > 0;
7258   bool HasOffloadingArrays = NumOffloadingArrays > 0;
7259   BasicBlock *EntryBB =
7260       BasicBlock::Create(Builder.getContext(), "entry", ProxyFn);
7261   Builder.SetInsertPoint(EntryBB);
7262 
7263   SmallVector<Value *> KernelLaunchArgs;
7264   KernelLaunchArgs.reserve(StaleCI->arg_size());
7265   KernelLaunchArgs.push_back(ThreadId);
7266 
7267   if (HasOffloadingArrays) {
7268     assert(TaskTy != TaskWithPrivatesTy &&
7269            "If there are offloading arrays to pass to the target"
7270            "TaskTy cannot be the same as TaskWithPrivatesTy");
7271     (void)TaskTy;
7272     Value *Privates =
7273         Builder.CreateStructGEP(TaskWithPrivatesTy, TaskWithPrivates, 1);
7274     for (unsigned int i = 0; i < NumOffloadingArrays; ++i)
7275       KernelLaunchArgs.push_back(
7276           Builder.CreateStructGEP(PrivatesTy, Privates, i));
7277   }
7278 
7279   if (HasShareds) {
7280     auto *ArgStructAlloca =
7281         dyn_cast<AllocaInst>(StaleCI->getArgOperand(SharedArgsOperandNo));
7282     assert(ArgStructAlloca &&
7283            "Unable to find the alloca instruction corresponding to arguments "
7284            "for extracted function");
7285     auto *ArgStructType = cast<StructType>(ArgStructAlloca->getAllocatedType());
7286 
7287     AllocaInst *NewArgStructAlloca =
7288         Builder.CreateAlloca(ArgStructType, nullptr, "structArg");
7289 
7290     Value *SharedsSize =
7291         Builder.getInt64(M.getDataLayout().getTypeStoreSize(ArgStructType));
7292 
7293     LoadInst *LoadShared = loadSharedDataFromTaskDescriptor(
7294         OMPBuilder, Builder, TaskWithPrivates, TaskWithPrivatesTy);
7295 
7296     Builder.CreateMemCpy(
7297         NewArgStructAlloca, NewArgStructAlloca->getAlign(), LoadShared,
7298         LoadShared->getPointerAlignment(M.getDataLayout()), SharedsSize);
7299     KernelLaunchArgs.push_back(NewArgStructAlloca);
7300   }
7301   Builder.CreateCall(KernelLaunchFunction, KernelLaunchArgs);
7302   Builder.CreateRetVoid();
7303   return ProxyFn;
7304 }
getOffloadingArrayType(Value * V)7305 static Type *getOffloadingArrayType(Value *V) {
7306 
7307   if (auto *GEP = dyn_cast<GetElementPtrInst>(V))
7308     return GEP->getSourceElementType();
7309   if (auto *Alloca = dyn_cast<AllocaInst>(V))
7310     return Alloca->getAllocatedType();
7311 
7312   llvm_unreachable("Unhandled Instruction type");
7313   return nullptr;
7314 }
7315 // This function returns a struct that has at most two members.
7316 // The first member is always %struct.kmp_task_ompbuilder_t, that is the task
7317 // descriptor. The second member, if needed, is a struct containing arrays
7318 // that need to be passed to the offloaded target kernel. For example,
7319 // if .offload_baseptrs, .offload_ptrs and .offload_sizes have to be passed to
7320 // the target kernel and their types are [3 x ptr], [3 x ptr] and [3 x i64]
7321 // respectively, then the types created  by this function are
7322 //
7323 // %struct.privates = type { [3 x ptr], [3 x ptr], [3 x i64] }
7324 // %struct.task_with_privates = type { %struct.kmp_task_ompbuilder_t,
7325 //                                     %struct.privates }
7326 // %struct.task_with_privates is returned by this function.
7327 // If there aren't any offloading arrays to pass to the target kernel,
7328 // %struct.kmp_task_ompbuilder_t is returned.
7329 static StructType *
createTaskWithPrivatesTy(OpenMPIRBuilder & OMPIRBuilder,ArrayRef<Value * > OffloadingArraysToPrivatize)7330 createTaskWithPrivatesTy(OpenMPIRBuilder &OMPIRBuilder,
7331                          ArrayRef<Value *> OffloadingArraysToPrivatize) {
7332 
7333   if (OffloadingArraysToPrivatize.empty())
7334     return OMPIRBuilder.Task;
7335 
7336   SmallVector<Type *, 4> StructFieldTypes;
7337   for (Value *V : OffloadingArraysToPrivatize) {
7338     assert(V->getType()->isPointerTy() &&
7339            "Expected pointer to array to privatize. Got a non-pointer value "
7340            "instead");
7341     Type *ArrayTy = getOffloadingArrayType(V);
7342     assert(ArrayTy && "ArrayType cannot be nullptr");
7343     StructFieldTypes.push_back(ArrayTy);
7344   }
7345   StructType *PrivatesStructTy =
7346       StructType::create(StructFieldTypes, "struct.privates");
7347   return StructType::create({OMPIRBuilder.Task, PrivatesStructTy},
7348                             "struct.task_with_privates");
7349 }
emitTargetOutlinedFunction(OpenMPIRBuilder & OMPBuilder,IRBuilderBase & Builder,bool IsOffloadEntry,TargetRegionEntryInfo & EntryInfo,const OpenMPIRBuilder::TargetKernelDefaultAttrs & DefaultAttrs,Function * & OutlinedFn,Constant * & OutlinedFnID,SmallVectorImpl<Value * > & Inputs,OpenMPIRBuilder::TargetBodyGenCallbackTy & CBFunc,OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy & ArgAccessorFuncCB)7350 static Error emitTargetOutlinedFunction(
7351     OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsOffloadEntry,
7352     TargetRegionEntryInfo &EntryInfo,
7353     const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
7354     Function *&OutlinedFn, Constant *&OutlinedFnID,
7355     SmallVectorImpl<Value *> &Inputs,
7356     OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
7357     OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy &ArgAccessorFuncCB) {
7358 
7359   OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction =
7360       [&](StringRef EntryFnName) {
7361         return createOutlinedFunction(OMPBuilder, Builder, DefaultAttrs,
7362                                       EntryFnName, Inputs, CBFunc,
7363                                       ArgAccessorFuncCB);
7364       };
7365 
7366   return OMPBuilder.emitTargetRegionFunction(
7367       EntryInfo, GenerateOutlinedFunction, IsOffloadEntry, OutlinedFn,
7368       OutlinedFnID);
7369 }
7370 
emitTargetTask(TargetTaskBodyCallbackTy TaskBodyCB,Value * DeviceID,Value * RTLoc,OpenMPIRBuilder::InsertPointTy AllocaIP,const SmallVector<llvm::OpenMPIRBuilder::DependData> & Dependencies,const TargetDataRTArgs & RTArgs,bool HasNoWait)7371 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitTargetTask(
7372     TargetTaskBodyCallbackTy TaskBodyCB, Value *DeviceID, Value *RTLoc,
7373     OpenMPIRBuilder::InsertPointTy AllocaIP,
7374     const SmallVector<llvm::OpenMPIRBuilder::DependData> &Dependencies,
7375     const TargetDataRTArgs &RTArgs, bool HasNoWait) {
7376 
7377   // The following explains the code-gen scenario for the `target` directive. A
7378   // similar scneario is followed for other device-related directives (e.g.
7379   // `target enter data`) but in similar fashion since we only need to emit task
7380   // that encapsulates the proper runtime call.
7381   //
7382   // When we arrive at this function, the target region itself has been
7383   // outlined into the function OutlinedFn.
7384   // So at ths point, for
7385   // --------------------------------------------------------------
7386   //   void user_code_that_offloads(...) {
7387   //     omp target depend(..) map(from:a) map(to:b) private(i)
7388   //     do i = 1, 10
7389   //        a(i) = b(i) + n
7390   //   }
7391   //
7392   // --------------------------------------------------------------
7393   //
7394   // we have
7395   //
7396   // --------------------------------------------------------------
7397   //
7398   //   void user_code_that_offloads(...) {
7399   //     %.offload_baseptrs = alloca [2 x ptr], align 8
7400   //     %.offload_ptrs = alloca [2 x ptr], align 8
7401   //     %.offload_mappers = alloca [2 x ptr], align 8
7402   //     ;; target region has been outlined and now we need to
7403   //     ;; offload to it via a target task.
7404   //   }
7405   //   void outlined_device_function(ptr a, ptr b, ptr n) {
7406   //     n = *n_ptr;
7407   //     do i = 1, 10
7408   //       a(i) = b(i) +  n
7409   //   }
7410   //
7411   // We have to now do the following
7412   // (i)   Make an offloading call to outlined_device_function using the OpenMP
7413   //       RTL. See 'kernel_launch_function' in the pseudo code below. This is
7414   //       emitted by emitKernelLaunch
7415   // (ii)  Create a task entry point function that calls kernel_launch_function
7416   //       and is the entry point for the target task. See
7417   //       '@.omp_target_task_proxy_func in the pseudocode below.
7418   // (iii) Create a task with the task entry point created in (ii)
7419   //
7420   // That is we create the following
7421   //   struct task_with_privates {
7422   //      struct kmp_task_ompbuilder_t task_struct;
7423   //      struct privates {
7424   //         [2 x ptr] ; baseptrs
7425   //         [2 x ptr] ; ptrs
7426   //         [2 x i64] ; sizes
7427   //      }
7428   //   }
7429   //   void user_code_that_offloads(...) {
7430   //     %.offload_baseptrs = alloca [2 x ptr], align 8
7431   //     %.offload_ptrs = alloca [2 x ptr], align 8
7432   //     %.offload_sizes = alloca [2 x i64], align 8
7433   //
7434   //     %structArg = alloca { ptr, ptr, ptr }, align 8
7435   //     %strucArg[0] = a
7436   //     %strucArg[1] = b
7437   //     %strucArg[2] = &n
7438   //
7439   //     target_task_with_privates = @__kmpc_omp_target_task_alloc(...,
7440   //                                               sizeof(kmp_task_ompbuilder_t),
7441   //                                               sizeof(structArg),
7442   //                                               @.omp_target_task_proxy_func,
7443   //                                               ...)
7444   //     memcpy(target_task_with_privates->task_struct->shareds, %structArg,
7445   //            sizeof(structArg))
7446   //     memcpy(target_task_with_privates->privates->baseptrs,
7447   //            offload_baseptrs, sizeof(offload_baseptrs)
7448   //     memcpy(target_task_with_privates->privates->ptrs,
7449   //            offload_ptrs, sizeof(offload_ptrs)
7450   //     memcpy(target_task_with_privates->privates->sizes,
7451   //            offload_sizes, sizeof(offload_sizes)
7452   //     dependencies_array = ...
7453   //     ;; if nowait not present
7454   //     call @__kmpc_omp_wait_deps(..., dependencies_array)
7455   //     call @__kmpc_omp_task_begin_if0(...)
7456   //     call @ @.omp_target_task_proxy_func(i32 thread_id, ptr
7457   //     %target_task_with_privates)
7458   //     call @__kmpc_omp_task_complete_if0(...)
7459   //   }
7460   //
7461   //   define internal void @.omp_target_task_proxy_func(i32 %thread.id,
7462   //                                                     ptr %task) {
7463   //       %structArg = alloca {ptr, ptr, ptr}
7464   //       %task_ptr = getelementptr(%task, 0, 0)
7465   //       %shared_data = load (getelementptr %task_ptr, 0, 0)
7466   //       mempcy(%structArg, %shared_data, sizeof(%structArg))
7467   //
7468   //       %offloading_arrays = getelementptr(%task, 0, 1)
7469   //       %offload_baseptrs = getelementptr(%offloading_arrays, 0, 0)
7470   //       %offload_ptrs = getelementptr(%offloading_arrays, 0, 1)
7471   //       %offload_sizes = getelementptr(%offloading_arrays, 0, 2)
7472   //       kernel_launch_function(%thread.id, %offload_baseptrs, %offload_ptrs,
7473   //                              %offload_sizes, %structArg)
7474   //   }
7475   //
7476   //   We need the proxy function because the signature of the task entry point
7477   //   expected by kmpc_omp_task is always the same and will be different from
7478   //   that of the kernel_launch function.
7479   //
7480   //   kernel_launch_function is generated by emitKernelLaunch and has the
7481   //   always_inline attribute. For this example, it'll look like so:
7482   //   void kernel_launch_function(%thread_id, %offload_baseptrs, %offload_ptrs,
7483   //                               %offload_sizes,  %structArg) alwaysinline {
7484   //       %kernel_args = alloca %struct.__tgt_kernel_arguments, align 8
7485   //       ; load aggregated data from %structArg
7486   //       ; setup kernel_args using offload_baseptrs, offload_ptrs and
7487   //       ; offload_sizes
7488   //       call i32 @__tgt_target_kernel(...,
7489   //                                     outlined_device_function,
7490   //                                     ptr %kernel_args)
7491   //   }
7492   //   void outlined_device_function(ptr a, ptr b, ptr n) {
7493   //     n = *n_ptr;
7494   //     do i = 1, 10
7495   //       a(i) = b(i) +  n
7496   //   }
7497   //
7498   BasicBlock *TargetTaskBodyBB =
7499       splitBB(Builder, /*CreateBranch=*/true, "target.task.body");
7500   BasicBlock *TargetTaskAllocaBB =
7501       splitBB(Builder, /*CreateBranch=*/true, "target.task.alloca");
7502 
7503   InsertPointTy TargetTaskAllocaIP(TargetTaskAllocaBB,
7504                                    TargetTaskAllocaBB->begin());
7505   InsertPointTy TargetTaskBodyIP(TargetTaskBodyBB, TargetTaskBodyBB->begin());
7506 
7507   OutlineInfo OI;
7508   OI.EntryBB = TargetTaskAllocaBB;
7509   OI.OuterAllocaBB = AllocaIP.getBlock();
7510 
7511   // Add the thread ID argument.
7512   SmallVector<Instruction *, 4> ToBeDeleted;
7513   OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
7514       Builder, AllocaIP, ToBeDeleted, TargetTaskAllocaIP, "global.tid", false));
7515 
7516   // Generate the task body which will subsequently be outlined.
7517   Builder.restoreIP(TargetTaskBodyIP);
7518   if (Error Err = TaskBodyCB(DeviceID, RTLoc, TargetTaskAllocaIP))
7519     return Err;
7520 
7521   // The outliner (CodeExtractor) extract a sequence or vector of blocks that
7522   // it is given. These blocks are enumerated by
7523   // OpenMPIRBuilder::OutlineInfo::collectBlocks which expects the OI.ExitBlock
7524   // to be outside the region. In other words, OI.ExitBlock is expected to be
7525   // the start of the region after the outlining. We used to set OI.ExitBlock
7526   // to the InsertBlock after TaskBodyCB is done. This is fine in most cases
7527   // except when the task body is a single basic block. In that case,
7528   // OI.ExitBlock is set to the single task body block and will get left out of
7529   // the outlining process. So, simply create a new empty block to which we
7530   // uncoditionally branch from where TaskBodyCB left off
7531   OI.ExitBB = BasicBlock::Create(Builder.getContext(), "target.task.cont");
7532   emitBlock(OI.ExitBB, Builder.GetInsertBlock()->getParent(),
7533             /*IsFinished=*/true);
7534 
7535   SmallVector<Value *, 2> OffloadingArraysToPrivatize;
7536   bool NeedsTargetTask = HasNoWait && DeviceID;
7537   if (NeedsTargetTask) {
7538     for (auto *V :
7539          {RTArgs.BasePointersArray, RTArgs.PointersArray, RTArgs.MappersArray,
7540           RTArgs.MapNamesArray, RTArgs.MapTypesArray, RTArgs.MapTypesArrayEnd,
7541           RTArgs.SizesArray}) {
7542       if (V && !isa<ConstantPointerNull, GlobalVariable>(V)) {
7543         OffloadingArraysToPrivatize.push_back(V);
7544         OI.ExcludeArgsFromAggregate.push_back(V);
7545       }
7546     }
7547   }
7548   OI.PostOutlineCB = [this, ToBeDeleted, Dependencies, NeedsTargetTask,
7549                       DeviceID, OffloadingArraysToPrivatize](
7550                          Function &OutlinedFn) mutable {
7551     assert(OutlinedFn.hasOneUse() &&
7552            "there must be a single user for the outlined function");
7553 
7554     CallInst *StaleCI = cast<CallInst>(OutlinedFn.user_back());
7555 
7556     // The first argument of StaleCI is always the thread id.
7557     // The next few arguments are the pointers to offloading arrays
7558     // if any. (see OffloadingArraysToPrivatize)
7559     // Finally, all other local values that are live-in into the outlined region
7560     // end up in a structure whose pointer is passed as the last argument. This
7561     // piece of data is passed in the "shared" field of the task structure. So,
7562     // we know we have to pass shareds to the task if the number of arguments is
7563     // greater than OffloadingArraysToPrivatize.size() + 1 The 1 is for the
7564     // thread id. Further, for safety, we assert that the number of arguments of
7565     // StaleCI is exactly OffloadingArraysToPrivatize.size() + 2
7566     const unsigned int NumStaleCIArgs = StaleCI->arg_size();
7567     bool HasShareds = NumStaleCIArgs > OffloadingArraysToPrivatize.size() + 1;
7568     assert((!HasShareds ||
7569             NumStaleCIArgs == (OffloadingArraysToPrivatize.size() + 2)) &&
7570            "Wrong number of arguments for StaleCI when shareds are present");
7571     int SharedArgOperandNo =
7572         HasShareds ? OffloadingArraysToPrivatize.size() + 1 : 0;
7573 
7574     StructType *TaskWithPrivatesTy =
7575         createTaskWithPrivatesTy(*this, OffloadingArraysToPrivatize);
7576     StructType *PrivatesTy = nullptr;
7577 
7578     if (!OffloadingArraysToPrivatize.empty())
7579       PrivatesTy =
7580           static_cast<StructType *>(TaskWithPrivatesTy->getElementType(1));
7581 
7582     Function *ProxyFn = emitTargetTaskProxyFunction(
7583         *this, Builder, StaleCI, PrivatesTy, TaskWithPrivatesTy,
7584         OffloadingArraysToPrivatize.size(), SharedArgOperandNo);
7585 
7586     LLVM_DEBUG(dbgs() << "Proxy task entry function created: " << *ProxyFn
7587                       << "\n");
7588 
7589     Builder.SetInsertPoint(StaleCI);
7590 
7591     // Gather the arguments for emitting the runtime call.
7592     uint32_t SrcLocStrSize;
7593     Constant *SrcLocStr =
7594         getOrCreateSrcLocStr(LocationDescription(Builder), SrcLocStrSize);
7595     Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
7596 
7597     // @__kmpc_omp_task_alloc or @__kmpc_omp_target_task_alloc
7598     //
7599     // If `HasNoWait == true`, we call  @__kmpc_omp_target_task_alloc to provide
7600     // the DeviceID to the deferred task and also since
7601     // @__kmpc_omp_target_task_alloc creates an untied/async task.
7602     Function *TaskAllocFn =
7603         !NeedsTargetTask
7604             ? getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_alloc)
7605             : getOrCreateRuntimeFunctionPtr(
7606                   OMPRTL___kmpc_omp_target_task_alloc);
7607 
7608     // Arguments - `loc_ref` (Ident) and `gtid` (ThreadID)
7609     // call.
7610     Value *ThreadID = getOrCreateThreadID(Ident);
7611 
7612     // Argument - `sizeof_kmp_task_t` (TaskSize)
7613     // Tasksize refers to the size in bytes of kmp_task_t data structure
7614     // plus any other data to be passed to the target task, if any, which
7615     // is packed into a struct. kmp_task_t and the struct so created are
7616     // packed into a wrapper struct whose type is TaskWithPrivatesTy.
7617     Value *TaskSize = Builder.getInt64(
7618         M.getDataLayout().getTypeStoreSize(TaskWithPrivatesTy));
7619 
7620     // Argument - `sizeof_shareds` (SharedsSize)
7621     // SharedsSize refers to the shareds array size in the kmp_task_t data
7622     // structure.
7623     Value *SharedsSize = Builder.getInt64(0);
7624     if (HasShareds) {
7625       auto *ArgStructAlloca =
7626           dyn_cast<AllocaInst>(StaleCI->getArgOperand(SharedArgOperandNo));
7627       assert(ArgStructAlloca &&
7628              "Unable to find the alloca instruction corresponding to arguments "
7629              "for extracted function");
7630       auto *ArgStructType =
7631           dyn_cast<StructType>(ArgStructAlloca->getAllocatedType());
7632       assert(ArgStructType && "Unable to find struct type corresponding to "
7633                               "arguments for extracted function");
7634       SharedsSize =
7635           Builder.getInt64(M.getDataLayout().getTypeStoreSize(ArgStructType));
7636     }
7637 
7638     // Argument - `flags`
7639     // Task is tied iff (Flags & 1) == 1.
7640     // Task is untied iff (Flags & 1) == 0.
7641     // Task is final iff (Flags & 2) == 2.
7642     // Task is not final iff (Flags & 2) == 0.
7643     // A target task is not final and is untied.
7644     Value *Flags = Builder.getInt32(0);
7645 
7646     // Emit the @__kmpc_omp_task_alloc runtime call
7647     // The runtime call returns a pointer to an area where the task captured
7648     // variables must be copied before the task is run (TaskData)
7649     CallInst *TaskData = nullptr;
7650 
7651     SmallVector<llvm::Value *> TaskAllocArgs = {
7652         /*loc_ref=*/Ident,        /*gtid=*/ThreadID,
7653         /*flags=*/Flags,
7654         /*sizeof_task=*/TaskSize, /*sizeof_shared=*/SharedsSize,
7655         /*task_func=*/ProxyFn};
7656 
7657     if (NeedsTargetTask) {
7658       assert(DeviceID && "Expected non-empty device ID.");
7659       TaskAllocArgs.push_back(DeviceID);
7660     }
7661 
7662     TaskData = Builder.CreateCall(TaskAllocFn, TaskAllocArgs);
7663 
7664     Align Alignment = TaskData->getPointerAlignment(M.getDataLayout());
7665     if (HasShareds) {
7666       Value *Shareds = StaleCI->getArgOperand(SharedArgOperandNo);
7667       Value *TaskShareds = loadSharedDataFromTaskDescriptor(
7668           *this, Builder, TaskData, TaskWithPrivatesTy);
7669       Builder.CreateMemCpy(TaskShareds, Alignment, Shareds, Alignment,
7670                            SharedsSize);
7671     }
7672     if (!OffloadingArraysToPrivatize.empty()) {
7673       Value *Privates =
7674           Builder.CreateStructGEP(TaskWithPrivatesTy, TaskData, 1);
7675       for (unsigned int i = 0; i < OffloadingArraysToPrivatize.size(); ++i) {
7676         Value *PtrToPrivatize = OffloadingArraysToPrivatize[i];
7677         [[maybe_unused]] Type *ArrayType =
7678             getOffloadingArrayType(PtrToPrivatize);
7679         assert(ArrayType && "ArrayType cannot be nullptr");
7680 
7681         Type *ElementType = PrivatesTy->getElementType(i);
7682         assert(ElementType == ArrayType &&
7683                "ElementType should match ArrayType");
7684         (void)ArrayType;
7685 
7686         Value *Dst = Builder.CreateStructGEP(PrivatesTy, Privates, i);
7687         Builder.CreateMemCpy(
7688             Dst, Alignment, PtrToPrivatize, Alignment,
7689             Builder.getInt64(M.getDataLayout().getTypeStoreSize(ElementType)));
7690       }
7691     }
7692 
7693     Value *DepArray = emitTaskDependencies(*this, Dependencies);
7694 
7695     // ---------------------------------------------------------------
7696     // V5.2 13.8 target construct
7697     // If the nowait clause is present, execution of the target task
7698     // may be deferred. If the nowait clause is not present, the target task is
7699     // an included task.
7700     // ---------------------------------------------------------------
7701     // The above means that the lack of a nowait on the target construct
7702     // translates to '#pragma omp task if(0)'
7703     if (!NeedsTargetTask) {
7704       if (DepArray) {
7705         Function *TaskWaitFn =
7706             getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_wait_deps);
7707         Builder.CreateCall(
7708             TaskWaitFn,
7709             {/*loc_ref=*/Ident, /*gtid=*/ThreadID,
7710              /*ndeps=*/Builder.getInt32(Dependencies.size()),
7711              /*dep_list=*/DepArray,
7712              /*ndeps_noalias=*/ConstantInt::get(Builder.getInt32Ty(), 0),
7713              /*noalias_dep_list=*/
7714              ConstantPointerNull::get(PointerType::getUnqual(M.getContext()))});
7715       }
7716       // Included task.
7717       Function *TaskBeginFn =
7718           getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_begin_if0);
7719       Function *TaskCompleteFn =
7720           getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_complete_if0);
7721       Builder.CreateCall(TaskBeginFn, {Ident, ThreadID, TaskData});
7722       CallInst *CI = Builder.CreateCall(ProxyFn, {ThreadID, TaskData});
7723       CI->setDebugLoc(StaleCI->getDebugLoc());
7724       Builder.CreateCall(TaskCompleteFn, {Ident, ThreadID, TaskData});
7725     } else if (DepArray) {
7726       // HasNoWait - meaning the task may be deferred. Call
7727       // __kmpc_omp_task_with_deps if there are dependencies,
7728       // else call __kmpc_omp_task
7729       Function *TaskFn =
7730           getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_with_deps);
7731       Builder.CreateCall(
7732           TaskFn,
7733           {Ident, ThreadID, TaskData, Builder.getInt32(Dependencies.size()),
7734            DepArray, ConstantInt::get(Builder.getInt32Ty(), 0),
7735            ConstantPointerNull::get(PointerType::getUnqual(M.getContext()))});
7736     } else {
7737       // Emit the @__kmpc_omp_task runtime call to spawn the task
7738       Function *TaskFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task);
7739       Builder.CreateCall(TaskFn, {Ident, ThreadID, TaskData});
7740     }
7741 
7742     StaleCI->eraseFromParent();
7743     for (Instruction *I : llvm::reverse(ToBeDeleted))
7744       I->eraseFromParent();
7745   };
7746   addOutlineInfo(std::move(OI));
7747 
7748   LLVM_DEBUG(dbgs() << "Insert block after emitKernelLaunch = \n"
7749                     << *(Builder.GetInsertBlock()) << "\n");
7750   LLVM_DEBUG(dbgs() << "Module after emitKernelLaunch = \n"
7751                     << *(Builder.GetInsertBlock()->getParent()->getParent())
7752                     << "\n");
7753   return Builder.saveIP();
7754 }
7755 
emitOffloadingArraysAndArgs(InsertPointTy AllocaIP,InsertPointTy CodeGenIP,TargetDataInfo & Info,TargetDataRTArgs & RTArgs,MapInfosTy & CombinedInfo,CustomMapperCallbackTy CustomMapperCB,bool IsNonContiguous,bool ForEndCall,function_ref<void (unsigned int,Value *)> DeviceAddrCB)7756 Error OpenMPIRBuilder::emitOffloadingArraysAndArgs(
7757     InsertPointTy AllocaIP, InsertPointTy CodeGenIP, TargetDataInfo &Info,
7758     TargetDataRTArgs &RTArgs, MapInfosTy &CombinedInfo,
7759     CustomMapperCallbackTy CustomMapperCB, bool IsNonContiguous,
7760     bool ForEndCall, function_ref<void(unsigned int, Value *)> DeviceAddrCB) {
7761   if (Error Err =
7762           emitOffloadingArrays(AllocaIP, CodeGenIP, CombinedInfo, Info,
7763                                CustomMapperCB, IsNonContiguous, DeviceAddrCB))
7764     return Err;
7765   emitOffloadingArraysArgument(Builder, RTArgs, Info, ForEndCall);
7766   return Error::success();
7767 }
7768 
emitTargetCall(OpenMPIRBuilder & OMPBuilder,IRBuilderBase & Builder,OpenMPIRBuilder::InsertPointTy AllocaIP,OpenMPIRBuilder::TargetDataInfo & Info,const OpenMPIRBuilder::TargetKernelDefaultAttrs & DefaultAttrs,const OpenMPIRBuilder::TargetKernelRuntimeAttrs & RuntimeAttrs,Value * IfCond,Function * OutlinedFn,Constant * OutlinedFnID,SmallVectorImpl<Value * > & Args,OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,OpenMPIRBuilder::CustomMapperCallbackTy CustomMapperCB,const SmallVector<llvm::OpenMPIRBuilder::DependData> & Dependencies,bool HasNoWait)7769 static void emitTargetCall(
7770     OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
7771     OpenMPIRBuilder::InsertPointTy AllocaIP,
7772     OpenMPIRBuilder::TargetDataInfo &Info,
7773     const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
7774     const OpenMPIRBuilder::TargetKernelRuntimeAttrs &RuntimeAttrs,
7775     Value *IfCond, Function *OutlinedFn, Constant *OutlinedFnID,
7776     SmallVectorImpl<Value *> &Args,
7777     OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
7778     OpenMPIRBuilder::CustomMapperCallbackTy CustomMapperCB,
7779     const SmallVector<llvm::OpenMPIRBuilder::DependData> &Dependencies,
7780     bool HasNoWait) {
7781   // Generate a function call to the host fallback implementation of the target
7782   // region. This is called by the host when no offload entry was generated for
7783   // the target region and when the offloading call fails at runtime.
7784   auto &&EmitTargetCallFallbackCB = [&](OpenMPIRBuilder::InsertPointTy IP)
7785       -> OpenMPIRBuilder::InsertPointOrErrorTy {
7786     Builder.restoreIP(IP);
7787     Builder.CreateCall(OutlinedFn, Args);
7788     return Builder.saveIP();
7789   };
7790 
7791   bool HasDependencies = Dependencies.size() > 0;
7792   bool RequiresOuterTargetTask = HasNoWait || HasDependencies;
7793 
7794   OpenMPIRBuilder::TargetKernelArgs KArgs;
7795 
7796   auto TaskBodyCB =
7797       [&](Value *DeviceID, Value *RTLoc,
7798           IRBuilderBase::InsertPoint TargetTaskAllocaIP) -> Error {
7799     // Assume no error was returned because EmitTargetCallFallbackCB doesn't
7800     // produce any.
7801     llvm::OpenMPIRBuilder::InsertPointTy AfterIP = cantFail([&]() {
7802       // emitKernelLaunch makes the necessary runtime call to offload the
7803       // kernel. We then outline all that code into a separate function
7804       // ('kernel_launch_function' in the pseudo code above). This function is
7805       // then called by the target task proxy function (see
7806       // '@.omp_target_task_proxy_func' in the pseudo code above)
7807       // "@.omp_target_task_proxy_func' is generated by
7808       // emitTargetTaskProxyFunction.
7809       if (OutlinedFnID && DeviceID)
7810         return OMPBuilder.emitKernelLaunch(Builder, OutlinedFnID,
7811                                            EmitTargetCallFallbackCB, KArgs,
7812                                            DeviceID, RTLoc, TargetTaskAllocaIP);
7813 
7814       // We only need to do the outlining if `DeviceID` is set to avoid calling
7815       // `emitKernelLaunch` if we want to code-gen for the host; e.g. if we are
7816       // generating the `else` branch of an `if` clause.
7817       //
7818       // When OutlinedFnID is set to nullptr, then it's not an offloading call.
7819       // In this case, we execute the host implementation directly.
7820       return EmitTargetCallFallbackCB(OMPBuilder.Builder.saveIP());
7821     }());
7822 
7823     OMPBuilder.Builder.restoreIP(AfterIP);
7824     return Error::success();
7825   };
7826 
7827   auto &&EmitTargetCallElse =
7828       [&](OpenMPIRBuilder::InsertPointTy AllocaIP,
7829           OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error {
7830     // Assume no error was returned because EmitTargetCallFallbackCB doesn't
7831     // produce any.
7832     OpenMPIRBuilder::InsertPointTy AfterIP = cantFail([&]() {
7833       if (RequiresOuterTargetTask) {
7834         // Arguments that are intended to be directly forwarded to an
7835         // emitKernelLaunch call are pased as nullptr, since
7836         // OutlinedFnID=nullptr results in that call not being done.
7837         OpenMPIRBuilder::TargetDataRTArgs EmptyRTArgs;
7838         return OMPBuilder.emitTargetTask(TaskBodyCB, /*DeviceID=*/nullptr,
7839                                          /*RTLoc=*/nullptr, AllocaIP,
7840                                          Dependencies, EmptyRTArgs, HasNoWait);
7841       }
7842       return EmitTargetCallFallbackCB(Builder.saveIP());
7843     }());
7844 
7845     Builder.restoreIP(AfterIP);
7846     return Error::success();
7847   };
7848 
7849   auto &&EmitTargetCallThen =
7850       [&](OpenMPIRBuilder::InsertPointTy AllocaIP,
7851           OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error {
7852     Info.HasNoWait = HasNoWait;
7853     OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP());
7854     OpenMPIRBuilder::TargetDataRTArgs RTArgs;
7855     if (Error Err = OMPBuilder.emitOffloadingArraysAndArgs(
7856             AllocaIP, Builder.saveIP(), Info, RTArgs, MapInfo, CustomMapperCB,
7857             /*IsNonContiguous=*/true,
7858             /*ForEndCall=*/false))
7859       return Err;
7860 
7861     SmallVector<Value *, 3> NumTeamsC;
7862     for (auto [DefaultVal, RuntimeVal] :
7863          zip_equal(DefaultAttrs.MaxTeams, RuntimeAttrs.MaxTeams))
7864       NumTeamsC.push_back(RuntimeVal ? RuntimeVal
7865                                      : Builder.getInt32(DefaultVal));
7866 
7867     // Calculate number of threads: 0 if no clauses specified, otherwise it is
7868     // the minimum between optional THREAD_LIMIT and NUM_THREADS clauses.
7869     auto InitMaxThreadsClause = [&Builder](Value *Clause) {
7870       if (Clause)
7871         Clause = Builder.CreateIntCast(Clause, Builder.getInt32Ty(),
7872                                        /*isSigned=*/false);
7873       return Clause;
7874     };
7875     auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) {
7876       if (Clause)
7877         Result =
7878             Result ? Builder.CreateSelect(Builder.CreateICmpULT(Result, Clause),
7879                                           Result, Clause)
7880                    : Clause;
7881     };
7882 
7883     // If a multi-dimensional THREAD_LIMIT is set, it is the OMPX_BARE case, so
7884     // the NUM_THREADS clause is overriden by THREAD_LIMIT.
7885     SmallVector<Value *, 3> NumThreadsC;
7886     Value *MaxThreadsClause =
7887         RuntimeAttrs.TeamsThreadLimit.size() == 1
7888             ? InitMaxThreadsClause(RuntimeAttrs.MaxThreads)
7889             : nullptr;
7890 
7891     for (auto [TeamsVal, TargetVal] : zip_equal(
7892              RuntimeAttrs.TeamsThreadLimit, RuntimeAttrs.TargetThreadLimit)) {
7893       Value *TeamsThreadLimitClause = InitMaxThreadsClause(TeamsVal);
7894       Value *NumThreads = InitMaxThreadsClause(TargetVal);
7895 
7896       CombineMaxThreadsClauses(TeamsThreadLimitClause, NumThreads);
7897       CombineMaxThreadsClauses(MaxThreadsClause, NumThreads);
7898 
7899       NumThreadsC.push_back(NumThreads ? NumThreads : Builder.getInt32(0));
7900     }
7901 
7902     unsigned NumTargetItems = Info.NumberOfPtrs;
7903     // TODO: Use correct device ID
7904     Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF);
7905     uint32_t SrcLocStrSize;
7906     Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
7907     Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
7908                                                llvm::omp::IdentFlag(0), 0);
7909 
7910     Value *TripCount = RuntimeAttrs.LoopTripCount
7911                            ? Builder.CreateIntCast(RuntimeAttrs.LoopTripCount,
7912                                                    Builder.getInt64Ty(),
7913                                                    /*isSigned=*/false)
7914                            : Builder.getInt64(0);
7915 
7916     // TODO: Use correct DynCGGroupMem
7917     Value *DynCGGroupMem = Builder.getInt32(0);
7918 
7919     KArgs = OpenMPIRBuilder::TargetKernelArgs(NumTargetItems, RTArgs, TripCount,
7920                                               NumTeamsC, NumThreadsC,
7921                                               DynCGGroupMem, HasNoWait);
7922 
7923     // Assume no error was returned because TaskBodyCB and
7924     // EmitTargetCallFallbackCB don't produce any.
7925     OpenMPIRBuilder::InsertPointTy AfterIP = cantFail([&]() {
7926       // The presence of certain clauses on the target directive require the
7927       // explicit generation of the target task.
7928       if (RequiresOuterTargetTask)
7929         return OMPBuilder.emitTargetTask(TaskBodyCB, DeviceID, RTLoc, AllocaIP,
7930                                          Dependencies, KArgs.RTArgs,
7931                                          Info.HasNoWait);
7932 
7933       return OMPBuilder.emitKernelLaunch(Builder, OutlinedFnID,
7934                                          EmitTargetCallFallbackCB, KArgs,
7935                                          DeviceID, RTLoc, AllocaIP);
7936     }());
7937 
7938     Builder.restoreIP(AfterIP);
7939     return Error::success();
7940   };
7941 
7942   // If we don't have an ID for the target region, it means an offload entry
7943   // wasn't created. In this case we just run the host fallback directly and
7944   // ignore any potential 'if' clauses.
7945   if (!OutlinedFnID) {
7946     cantFail(EmitTargetCallElse(AllocaIP, Builder.saveIP()));
7947     return;
7948   }
7949 
7950   // If there's no 'if' clause, only generate the kernel launch code path.
7951   if (!IfCond) {
7952     cantFail(EmitTargetCallThen(AllocaIP, Builder.saveIP()));
7953     return;
7954   }
7955 
7956   cantFail(OMPBuilder.emitIfClause(IfCond, EmitTargetCallThen,
7957                                    EmitTargetCallElse, AllocaIP));
7958 }
7959 
createTarget(const LocationDescription & Loc,bool IsOffloadEntry,InsertPointTy AllocaIP,InsertPointTy CodeGenIP,TargetDataInfo & Info,TargetRegionEntryInfo & EntryInfo,const TargetKernelDefaultAttrs & DefaultAttrs,const TargetKernelRuntimeAttrs & RuntimeAttrs,Value * IfCond,SmallVectorImpl<Value * > & Inputs,GenMapInfoCallbackTy GenMapInfoCB,OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,CustomMapperCallbackTy CustomMapperCB,const SmallVector<DependData> & Dependencies,bool HasNowait)7960 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
7961     const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
7962     InsertPointTy CodeGenIP, TargetDataInfo &Info,
7963     TargetRegionEntryInfo &EntryInfo,
7964     const TargetKernelDefaultAttrs &DefaultAttrs,
7965     const TargetKernelRuntimeAttrs &RuntimeAttrs, Value *IfCond,
7966     SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB,
7967     OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
7968     OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
7969     CustomMapperCallbackTy CustomMapperCB,
7970     const SmallVector<DependData> &Dependencies, bool HasNowait) {
7971 
7972   if (!updateToLocation(Loc))
7973     return InsertPointTy();
7974 
7975   Builder.restoreIP(CodeGenIP);
7976 
7977   Function *OutlinedFn;
7978   Constant *OutlinedFnID = nullptr;
7979   // The target region is outlined into its own function. The LLVM IR for
7980   // the target region itself is generated using the callbacks CBFunc
7981   // and ArgAccessorFuncCB
7982   if (Error Err = emitTargetOutlinedFunction(
7983           *this, Builder, IsOffloadEntry, EntryInfo, DefaultAttrs, OutlinedFn,
7984           OutlinedFnID, Inputs, CBFunc, ArgAccessorFuncCB))
7985     return Err;
7986 
7987   // If we are not on the target device, then we need to generate code
7988   // to make a remote call (offload) to the previously outlined function
7989   // that represents the target region. Do that now.
7990   if (!Config.isTargetDevice())
7991     emitTargetCall(*this, Builder, AllocaIP, Info, DefaultAttrs, RuntimeAttrs,
7992                    IfCond, OutlinedFn, OutlinedFnID, Inputs, GenMapInfoCB,
7993                    CustomMapperCB, Dependencies, HasNowait);
7994   return Builder.saveIP();
7995 }
7996 
getNameWithSeparators(ArrayRef<StringRef> Parts,StringRef FirstSeparator,StringRef Separator)7997 std::string OpenMPIRBuilder::getNameWithSeparators(ArrayRef<StringRef> Parts,
7998                                                    StringRef FirstSeparator,
7999                                                    StringRef Separator) {
8000   SmallString<128> Buffer;
8001   llvm::raw_svector_ostream OS(Buffer);
8002   StringRef Sep = FirstSeparator;
8003   for (StringRef Part : Parts) {
8004     OS << Sep << Part;
8005     Sep = Separator;
8006   }
8007   return OS.str().str();
8008 }
8009 
8010 std::string
createPlatformSpecificName(ArrayRef<StringRef> Parts) const8011 OpenMPIRBuilder::createPlatformSpecificName(ArrayRef<StringRef> Parts) const {
8012   return OpenMPIRBuilder::getNameWithSeparators(Parts, Config.firstSeparator(),
8013                                                 Config.separator());
8014 }
8015 
8016 GlobalVariable *
getOrCreateInternalVariable(Type * Ty,const StringRef & Name,unsigned AddressSpace)8017 OpenMPIRBuilder::getOrCreateInternalVariable(Type *Ty, const StringRef &Name,
8018                                              unsigned AddressSpace) {
8019   auto &Elem = *InternalVars.try_emplace(Name, nullptr).first;
8020   if (Elem.second) {
8021     assert(Elem.second->getValueType() == Ty &&
8022            "OMP internal variable has different type than requested");
8023   } else {
8024     // TODO: investigate the appropriate linkage type used for the global
8025     // variable for possibly changing that to internal or private, or maybe
8026     // create different versions of the function for different OMP internal
8027     // variables.
8028     auto Linkage = this->M.getTargetTriple().getArch() == Triple::wasm32
8029                        ? GlobalValue::InternalLinkage
8030                        : GlobalValue::CommonLinkage;
8031     auto *GV = new GlobalVariable(M, Ty, /*IsConstant=*/false, Linkage,
8032                                   Constant::getNullValue(Ty), Elem.first(),
8033                                   /*InsertBefore=*/nullptr,
8034                                   GlobalValue::NotThreadLocal, AddressSpace);
8035     const DataLayout &DL = M.getDataLayout();
8036     const llvm::Align TypeAlign = DL.getABITypeAlign(Ty);
8037     const llvm::Align PtrAlign = DL.getPointerABIAlignment(AddressSpace);
8038     GV->setAlignment(std::max(TypeAlign, PtrAlign));
8039     Elem.second = GV;
8040   }
8041 
8042   return Elem.second;
8043 }
8044 
getOMPCriticalRegionLock(StringRef CriticalName)8045 Value *OpenMPIRBuilder::getOMPCriticalRegionLock(StringRef CriticalName) {
8046   std::string Prefix = Twine("gomp_critical_user_", CriticalName).str();
8047   std::string Name = getNameWithSeparators({Prefix, "var"}, ".", ".");
8048   return getOrCreateInternalVariable(KmpCriticalNameTy, Name);
8049 }
8050 
getSizeInBytes(Value * BasePtr)8051 Value *OpenMPIRBuilder::getSizeInBytes(Value *BasePtr) {
8052   LLVMContext &Ctx = Builder.getContext();
8053   Value *Null =
8054       Constant::getNullValue(PointerType::getUnqual(BasePtr->getContext()));
8055   Value *SizeGep =
8056       Builder.CreateGEP(BasePtr->getType(), Null, Builder.getInt32(1));
8057   Value *SizePtrToInt = Builder.CreatePtrToInt(SizeGep, Type::getInt64Ty(Ctx));
8058   return SizePtrToInt;
8059 }
8060 
8061 GlobalVariable *
createOffloadMaptypes(SmallVectorImpl<uint64_t> & Mappings,std::string VarName)8062 OpenMPIRBuilder::createOffloadMaptypes(SmallVectorImpl<uint64_t> &Mappings,
8063                                        std::string VarName) {
8064   llvm::Constant *MaptypesArrayInit =
8065       llvm::ConstantDataArray::get(M.getContext(), Mappings);
8066   auto *MaptypesArrayGlobal = new llvm::GlobalVariable(
8067       M, MaptypesArrayInit->getType(),
8068       /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, MaptypesArrayInit,
8069       VarName);
8070   MaptypesArrayGlobal->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global);
8071   return MaptypesArrayGlobal;
8072 }
8073 
createMapperAllocas(const LocationDescription & Loc,InsertPointTy AllocaIP,unsigned NumOperands,struct MapperAllocas & MapperAllocas)8074 void OpenMPIRBuilder::createMapperAllocas(const LocationDescription &Loc,
8075                                           InsertPointTy AllocaIP,
8076                                           unsigned NumOperands,
8077                                           struct MapperAllocas &MapperAllocas) {
8078   if (!updateToLocation(Loc))
8079     return;
8080 
8081   auto *ArrI8PtrTy = ArrayType::get(Int8Ptr, NumOperands);
8082   auto *ArrI64Ty = ArrayType::get(Int64, NumOperands);
8083   Builder.restoreIP(AllocaIP);
8084   AllocaInst *ArgsBase = Builder.CreateAlloca(
8085       ArrI8PtrTy, /* ArraySize = */ nullptr, ".offload_baseptrs");
8086   AllocaInst *Args = Builder.CreateAlloca(ArrI8PtrTy, /* ArraySize = */ nullptr,
8087                                           ".offload_ptrs");
8088   AllocaInst *ArgSizes = Builder.CreateAlloca(
8089       ArrI64Ty, /* ArraySize = */ nullptr, ".offload_sizes");
8090   Builder.restoreIP(Loc.IP);
8091   MapperAllocas.ArgsBase = ArgsBase;
8092   MapperAllocas.Args = Args;
8093   MapperAllocas.ArgSizes = ArgSizes;
8094 }
8095 
emitMapperCall(const LocationDescription & Loc,Function * MapperFunc,Value * SrcLocInfo,Value * MaptypesArg,Value * MapnamesArg,struct MapperAllocas & MapperAllocas,int64_t DeviceID,unsigned NumOperands)8096 void OpenMPIRBuilder::emitMapperCall(const LocationDescription &Loc,
8097                                      Function *MapperFunc, Value *SrcLocInfo,
8098                                      Value *MaptypesArg, Value *MapnamesArg,
8099                                      struct MapperAllocas &MapperAllocas,
8100                                      int64_t DeviceID, unsigned NumOperands) {
8101   if (!updateToLocation(Loc))
8102     return;
8103 
8104   auto *ArrI8PtrTy = ArrayType::get(Int8Ptr, NumOperands);
8105   auto *ArrI64Ty = ArrayType::get(Int64, NumOperands);
8106   Value *ArgsBaseGEP =
8107       Builder.CreateInBoundsGEP(ArrI8PtrTy, MapperAllocas.ArgsBase,
8108                                 {Builder.getInt32(0), Builder.getInt32(0)});
8109   Value *ArgsGEP =
8110       Builder.CreateInBoundsGEP(ArrI8PtrTy, MapperAllocas.Args,
8111                                 {Builder.getInt32(0), Builder.getInt32(0)});
8112   Value *ArgSizesGEP =
8113       Builder.CreateInBoundsGEP(ArrI64Ty, MapperAllocas.ArgSizes,
8114                                 {Builder.getInt32(0), Builder.getInt32(0)});
8115   Value *NullPtr =
8116       Constant::getNullValue(PointerType::getUnqual(Int8Ptr->getContext()));
8117   Builder.CreateCall(MapperFunc,
8118                      {SrcLocInfo, Builder.getInt64(DeviceID),
8119                       Builder.getInt32(NumOperands), ArgsBaseGEP, ArgsGEP,
8120                       ArgSizesGEP, MaptypesArg, MapnamesArg, NullPtr});
8121 }
8122 
emitOffloadingArraysArgument(IRBuilderBase & Builder,TargetDataRTArgs & RTArgs,TargetDataInfo & Info,bool ForEndCall)8123 void OpenMPIRBuilder::emitOffloadingArraysArgument(IRBuilderBase &Builder,
8124                                                    TargetDataRTArgs &RTArgs,
8125                                                    TargetDataInfo &Info,
8126                                                    bool ForEndCall) {
8127   assert((!ForEndCall || Info.separateBeginEndCalls()) &&
8128          "expected region end call to runtime only when end call is separate");
8129   auto UnqualPtrTy = PointerType::getUnqual(M.getContext());
8130   auto VoidPtrTy = UnqualPtrTy;
8131   auto VoidPtrPtrTy = UnqualPtrTy;
8132   auto Int64Ty = Type::getInt64Ty(M.getContext());
8133   auto Int64PtrTy = UnqualPtrTy;
8134 
8135   if (!Info.NumberOfPtrs) {
8136     RTArgs.BasePointersArray = ConstantPointerNull::get(VoidPtrPtrTy);
8137     RTArgs.PointersArray = ConstantPointerNull::get(VoidPtrPtrTy);
8138     RTArgs.SizesArray = ConstantPointerNull::get(Int64PtrTy);
8139     RTArgs.MapTypesArray = ConstantPointerNull::get(Int64PtrTy);
8140     RTArgs.MapNamesArray = ConstantPointerNull::get(VoidPtrPtrTy);
8141     RTArgs.MappersArray = ConstantPointerNull::get(VoidPtrPtrTy);
8142     return;
8143   }
8144 
8145   RTArgs.BasePointersArray = Builder.CreateConstInBoundsGEP2_32(
8146       ArrayType::get(VoidPtrTy, Info.NumberOfPtrs),
8147       Info.RTArgs.BasePointersArray,
8148       /*Idx0=*/0, /*Idx1=*/0);
8149   RTArgs.PointersArray = Builder.CreateConstInBoundsGEP2_32(
8150       ArrayType::get(VoidPtrTy, Info.NumberOfPtrs), Info.RTArgs.PointersArray,
8151       /*Idx0=*/0,
8152       /*Idx1=*/0);
8153   RTArgs.SizesArray = Builder.CreateConstInBoundsGEP2_32(
8154       ArrayType::get(Int64Ty, Info.NumberOfPtrs), Info.RTArgs.SizesArray,
8155       /*Idx0=*/0, /*Idx1=*/0);
8156   RTArgs.MapTypesArray = Builder.CreateConstInBoundsGEP2_32(
8157       ArrayType::get(Int64Ty, Info.NumberOfPtrs),
8158       ForEndCall && Info.RTArgs.MapTypesArrayEnd ? Info.RTArgs.MapTypesArrayEnd
8159                                                  : Info.RTArgs.MapTypesArray,
8160       /*Idx0=*/0,
8161       /*Idx1=*/0);
8162 
8163   // Only emit the mapper information arrays if debug information is
8164   // requested.
8165   if (!Info.EmitDebug)
8166     RTArgs.MapNamesArray = ConstantPointerNull::get(VoidPtrPtrTy);
8167   else
8168     RTArgs.MapNamesArray = Builder.CreateConstInBoundsGEP2_32(
8169         ArrayType::get(VoidPtrTy, Info.NumberOfPtrs), Info.RTArgs.MapNamesArray,
8170         /*Idx0=*/0,
8171         /*Idx1=*/0);
8172   // If there is no user-defined mapper, set the mapper array to nullptr to
8173   // avoid an unnecessary data privatization
8174   if (!Info.HasMapper)
8175     RTArgs.MappersArray = ConstantPointerNull::get(VoidPtrPtrTy);
8176   else
8177     RTArgs.MappersArray =
8178         Builder.CreatePointerCast(Info.RTArgs.MappersArray, VoidPtrPtrTy);
8179 }
8180 
emitNonContiguousDescriptor(InsertPointTy AllocaIP,InsertPointTy CodeGenIP,MapInfosTy & CombinedInfo,TargetDataInfo & Info)8181 void OpenMPIRBuilder::emitNonContiguousDescriptor(InsertPointTy AllocaIP,
8182                                                   InsertPointTy CodeGenIP,
8183                                                   MapInfosTy &CombinedInfo,
8184                                                   TargetDataInfo &Info) {
8185   MapInfosTy::StructNonContiguousInfo &NonContigInfo =
8186       CombinedInfo.NonContigInfo;
8187 
8188   // Build an array of struct descriptor_dim and then assign it to
8189   // offload_args.
8190   //
8191   // struct descriptor_dim {
8192   //  uint64_t offset;
8193   //  uint64_t count;
8194   //  uint64_t stride
8195   // };
8196   Type *Int64Ty = Builder.getInt64Ty();
8197   StructType *DimTy = StructType::create(
8198       M.getContext(), ArrayRef<Type *>({Int64Ty, Int64Ty, Int64Ty}),
8199       "struct.descriptor_dim");
8200 
8201   enum { OffsetFD = 0, CountFD, StrideFD };
8202   // We need two index variable here since the size of "Dims" is the same as
8203   // the size of Components, however, the size of offset, count, and stride is
8204   // equal to the size of base declaration that is non-contiguous.
8205   for (unsigned I = 0, L = 0, E = NonContigInfo.Dims.size(); I < E; ++I) {
8206     // Skip emitting ir if dimension size is 1 since it cannot be
8207     // non-contiguous.
8208     if (NonContigInfo.Dims[I] == 1)
8209       continue;
8210     Builder.restoreIP(AllocaIP);
8211     ArrayType *ArrayTy = ArrayType::get(DimTy, NonContigInfo.Dims[I]);
8212     AllocaInst *DimsAddr =
8213         Builder.CreateAlloca(ArrayTy, /* ArraySize = */ nullptr, "dims");
8214     Builder.restoreIP(CodeGenIP);
8215     for (unsigned II = 0, EE = NonContigInfo.Dims[I]; II < EE; ++II) {
8216       unsigned RevIdx = EE - II - 1;
8217       Value *DimsLVal = Builder.CreateInBoundsGEP(
8218           DimsAddr->getAllocatedType(), DimsAddr,
8219           {Builder.getInt64(0), Builder.getInt64(II)});
8220       // Offset
8221       Value *OffsetLVal = Builder.CreateStructGEP(DimTy, DimsLVal, OffsetFD);
8222       Builder.CreateAlignedStore(
8223           NonContigInfo.Offsets[L][RevIdx], OffsetLVal,
8224           M.getDataLayout().getPrefTypeAlign(OffsetLVal->getType()));
8225       // Count
8226       Value *CountLVal = Builder.CreateStructGEP(DimTy, DimsLVal, CountFD);
8227       Builder.CreateAlignedStore(
8228           NonContigInfo.Counts[L][RevIdx], CountLVal,
8229           M.getDataLayout().getPrefTypeAlign(CountLVal->getType()));
8230       // Stride
8231       Value *StrideLVal = Builder.CreateStructGEP(DimTy, DimsLVal, StrideFD);
8232       Builder.CreateAlignedStore(
8233           NonContigInfo.Strides[L][RevIdx], StrideLVal,
8234           M.getDataLayout().getPrefTypeAlign(CountLVal->getType()));
8235     }
8236     // args[I] = &dims
8237     Builder.restoreIP(CodeGenIP);
8238     Value *DAddr = Builder.CreatePointerBitCastOrAddrSpaceCast(
8239         DimsAddr, Builder.getPtrTy());
8240     Value *P = Builder.CreateConstInBoundsGEP2_32(
8241         ArrayType::get(Builder.getPtrTy(), Info.NumberOfPtrs),
8242         Info.RTArgs.PointersArray, 0, I);
8243     Builder.CreateAlignedStore(
8244         DAddr, P, M.getDataLayout().getPrefTypeAlign(Builder.getPtrTy()));
8245     ++L;
8246   }
8247 }
8248 
emitUDMapperArrayInitOrDel(Function * MapperFn,Value * MapperHandle,Value * Base,Value * Begin,Value * Size,Value * MapType,Value * MapName,TypeSize ElementSize,BasicBlock * ExitBB,bool IsInit)8249 void OpenMPIRBuilder::emitUDMapperArrayInitOrDel(
8250     Function *MapperFn, Value *MapperHandle, Value *Base, Value *Begin,
8251     Value *Size, Value *MapType, Value *MapName, TypeSize ElementSize,
8252     BasicBlock *ExitBB, bool IsInit) {
8253   StringRef Prefix = IsInit ? ".init" : ".del";
8254 
8255   // Evaluate if this is an array section.
8256   BasicBlock *BodyBB = BasicBlock::Create(
8257       M.getContext(), createPlatformSpecificName({"omp.array", Prefix}));
8258   Value *IsArray =
8259       Builder.CreateICmpSGT(Size, Builder.getInt64(1), "omp.arrayinit.isarray");
8260   Value *DeleteBit = Builder.CreateAnd(
8261       MapType,
8262       Builder.getInt64(
8263           static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
8264               OpenMPOffloadMappingFlags::OMP_MAP_DELETE)));
8265   Value *DeleteCond;
8266   Value *Cond;
8267   if (IsInit) {
8268     // base != begin?
8269     Value *BaseIsBegin = Builder.CreateICmpNE(Base, Begin);
8270     // IsPtrAndObj?
8271     Value *PtrAndObjBit = Builder.CreateAnd(
8272         MapType,
8273         Builder.getInt64(
8274             static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
8275                 OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ)));
8276     PtrAndObjBit = Builder.CreateIsNotNull(PtrAndObjBit);
8277     BaseIsBegin = Builder.CreateAnd(BaseIsBegin, PtrAndObjBit);
8278     Cond = Builder.CreateOr(IsArray, BaseIsBegin);
8279     DeleteCond = Builder.CreateIsNull(
8280         DeleteBit,
8281         createPlatformSpecificName({"omp.array", Prefix, ".delete"}));
8282   } else {
8283     Cond = IsArray;
8284     DeleteCond = Builder.CreateIsNotNull(
8285         DeleteBit,
8286         createPlatformSpecificName({"omp.array", Prefix, ".delete"}));
8287   }
8288   Cond = Builder.CreateAnd(Cond, DeleteCond);
8289   Builder.CreateCondBr(Cond, BodyBB, ExitBB);
8290 
8291   emitBlock(BodyBB, MapperFn);
8292   // Get the array size by multiplying element size and element number (i.e., \p
8293   // Size).
8294   Value *ArraySize = Builder.CreateNUWMul(Size, Builder.getInt64(ElementSize));
8295   // Remove OMP_MAP_TO and OMP_MAP_FROM from the map type, so that it achieves
8296   // memory allocation/deletion purpose only.
8297   Value *MapTypeArg = Builder.CreateAnd(
8298       MapType,
8299       Builder.getInt64(
8300           ~static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
8301               OpenMPOffloadMappingFlags::OMP_MAP_TO |
8302               OpenMPOffloadMappingFlags::OMP_MAP_FROM)));
8303   MapTypeArg = Builder.CreateOr(
8304       MapTypeArg,
8305       Builder.getInt64(
8306           static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
8307               OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT)));
8308 
8309   // Call the runtime API __tgt_push_mapper_component to fill up the runtime
8310   // data structure.
8311   Value *OffloadingArgs[] = {MapperHandle, Base,       Begin,
8312                              ArraySize,    MapTypeArg, MapName};
8313   Builder.CreateCall(
8314       getOrCreateRuntimeFunction(M, OMPRTL___tgt_push_mapper_component),
8315       OffloadingArgs);
8316 }
8317 
emitUserDefinedMapper(function_ref<MapInfosOrErrorTy (InsertPointTy CodeGenIP,llvm::Value * PtrPHI,llvm::Value * BeginArg)> GenMapInfoCB,Type * ElemTy,StringRef FuncName,CustomMapperCallbackTy CustomMapperCB)8318 Expected<Function *> OpenMPIRBuilder::emitUserDefinedMapper(
8319     function_ref<MapInfosOrErrorTy(InsertPointTy CodeGenIP, llvm::Value *PtrPHI,
8320                                    llvm::Value *BeginArg)>
8321         GenMapInfoCB,
8322     Type *ElemTy, StringRef FuncName, CustomMapperCallbackTy CustomMapperCB) {
8323   SmallVector<Type *> Params;
8324   Params.emplace_back(Builder.getPtrTy());
8325   Params.emplace_back(Builder.getPtrTy());
8326   Params.emplace_back(Builder.getPtrTy());
8327   Params.emplace_back(Builder.getInt64Ty());
8328   Params.emplace_back(Builder.getInt64Ty());
8329   Params.emplace_back(Builder.getPtrTy());
8330 
8331   auto *FnTy =
8332       FunctionType::get(Builder.getVoidTy(), Params, /* IsVarArg */ false);
8333 
8334   SmallString<64> TyStr;
8335   raw_svector_ostream Out(TyStr);
8336   Function *MapperFn =
8337       Function::Create(FnTy, GlobalValue::InternalLinkage, FuncName, M);
8338   MapperFn->addFnAttr(Attribute::NoInline);
8339   MapperFn->addFnAttr(Attribute::NoUnwind);
8340   MapperFn->addParamAttr(0, Attribute::NoUndef);
8341   MapperFn->addParamAttr(1, Attribute::NoUndef);
8342   MapperFn->addParamAttr(2, Attribute::NoUndef);
8343   MapperFn->addParamAttr(3, Attribute::NoUndef);
8344   MapperFn->addParamAttr(4, Attribute::NoUndef);
8345   MapperFn->addParamAttr(5, Attribute::NoUndef);
8346 
8347   // Start the mapper function code generation.
8348   BasicBlock *EntryBB = BasicBlock::Create(M.getContext(), "entry", MapperFn);
8349   auto SavedIP = Builder.saveIP();
8350   Builder.SetInsertPoint(EntryBB);
8351 
8352   Value *MapperHandle = MapperFn->getArg(0);
8353   Value *BaseIn = MapperFn->getArg(1);
8354   Value *BeginIn = MapperFn->getArg(2);
8355   Value *Size = MapperFn->getArg(3);
8356   Value *MapType = MapperFn->getArg(4);
8357   Value *MapName = MapperFn->getArg(5);
8358 
8359   // Compute the starting and end addresses of array elements.
8360   // Prepare common arguments for array initiation and deletion.
8361   // Convert the size in bytes into the number of array elements.
8362   TypeSize ElementSize = M.getDataLayout().getTypeStoreSize(ElemTy);
8363   Size = Builder.CreateExactUDiv(Size, Builder.getInt64(ElementSize));
8364   Value *PtrBegin = BeginIn;
8365   Value *PtrEnd = Builder.CreateGEP(ElemTy, PtrBegin, Size);
8366 
8367   // Emit array initiation if this is an array section and \p MapType indicates
8368   // that memory allocation is required.
8369   BasicBlock *HeadBB = BasicBlock::Create(M.getContext(), "omp.arraymap.head");
8370   emitUDMapperArrayInitOrDel(MapperFn, MapperHandle, BaseIn, BeginIn, Size,
8371                              MapType, MapName, ElementSize, HeadBB,
8372                              /*IsInit=*/true);
8373 
8374   // Emit a for loop to iterate through SizeArg of elements and map all of them.
8375 
8376   // Emit the loop header block.
8377   emitBlock(HeadBB, MapperFn);
8378   BasicBlock *BodyBB = BasicBlock::Create(M.getContext(), "omp.arraymap.body");
8379   BasicBlock *DoneBB = BasicBlock::Create(M.getContext(), "omp.done");
8380   // Evaluate whether the initial condition is satisfied.
8381   Value *IsEmpty =
8382       Builder.CreateICmpEQ(PtrBegin, PtrEnd, "omp.arraymap.isempty");
8383   Builder.CreateCondBr(IsEmpty, DoneBB, BodyBB);
8384 
8385   // Emit the loop body block.
8386   emitBlock(BodyBB, MapperFn);
8387   BasicBlock *LastBB = BodyBB;
8388   PHINode *PtrPHI =
8389       Builder.CreatePHI(PtrBegin->getType(), 2, "omp.arraymap.ptrcurrent");
8390   PtrPHI->addIncoming(PtrBegin, HeadBB);
8391 
8392   // Get map clause information. Fill up the arrays with all mapped variables.
8393   MapInfosOrErrorTy Info = GenMapInfoCB(Builder.saveIP(), PtrPHI, BeginIn);
8394   if (!Info)
8395     return Info.takeError();
8396 
8397   // Call the runtime API __tgt_mapper_num_components to get the number of
8398   // pre-existing components.
8399   Value *OffloadingArgs[] = {MapperHandle};
8400   Value *PreviousSize = Builder.CreateCall(
8401       getOrCreateRuntimeFunction(M, OMPRTL___tgt_mapper_num_components),
8402       OffloadingArgs);
8403   Value *ShiftedPreviousSize =
8404       Builder.CreateShl(PreviousSize, Builder.getInt64(getFlagMemberOffset()));
8405 
8406   // Fill up the runtime mapper handle for all components.
8407   for (unsigned I = 0; I < Info->BasePointers.size(); ++I) {
8408     Value *CurBaseArg = Info->BasePointers[I];
8409     Value *CurBeginArg = Info->Pointers[I];
8410     Value *CurSizeArg = Info->Sizes[I];
8411     Value *CurNameArg = Info->Names.size()
8412                             ? Info->Names[I]
8413                             : Constant::getNullValue(Builder.getPtrTy());
8414 
8415     // Extract the MEMBER_OF field from the map type.
8416     Value *OriMapType = Builder.getInt64(
8417         static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
8418             Info->Types[I]));
8419     Value *MemberMapType =
8420         Builder.CreateNUWAdd(OriMapType, ShiftedPreviousSize);
8421 
8422     // Combine the map type inherited from user-defined mapper with that
8423     // specified in the program. According to the OMP_MAP_TO and OMP_MAP_FROM
8424     // bits of the \a MapType, which is the input argument of the mapper
8425     // function, the following code will set the OMP_MAP_TO and OMP_MAP_FROM
8426     // bits of MemberMapType.
8427     // [OpenMP 5.0], 1.2.6. map-type decay.
8428     //        | alloc |  to   | from  | tofrom | release | delete
8429     // ----------------------------------------------------------
8430     // alloc  | alloc | alloc | alloc | alloc  | release | delete
8431     // to     | alloc |  to   | alloc |   to   | release | delete
8432     // from   | alloc | alloc | from  |  from  | release | delete
8433     // tofrom | alloc |  to   | from  | tofrom | release | delete
8434     Value *LeftToFrom = Builder.CreateAnd(
8435         MapType,
8436         Builder.getInt64(
8437             static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
8438                 OpenMPOffloadMappingFlags::OMP_MAP_TO |
8439                 OpenMPOffloadMappingFlags::OMP_MAP_FROM)));
8440     BasicBlock *AllocBB = BasicBlock::Create(M.getContext(), "omp.type.alloc");
8441     BasicBlock *AllocElseBB =
8442         BasicBlock::Create(M.getContext(), "omp.type.alloc.else");
8443     BasicBlock *ToBB = BasicBlock::Create(M.getContext(), "omp.type.to");
8444     BasicBlock *ToElseBB =
8445         BasicBlock::Create(M.getContext(), "omp.type.to.else");
8446     BasicBlock *FromBB = BasicBlock::Create(M.getContext(), "omp.type.from");
8447     BasicBlock *EndBB = BasicBlock::Create(M.getContext(), "omp.type.end");
8448     Value *IsAlloc = Builder.CreateIsNull(LeftToFrom);
8449     Builder.CreateCondBr(IsAlloc, AllocBB, AllocElseBB);
8450     // In case of alloc, clear OMP_MAP_TO and OMP_MAP_FROM.
8451     emitBlock(AllocBB, MapperFn);
8452     Value *AllocMapType = Builder.CreateAnd(
8453         MemberMapType,
8454         Builder.getInt64(
8455             ~static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
8456                 OpenMPOffloadMappingFlags::OMP_MAP_TO |
8457                 OpenMPOffloadMappingFlags::OMP_MAP_FROM)));
8458     Builder.CreateBr(EndBB);
8459     emitBlock(AllocElseBB, MapperFn);
8460     Value *IsTo = Builder.CreateICmpEQ(
8461         LeftToFrom,
8462         Builder.getInt64(
8463             static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
8464                 OpenMPOffloadMappingFlags::OMP_MAP_TO)));
8465     Builder.CreateCondBr(IsTo, ToBB, ToElseBB);
8466     // In case of to, clear OMP_MAP_FROM.
8467     emitBlock(ToBB, MapperFn);
8468     Value *ToMapType = Builder.CreateAnd(
8469         MemberMapType,
8470         Builder.getInt64(
8471             ~static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
8472                 OpenMPOffloadMappingFlags::OMP_MAP_FROM)));
8473     Builder.CreateBr(EndBB);
8474     emitBlock(ToElseBB, MapperFn);
8475     Value *IsFrom = Builder.CreateICmpEQ(
8476         LeftToFrom,
8477         Builder.getInt64(
8478             static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
8479                 OpenMPOffloadMappingFlags::OMP_MAP_FROM)));
8480     Builder.CreateCondBr(IsFrom, FromBB, EndBB);
8481     // In case of from, clear OMP_MAP_TO.
8482     emitBlock(FromBB, MapperFn);
8483     Value *FromMapType = Builder.CreateAnd(
8484         MemberMapType,
8485         Builder.getInt64(
8486             ~static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
8487                 OpenMPOffloadMappingFlags::OMP_MAP_TO)));
8488     // In case of tofrom, do nothing.
8489     emitBlock(EndBB, MapperFn);
8490     LastBB = EndBB;
8491     PHINode *CurMapType =
8492         Builder.CreatePHI(Builder.getInt64Ty(), 4, "omp.maptype");
8493     CurMapType->addIncoming(AllocMapType, AllocBB);
8494     CurMapType->addIncoming(ToMapType, ToBB);
8495     CurMapType->addIncoming(FromMapType, FromBB);
8496     CurMapType->addIncoming(MemberMapType, ToElseBB);
8497 
8498     Value *OffloadingArgs[] = {MapperHandle, CurBaseArg, CurBeginArg,
8499                                CurSizeArg,   CurMapType, CurNameArg};
8500 
8501     auto ChildMapperFn = CustomMapperCB(I);
8502     if (!ChildMapperFn)
8503       return ChildMapperFn.takeError();
8504     if (*ChildMapperFn) {
8505       // Call the corresponding mapper function.
8506       Builder.CreateCall(*ChildMapperFn, OffloadingArgs)->setDoesNotThrow();
8507     } else {
8508       // Call the runtime API __tgt_push_mapper_component to fill up the runtime
8509       // data structure.
8510       Builder.CreateCall(
8511           getOrCreateRuntimeFunction(M, OMPRTL___tgt_push_mapper_component),
8512           OffloadingArgs);
8513     }
8514   }
8515 
8516   // Update the pointer to point to the next element that needs to be mapped,
8517   // and check whether we have mapped all elements.
8518   Value *PtrNext = Builder.CreateConstGEP1_32(ElemTy, PtrPHI, /*Idx0=*/1,
8519                                               "omp.arraymap.next");
8520   PtrPHI->addIncoming(PtrNext, LastBB);
8521   Value *IsDone = Builder.CreateICmpEQ(PtrNext, PtrEnd, "omp.arraymap.isdone");
8522   BasicBlock *ExitBB = BasicBlock::Create(M.getContext(), "omp.arraymap.exit");
8523   Builder.CreateCondBr(IsDone, ExitBB, BodyBB);
8524 
8525   emitBlock(ExitBB, MapperFn);
8526   // Emit array deletion if this is an array section and \p MapType indicates
8527   // that deletion is required.
8528   emitUDMapperArrayInitOrDel(MapperFn, MapperHandle, BaseIn, BeginIn, Size,
8529                              MapType, MapName, ElementSize, DoneBB,
8530                              /*IsInit=*/false);
8531 
8532   // Emit the function exit block.
8533   emitBlock(DoneBB, MapperFn, /*IsFinished=*/true);
8534 
8535   Builder.CreateRetVoid();
8536   Builder.restoreIP(SavedIP);
8537   return MapperFn;
8538 }
8539 
emitOffloadingArrays(InsertPointTy AllocaIP,InsertPointTy CodeGenIP,MapInfosTy & CombinedInfo,TargetDataInfo & Info,CustomMapperCallbackTy CustomMapperCB,bool IsNonContiguous,function_ref<void (unsigned int,Value *)> DeviceAddrCB)8540 Error OpenMPIRBuilder::emitOffloadingArrays(
8541     InsertPointTy AllocaIP, InsertPointTy CodeGenIP, MapInfosTy &CombinedInfo,
8542     TargetDataInfo &Info, CustomMapperCallbackTy CustomMapperCB,
8543     bool IsNonContiguous,
8544     function_ref<void(unsigned int, Value *)> DeviceAddrCB) {
8545 
8546   // Reset the array information.
8547   Info.clearArrayInfo();
8548   Info.NumberOfPtrs = CombinedInfo.BasePointers.size();
8549 
8550   if (Info.NumberOfPtrs == 0)
8551     return Error::success();
8552 
8553   Builder.restoreIP(AllocaIP);
8554   // Detect if we have any capture size requiring runtime evaluation of the
8555   // size so that a constant array could be eventually used.
8556   ArrayType *PointerArrayType =
8557       ArrayType::get(Builder.getPtrTy(), Info.NumberOfPtrs);
8558 
8559   Info.RTArgs.BasePointersArray = Builder.CreateAlloca(
8560       PointerArrayType, /* ArraySize = */ nullptr, ".offload_baseptrs");
8561 
8562   Info.RTArgs.PointersArray = Builder.CreateAlloca(
8563       PointerArrayType, /* ArraySize = */ nullptr, ".offload_ptrs");
8564   AllocaInst *MappersArray = Builder.CreateAlloca(
8565       PointerArrayType, /* ArraySize = */ nullptr, ".offload_mappers");
8566   Info.RTArgs.MappersArray = MappersArray;
8567 
8568   // If we don't have any VLA types or other types that require runtime
8569   // evaluation, we can use a constant array for the map sizes, otherwise we
8570   // need to fill up the arrays as we do for the pointers.
8571   Type *Int64Ty = Builder.getInt64Ty();
8572   SmallVector<Constant *> ConstSizes(CombinedInfo.Sizes.size(),
8573                                      ConstantInt::get(Int64Ty, 0));
8574   SmallBitVector RuntimeSizes(CombinedInfo.Sizes.size());
8575   for (unsigned I = 0, E = CombinedInfo.Sizes.size(); I < E; ++I) {
8576     if (auto *CI = dyn_cast<Constant>(CombinedInfo.Sizes[I])) {
8577       if (!isa<ConstantExpr>(CI) && !isa<GlobalValue>(CI)) {
8578         if (IsNonContiguous &&
8579             static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
8580                 CombinedInfo.Types[I] &
8581                 OpenMPOffloadMappingFlags::OMP_MAP_NON_CONTIG))
8582           ConstSizes[I] =
8583               ConstantInt::get(Int64Ty, CombinedInfo.NonContigInfo.Dims[I]);
8584         else
8585           ConstSizes[I] = CI;
8586         continue;
8587       }
8588     }
8589     RuntimeSizes.set(I);
8590   }
8591 
8592   if (RuntimeSizes.all()) {
8593     ArrayType *SizeArrayType = ArrayType::get(Int64Ty, Info.NumberOfPtrs);
8594     Info.RTArgs.SizesArray = Builder.CreateAlloca(
8595         SizeArrayType, /* ArraySize = */ nullptr, ".offload_sizes");
8596     Builder.restoreIP(CodeGenIP);
8597   } else {
8598     auto *SizesArrayInit = ConstantArray::get(
8599         ArrayType::get(Int64Ty, ConstSizes.size()), ConstSizes);
8600     std::string Name = createPlatformSpecificName({"offload_sizes"});
8601     auto *SizesArrayGbl =
8602         new GlobalVariable(M, SizesArrayInit->getType(), /*isConstant=*/true,
8603                            GlobalValue::PrivateLinkage, SizesArrayInit, Name);
8604     SizesArrayGbl->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
8605 
8606     if (!RuntimeSizes.any()) {
8607       Info.RTArgs.SizesArray = SizesArrayGbl;
8608     } else {
8609       unsigned IndexSize = M.getDataLayout().getIndexSizeInBits(0);
8610       Align OffloadSizeAlign = M.getDataLayout().getABIIntegerTypeAlignment(64);
8611       ArrayType *SizeArrayType = ArrayType::get(Int64Ty, Info.NumberOfPtrs);
8612       AllocaInst *Buffer = Builder.CreateAlloca(
8613           SizeArrayType, /* ArraySize = */ nullptr, ".offload_sizes");
8614       Buffer->setAlignment(OffloadSizeAlign);
8615       Builder.restoreIP(CodeGenIP);
8616       Builder.CreateMemCpy(
8617           Buffer, M.getDataLayout().getPrefTypeAlign(Buffer->getType()),
8618           SizesArrayGbl, OffloadSizeAlign,
8619           Builder.getIntN(
8620               IndexSize,
8621               Buffer->getAllocationSize(M.getDataLayout())->getFixedValue()));
8622 
8623       Info.RTArgs.SizesArray = Buffer;
8624     }
8625     Builder.restoreIP(CodeGenIP);
8626   }
8627 
8628   // The map types are always constant so we don't need to generate code to
8629   // fill arrays. Instead, we create an array constant.
8630   SmallVector<uint64_t, 4> Mapping;
8631   for (auto mapFlag : CombinedInfo.Types)
8632     Mapping.push_back(
8633         static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
8634             mapFlag));
8635   std::string MaptypesName = createPlatformSpecificName({"offload_maptypes"});
8636   auto *MapTypesArrayGbl = createOffloadMaptypes(Mapping, MaptypesName);
8637   Info.RTArgs.MapTypesArray = MapTypesArrayGbl;
8638 
8639   // The information types are only built if provided.
8640   if (!CombinedInfo.Names.empty()) {
8641     auto *MapNamesArrayGbl = createOffloadMapnames(
8642         CombinedInfo.Names, createPlatformSpecificName({"offload_mapnames"}));
8643     Info.RTArgs.MapNamesArray = MapNamesArrayGbl;
8644     Info.EmitDebug = true;
8645   } else {
8646     Info.RTArgs.MapNamesArray =
8647         Constant::getNullValue(PointerType::getUnqual(Builder.getContext()));
8648     Info.EmitDebug = false;
8649   }
8650 
8651   // If there's a present map type modifier, it must not be applied to the end
8652   // of a region, so generate a separate map type array in that case.
8653   if (Info.separateBeginEndCalls()) {
8654     bool EndMapTypesDiffer = false;
8655     for (uint64_t &Type : Mapping) {
8656       if (Type & static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
8657                      OpenMPOffloadMappingFlags::OMP_MAP_PRESENT)) {
8658         Type &= ~static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
8659             OpenMPOffloadMappingFlags::OMP_MAP_PRESENT);
8660         EndMapTypesDiffer = true;
8661       }
8662     }
8663     if (EndMapTypesDiffer) {
8664       MapTypesArrayGbl = createOffloadMaptypes(Mapping, MaptypesName);
8665       Info.RTArgs.MapTypesArrayEnd = MapTypesArrayGbl;
8666     }
8667   }
8668 
8669   PointerType *PtrTy = Builder.getPtrTy();
8670   for (unsigned I = 0; I < Info.NumberOfPtrs; ++I) {
8671     Value *BPVal = CombinedInfo.BasePointers[I];
8672     Value *BP = Builder.CreateConstInBoundsGEP2_32(
8673         ArrayType::get(PtrTy, Info.NumberOfPtrs), Info.RTArgs.BasePointersArray,
8674         0, I);
8675     Builder.CreateAlignedStore(BPVal, BP,
8676                                M.getDataLayout().getPrefTypeAlign(PtrTy));
8677 
8678     if (Info.requiresDevicePointerInfo()) {
8679       if (CombinedInfo.DevicePointers[I] == DeviceInfoTy::Pointer) {
8680         CodeGenIP = Builder.saveIP();
8681         Builder.restoreIP(AllocaIP);
8682         Info.DevicePtrInfoMap[BPVal] = {BP, Builder.CreateAlloca(PtrTy)};
8683         Builder.restoreIP(CodeGenIP);
8684         if (DeviceAddrCB)
8685           DeviceAddrCB(I, Info.DevicePtrInfoMap[BPVal].second);
8686       } else if (CombinedInfo.DevicePointers[I] == DeviceInfoTy::Address) {
8687         Info.DevicePtrInfoMap[BPVal] = {BP, BP};
8688         if (DeviceAddrCB)
8689           DeviceAddrCB(I, BP);
8690       }
8691     }
8692 
8693     Value *PVal = CombinedInfo.Pointers[I];
8694     Value *P = Builder.CreateConstInBoundsGEP2_32(
8695         ArrayType::get(PtrTy, Info.NumberOfPtrs), Info.RTArgs.PointersArray, 0,
8696         I);
8697     // TODO: Check alignment correct.
8698     Builder.CreateAlignedStore(PVal, P,
8699                                M.getDataLayout().getPrefTypeAlign(PtrTy));
8700 
8701     if (RuntimeSizes.test(I)) {
8702       Value *S = Builder.CreateConstInBoundsGEP2_32(
8703           ArrayType::get(Int64Ty, Info.NumberOfPtrs), Info.RTArgs.SizesArray,
8704           /*Idx0=*/0,
8705           /*Idx1=*/I);
8706       Builder.CreateAlignedStore(Builder.CreateIntCast(CombinedInfo.Sizes[I],
8707                                                        Int64Ty,
8708                                                        /*isSigned=*/true),
8709                                  S, M.getDataLayout().getPrefTypeAlign(PtrTy));
8710     }
8711     // Fill up the mapper array.
8712     unsigned IndexSize = M.getDataLayout().getIndexSizeInBits(0);
8713     Value *MFunc = ConstantPointerNull::get(PtrTy);
8714 
8715     auto CustomMFunc = CustomMapperCB(I);
8716     if (!CustomMFunc)
8717       return CustomMFunc.takeError();
8718     if (*CustomMFunc)
8719       MFunc = Builder.CreatePointerCast(*CustomMFunc, PtrTy);
8720 
8721     Value *MAddr = Builder.CreateInBoundsGEP(
8722         MappersArray->getAllocatedType(), MappersArray,
8723         {Builder.getIntN(IndexSize, 0), Builder.getIntN(IndexSize, I)});
8724     Builder.CreateAlignedStore(
8725         MFunc, MAddr, M.getDataLayout().getPrefTypeAlign(MAddr->getType()));
8726   }
8727 
8728   if (!IsNonContiguous || CombinedInfo.NonContigInfo.Offsets.empty() ||
8729       Info.NumberOfPtrs == 0)
8730     return Error::success();
8731   emitNonContiguousDescriptor(AllocaIP, CodeGenIP, CombinedInfo, Info);
8732   return Error::success();
8733 }
8734 
emitBranch(BasicBlock * Target)8735 void OpenMPIRBuilder::emitBranch(BasicBlock *Target) {
8736   BasicBlock *CurBB = Builder.GetInsertBlock();
8737 
8738   if (!CurBB || CurBB->getTerminator()) {
8739     // If there is no insert point or the previous block is already
8740     // terminated, don't touch it.
8741   } else {
8742     // Otherwise, create a fall-through branch.
8743     Builder.CreateBr(Target);
8744   }
8745 
8746   Builder.ClearInsertionPoint();
8747 }
8748 
emitBlock(BasicBlock * BB,Function * CurFn,bool IsFinished)8749 void OpenMPIRBuilder::emitBlock(BasicBlock *BB, Function *CurFn,
8750                                 bool IsFinished) {
8751   BasicBlock *CurBB = Builder.GetInsertBlock();
8752 
8753   // Fall out of the current block (if necessary).
8754   emitBranch(BB);
8755 
8756   if (IsFinished && BB->use_empty()) {
8757     BB->eraseFromParent();
8758     return;
8759   }
8760 
8761   // Place the block after the current block, if possible, or else at
8762   // the end of the function.
8763   if (CurBB && CurBB->getParent())
8764     CurFn->insert(std::next(CurBB->getIterator()), BB);
8765   else
8766     CurFn->insert(CurFn->end(), BB);
8767   Builder.SetInsertPoint(BB);
8768 }
8769 
emitIfClause(Value * Cond,BodyGenCallbackTy ThenGen,BodyGenCallbackTy ElseGen,InsertPointTy AllocaIP)8770 Error OpenMPIRBuilder::emitIfClause(Value *Cond, BodyGenCallbackTy ThenGen,
8771                                     BodyGenCallbackTy ElseGen,
8772                                     InsertPointTy AllocaIP) {
8773   // If the condition constant folds and can be elided, try to avoid emitting
8774   // the condition and the dead arm of the if/else.
8775   if (auto *CI = dyn_cast<ConstantInt>(Cond)) {
8776     auto CondConstant = CI->getSExtValue();
8777     if (CondConstant)
8778       return ThenGen(AllocaIP, Builder.saveIP());
8779 
8780     return ElseGen(AllocaIP, Builder.saveIP());
8781   }
8782 
8783   Function *CurFn = Builder.GetInsertBlock()->getParent();
8784 
8785   // Otherwise, the condition did not fold, or we couldn't elide it.  Just
8786   // emit the conditional branch.
8787   BasicBlock *ThenBlock = BasicBlock::Create(M.getContext(), "omp_if.then");
8788   BasicBlock *ElseBlock = BasicBlock::Create(M.getContext(), "omp_if.else");
8789   BasicBlock *ContBlock = BasicBlock::Create(M.getContext(), "omp_if.end");
8790   Builder.CreateCondBr(Cond, ThenBlock, ElseBlock);
8791   // Emit the 'then' code.
8792   emitBlock(ThenBlock, CurFn);
8793   if (Error Err = ThenGen(AllocaIP, Builder.saveIP()))
8794     return Err;
8795   emitBranch(ContBlock);
8796   // Emit the 'else' code if present.
8797   // There is no need to emit line number for unconditional branch.
8798   emitBlock(ElseBlock, CurFn);
8799   if (Error Err = ElseGen(AllocaIP, Builder.saveIP()))
8800     return Err;
8801   // There is no need to emit line number for unconditional branch.
8802   emitBranch(ContBlock);
8803   // Emit the continuation block for code after the if.
8804   emitBlock(ContBlock, CurFn, /*IsFinished=*/true);
8805   return Error::success();
8806 }
8807 
checkAndEmitFlushAfterAtomic(const LocationDescription & Loc,llvm::AtomicOrdering AO,AtomicKind AK)8808 bool OpenMPIRBuilder::checkAndEmitFlushAfterAtomic(
8809     const LocationDescription &Loc, llvm::AtomicOrdering AO, AtomicKind AK) {
8810   assert(!(AO == AtomicOrdering::NotAtomic ||
8811            AO == llvm::AtomicOrdering::Unordered) &&
8812          "Unexpected Atomic Ordering.");
8813 
8814   bool Flush = false;
8815   llvm::AtomicOrdering FlushAO = AtomicOrdering::Monotonic;
8816 
8817   switch (AK) {
8818   case Read:
8819     if (AO == AtomicOrdering::Acquire || AO == AtomicOrdering::AcquireRelease ||
8820         AO == AtomicOrdering::SequentiallyConsistent) {
8821       FlushAO = AtomicOrdering::Acquire;
8822       Flush = true;
8823     }
8824     break;
8825   case Write:
8826   case Compare:
8827   case Update:
8828     if (AO == AtomicOrdering::Release || AO == AtomicOrdering::AcquireRelease ||
8829         AO == AtomicOrdering::SequentiallyConsistent) {
8830       FlushAO = AtomicOrdering::Release;
8831       Flush = true;
8832     }
8833     break;
8834   case Capture:
8835     switch (AO) {
8836     case AtomicOrdering::Acquire:
8837       FlushAO = AtomicOrdering::Acquire;
8838       Flush = true;
8839       break;
8840     case AtomicOrdering::Release:
8841       FlushAO = AtomicOrdering::Release;
8842       Flush = true;
8843       break;
8844     case AtomicOrdering::AcquireRelease:
8845     case AtomicOrdering::SequentiallyConsistent:
8846       FlushAO = AtomicOrdering::AcquireRelease;
8847       Flush = true;
8848       break;
8849     default:
8850       // do nothing - leave silently.
8851       break;
8852     }
8853   }
8854 
8855   if (Flush) {
8856     // Currently Flush RT call still doesn't take memory_ordering, so for when
8857     // that happens, this tries to do the resolution of which atomic ordering
8858     // to use with but issue the flush call
8859     // TODO: pass `FlushAO` after memory ordering support is added
8860     (void)FlushAO;
8861     emitFlush(Loc);
8862   }
8863 
8864   // for AO == AtomicOrdering::Monotonic and  all other case combinations
8865   // do nothing
8866   return Flush;
8867 }
8868 
8869 OpenMPIRBuilder::InsertPointTy
createAtomicRead(const LocationDescription & Loc,AtomicOpValue & X,AtomicOpValue & V,AtomicOrdering AO,InsertPointTy AllocaIP)8870 OpenMPIRBuilder::createAtomicRead(const LocationDescription &Loc,
8871                                   AtomicOpValue &X, AtomicOpValue &V,
8872                                   AtomicOrdering AO, InsertPointTy AllocaIP) {
8873   if (!updateToLocation(Loc))
8874     return Loc.IP;
8875 
8876   assert(X.Var->getType()->isPointerTy() &&
8877          "OMP Atomic expects a pointer to target memory");
8878   Type *XElemTy = X.ElemTy;
8879   assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
8880           XElemTy->isPointerTy() || XElemTy->isStructTy()) &&
8881          "OMP atomic read expected a scalar type");
8882 
8883   Value *XRead = nullptr;
8884 
8885   if (XElemTy->isIntegerTy()) {
8886     LoadInst *XLD =
8887         Builder.CreateLoad(XElemTy, X.Var, X.IsVolatile, "omp.atomic.read");
8888     XLD->setAtomic(AO);
8889     XRead = cast<Value>(XLD);
8890   } else if (XElemTy->isStructTy()) {
8891     // FIXME: Add checks to ensure __atomic_load is emitted iff the
8892     // target does not support `atomicrmw` of the size of the struct
8893     LoadInst *OldVal = Builder.CreateLoad(XElemTy, X.Var, "omp.atomic.read");
8894     OldVal->setAtomic(AO);
8895     const DataLayout &LoadDL = OldVal->getModule()->getDataLayout();
8896     unsigned LoadSize =
8897         LoadDL.getTypeStoreSize(OldVal->getPointerOperand()->getType());
8898     OpenMPIRBuilder::AtomicInfo atomicInfo(
8899         &Builder, XElemTy, LoadSize * 8, LoadSize * 8, OldVal->getAlign(),
8900         OldVal->getAlign(), true /* UseLibcall */, AllocaIP, X.Var);
8901     auto AtomicLoadRes = atomicInfo.EmitAtomicLoadLibcall(AO);
8902     XRead = AtomicLoadRes.first;
8903     OldVal->eraseFromParent();
8904   } else {
8905     // We need to perform atomic op as integer
8906     IntegerType *IntCastTy =
8907         IntegerType::get(M.getContext(), XElemTy->getScalarSizeInBits());
8908     LoadInst *XLoad =
8909         Builder.CreateLoad(IntCastTy, X.Var, X.IsVolatile, "omp.atomic.load");
8910     XLoad->setAtomic(AO);
8911     if (XElemTy->isFloatingPointTy()) {
8912       XRead = Builder.CreateBitCast(XLoad, XElemTy, "atomic.flt.cast");
8913     } else {
8914       XRead = Builder.CreateIntToPtr(XLoad, XElemTy, "atomic.ptr.cast");
8915     }
8916   }
8917   checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Read);
8918   Builder.CreateStore(XRead, V.Var, V.IsVolatile);
8919   return Builder.saveIP();
8920 }
8921 
8922 OpenMPIRBuilder::InsertPointTy
createAtomicWrite(const LocationDescription & Loc,AtomicOpValue & X,Value * Expr,AtomicOrdering AO,InsertPointTy AllocaIP)8923 OpenMPIRBuilder::createAtomicWrite(const LocationDescription &Loc,
8924                                    AtomicOpValue &X, Value *Expr,
8925                                    AtomicOrdering AO, InsertPointTy AllocaIP) {
8926   if (!updateToLocation(Loc))
8927     return Loc.IP;
8928 
8929   assert(X.Var->getType()->isPointerTy() &&
8930          "OMP Atomic expects a pointer to target memory");
8931   Type *XElemTy = X.ElemTy;
8932   assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
8933           XElemTy->isPointerTy() || XElemTy->isStructTy()) &&
8934          "OMP atomic write expected a scalar type");
8935 
8936   if (XElemTy->isIntegerTy()) {
8937     StoreInst *XSt = Builder.CreateStore(Expr, X.Var, X.IsVolatile);
8938     XSt->setAtomic(AO);
8939   } else if (XElemTy->isStructTy()) {
8940     LoadInst *OldVal = Builder.CreateLoad(XElemTy, X.Var, "omp.atomic.read");
8941     const DataLayout &LoadDL = OldVal->getModule()->getDataLayout();
8942     unsigned LoadSize =
8943         LoadDL.getTypeStoreSize(OldVal->getPointerOperand()->getType());
8944     OpenMPIRBuilder::AtomicInfo atomicInfo(
8945         &Builder, XElemTy, LoadSize * 8, LoadSize * 8, OldVal->getAlign(),
8946         OldVal->getAlign(), true /* UseLibcall */, AllocaIP, X.Var);
8947     atomicInfo.EmitAtomicStoreLibcall(AO, Expr);
8948     OldVal->eraseFromParent();
8949   } else {
8950     // We need to bitcast and perform atomic op as integers
8951     IntegerType *IntCastTy =
8952         IntegerType::get(M.getContext(), XElemTy->getScalarSizeInBits());
8953     Value *ExprCast =
8954         Builder.CreateBitCast(Expr, IntCastTy, "atomic.src.int.cast");
8955     StoreInst *XSt = Builder.CreateStore(ExprCast, X.Var, X.IsVolatile);
8956     XSt->setAtomic(AO);
8957   }
8958 
8959   checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Write);
8960   return Builder.saveIP();
8961 }
8962 
createAtomicUpdate(const LocationDescription & Loc,InsertPointTy AllocaIP,AtomicOpValue & X,Value * Expr,AtomicOrdering AO,AtomicRMWInst::BinOp RMWOp,AtomicUpdateCallbackTy & UpdateOp,bool IsXBinopExpr)8963 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createAtomicUpdate(
8964     const LocationDescription &Loc, InsertPointTy AllocaIP, AtomicOpValue &X,
8965     Value *Expr, AtomicOrdering AO, AtomicRMWInst::BinOp RMWOp,
8966     AtomicUpdateCallbackTy &UpdateOp, bool IsXBinopExpr) {
8967   assert(!isConflictIP(Loc.IP, AllocaIP) && "IPs must not be ambiguous");
8968   if (!updateToLocation(Loc))
8969     return Loc.IP;
8970 
8971   LLVM_DEBUG({
8972     Type *XTy = X.Var->getType();
8973     assert(XTy->isPointerTy() &&
8974            "OMP Atomic expects a pointer to target memory");
8975     Type *XElemTy = X.ElemTy;
8976     assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
8977             XElemTy->isPointerTy()) &&
8978            "OMP atomic update expected a scalar type");
8979     assert((RMWOp != AtomicRMWInst::Max) && (RMWOp != AtomicRMWInst::Min) &&
8980            (RMWOp != AtomicRMWInst::UMax) && (RMWOp != AtomicRMWInst::UMin) &&
8981            "OpenMP atomic does not support LT or GT operations");
8982   });
8983 
8984   Expected<std::pair<Value *, Value *>> AtomicResult =
8985       emitAtomicUpdate(AllocaIP, X.Var, X.ElemTy, Expr, AO, RMWOp, UpdateOp,
8986                        X.IsVolatile, IsXBinopExpr);
8987   if (!AtomicResult)
8988     return AtomicResult.takeError();
8989   checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Update);
8990   return Builder.saveIP();
8991 }
8992 
8993 // FIXME: Duplicating AtomicExpand
emitRMWOpAsInstruction(Value * Src1,Value * Src2,AtomicRMWInst::BinOp RMWOp)8994 Value *OpenMPIRBuilder::emitRMWOpAsInstruction(Value *Src1, Value *Src2,
8995                                                AtomicRMWInst::BinOp RMWOp) {
8996   switch (RMWOp) {
8997   case AtomicRMWInst::Add:
8998     return Builder.CreateAdd(Src1, Src2);
8999   case AtomicRMWInst::Sub:
9000     return Builder.CreateSub(Src1, Src2);
9001   case AtomicRMWInst::And:
9002     return Builder.CreateAnd(Src1, Src2);
9003   case AtomicRMWInst::Nand:
9004     return Builder.CreateNeg(Builder.CreateAnd(Src1, Src2));
9005   case AtomicRMWInst::Or:
9006     return Builder.CreateOr(Src1, Src2);
9007   case AtomicRMWInst::Xor:
9008     return Builder.CreateXor(Src1, Src2);
9009   case AtomicRMWInst::Xchg:
9010   case AtomicRMWInst::FAdd:
9011   case AtomicRMWInst::FSub:
9012   case AtomicRMWInst::BAD_BINOP:
9013   case AtomicRMWInst::Max:
9014   case AtomicRMWInst::Min:
9015   case AtomicRMWInst::UMax:
9016   case AtomicRMWInst::UMin:
9017   case AtomicRMWInst::FMax:
9018   case AtomicRMWInst::FMin:
9019   case AtomicRMWInst::FMaximum:
9020   case AtomicRMWInst::FMinimum:
9021   case AtomicRMWInst::UIncWrap:
9022   case AtomicRMWInst::UDecWrap:
9023   case AtomicRMWInst::USubCond:
9024   case AtomicRMWInst::USubSat:
9025     llvm_unreachable("Unsupported atomic update operation");
9026   }
9027   llvm_unreachable("Unsupported atomic update operation");
9028 }
9029 
emitAtomicUpdate(InsertPointTy AllocaIP,Value * X,Type * XElemTy,Value * Expr,AtomicOrdering AO,AtomicRMWInst::BinOp RMWOp,AtomicUpdateCallbackTy & UpdateOp,bool VolatileX,bool IsXBinopExpr)9030 Expected<std::pair<Value *, Value *>> OpenMPIRBuilder::emitAtomicUpdate(
9031     InsertPointTy AllocaIP, Value *X, Type *XElemTy, Value *Expr,
9032     AtomicOrdering AO, AtomicRMWInst::BinOp RMWOp,
9033     AtomicUpdateCallbackTy &UpdateOp, bool VolatileX, bool IsXBinopExpr) {
9034   // TODO: handle the case where XElemTy is not byte-sized or not a power of 2
9035   // or a complex datatype.
9036   bool emitRMWOp = false;
9037   switch (RMWOp) {
9038   case AtomicRMWInst::Add:
9039   case AtomicRMWInst::And:
9040   case AtomicRMWInst::Nand:
9041   case AtomicRMWInst::Or:
9042   case AtomicRMWInst::Xor:
9043   case AtomicRMWInst::Xchg:
9044     emitRMWOp = XElemTy;
9045     break;
9046   case AtomicRMWInst::Sub:
9047     emitRMWOp = (IsXBinopExpr && XElemTy);
9048     break;
9049   default:
9050     emitRMWOp = false;
9051   }
9052   emitRMWOp &= XElemTy->isIntegerTy();
9053 
9054   std::pair<Value *, Value *> Res;
9055   if (emitRMWOp) {
9056     Res.first = Builder.CreateAtomicRMW(RMWOp, X, Expr, llvm::MaybeAlign(), AO);
9057     // not needed except in case of postfix captures. Generate anyway for
9058     // consistency with the else part. Will be removed with any DCE pass.
9059     // AtomicRMWInst::Xchg does not have a coressponding instruction.
9060     if (RMWOp == AtomicRMWInst::Xchg)
9061       Res.second = Res.first;
9062     else
9063       Res.second = emitRMWOpAsInstruction(Res.first, Expr, RMWOp);
9064   } else if (RMWOp == llvm::AtomicRMWInst::BinOp::BAD_BINOP &&
9065              XElemTy->isStructTy()) {
9066     LoadInst *OldVal =
9067         Builder.CreateLoad(XElemTy, X, X->getName() + ".atomic.load");
9068     OldVal->setAtomic(AO);
9069     const DataLayout &LoadDL = OldVal->getModule()->getDataLayout();
9070     unsigned LoadSize =
9071         LoadDL.getTypeStoreSize(OldVal->getPointerOperand()->getType());
9072 
9073     OpenMPIRBuilder::AtomicInfo atomicInfo(
9074         &Builder, XElemTy, LoadSize * 8, LoadSize * 8, OldVal->getAlign(),
9075         OldVal->getAlign(), true /* UseLibcall */, AllocaIP, X);
9076     auto AtomicLoadRes = atomicInfo.EmitAtomicLoadLibcall(AO);
9077     BasicBlock *CurBB = Builder.GetInsertBlock();
9078     Instruction *CurBBTI = CurBB->getTerminator();
9079     CurBBTI = CurBBTI ? CurBBTI : Builder.CreateUnreachable();
9080     BasicBlock *ExitBB =
9081         CurBB->splitBasicBlock(CurBBTI, X->getName() + ".atomic.exit");
9082     BasicBlock *ContBB = CurBB->splitBasicBlock(CurBB->getTerminator(),
9083                                                 X->getName() + ".atomic.cont");
9084     ContBB->getTerminator()->eraseFromParent();
9085     Builder.restoreIP(AllocaIP);
9086     AllocaInst *NewAtomicAddr = Builder.CreateAlloca(XElemTy);
9087     NewAtomicAddr->setName(X->getName() + "x.new.val");
9088     Builder.SetInsertPoint(ContBB);
9089     llvm::PHINode *PHI = Builder.CreatePHI(OldVal->getType(), 2);
9090     PHI->addIncoming(AtomicLoadRes.first, CurBB);
9091     Value *OldExprVal = PHI;
9092     Expected<Value *> CBResult = UpdateOp(OldExprVal, Builder);
9093     if (!CBResult)
9094       return CBResult.takeError();
9095     Value *Upd = *CBResult;
9096     Builder.CreateStore(Upd, NewAtomicAddr);
9097     AtomicOrdering Failure =
9098         llvm::AtomicCmpXchgInst::getStrongestFailureOrdering(AO);
9099     auto Result = atomicInfo.EmitAtomicCompareExchangeLibcall(
9100         AtomicLoadRes.second, NewAtomicAddr, AO, Failure);
9101     LoadInst *PHILoad = Builder.CreateLoad(XElemTy, Result.first);
9102     PHI->addIncoming(PHILoad, Builder.GetInsertBlock());
9103     Builder.CreateCondBr(Result.second, ExitBB, ContBB);
9104     OldVal->eraseFromParent();
9105     Res.first = OldExprVal;
9106     Res.second = Upd;
9107 
9108     if (UnreachableInst *ExitTI =
9109             dyn_cast<UnreachableInst>(ExitBB->getTerminator())) {
9110       CurBBTI->eraseFromParent();
9111       Builder.SetInsertPoint(ExitBB);
9112     } else {
9113       Builder.SetInsertPoint(ExitTI);
9114     }
9115   } else {
9116     IntegerType *IntCastTy =
9117         IntegerType::get(M.getContext(), XElemTy->getScalarSizeInBits());
9118     LoadInst *OldVal =
9119         Builder.CreateLoad(IntCastTy, X, X->getName() + ".atomic.load");
9120     OldVal->setAtomic(AO);
9121     // CurBB
9122     // |     /---\
9123 		// ContBB    |
9124     // |     \---/
9125     // ExitBB
9126     BasicBlock *CurBB = Builder.GetInsertBlock();
9127     Instruction *CurBBTI = CurBB->getTerminator();
9128     CurBBTI = CurBBTI ? CurBBTI : Builder.CreateUnreachable();
9129     BasicBlock *ExitBB =
9130         CurBB->splitBasicBlock(CurBBTI, X->getName() + ".atomic.exit");
9131     BasicBlock *ContBB = CurBB->splitBasicBlock(CurBB->getTerminator(),
9132                                                 X->getName() + ".atomic.cont");
9133     ContBB->getTerminator()->eraseFromParent();
9134     Builder.restoreIP(AllocaIP);
9135     AllocaInst *NewAtomicAddr = Builder.CreateAlloca(XElemTy);
9136     NewAtomicAddr->setName(X->getName() + "x.new.val");
9137     Builder.SetInsertPoint(ContBB);
9138     llvm::PHINode *PHI = Builder.CreatePHI(OldVal->getType(), 2);
9139     PHI->addIncoming(OldVal, CurBB);
9140     bool IsIntTy = XElemTy->isIntegerTy();
9141     Value *OldExprVal = PHI;
9142     if (!IsIntTy) {
9143       if (XElemTy->isFloatingPointTy()) {
9144         OldExprVal = Builder.CreateBitCast(PHI, XElemTy,
9145                                            X->getName() + ".atomic.fltCast");
9146       } else {
9147         OldExprVal = Builder.CreateIntToPtr(PHI, XElemTy,
9148                                             X->getName() + ".atomic.ptrCast");
9149       }
9150     }
9151 
9152     Expected<Value *> CBResult = UpdateOp(OldExprVal, Builder);
9153     if (!CBResult)
9154       return CBResult.takeError();
9155     Value *Upd = *CBResult;
9156     Builder.CreateStore(Upd, NewAtomicAddr);
9157     LoadInst *DesiredVal = Builder.CreateLoad(IntCastTy, NewAtomicAddr);
9158     AtomicOrdering Failure =
9159         llvm::AtomicCmpXchgInst::getStrongestFailureOrdering(AO);
9160     AtomicCmpXchgInst *Result = Builder.CreateAtomicCmpXchg(
9161         X, PHI, DesiredVal, llvm::MaybeAlign(), AO, Failure);
9162     Result->setVolatile(VolatileX);
9163     Value *PreviousVal = Builder.CreateExtractValue(Result, /*Idxs=*/0);
9164     Value *SuccessFailureVal = Builder.CreateExtractValue(Result, /*Idxs=*/1);
9165     PHI->addIncoming(PreviousVal, Builder.GetInsertBlock());
9166     Builder.CreateCondBr(SuccessFailureVal, ExitBB, ContBB);
9167 
9168     Res.first = OldExprVal;
9169     Res.second = Upd;
9170 
9171     // set Insertion point in exit block
9172     if (UnreachableInst *ExitTI =
9173             dyn_cast<UnreachableInst>(ExitBB->getTerminator())) {
9174       CurBBTI->eraseFromParent();
9175       Builder.SetInsertPoint(ExitBB);
9176     } else {
9177       Builder.SetInsertPoint(ExitTI);
9178     }
9179   }
9180 
9181   return Res;
9182 }
9183 
createAtomicCapture(const LocationDescription & Loc,InsertPointTy AllocaIP,AtomicOpValue & X,AtomicOpValue & V,Value * Expr,AtomicOrdering AO,AtomicRMWInst::BinOp RMWOp,AtomicUpdateCallbackTy & UpdateOp,bool UpdateExpr,bool IsPostfixUpdate,bool IsXBinopExpr)9184 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createAtomicCapture(
9185     const LocationDescription &Loc, InsertPointTy AllocaIP, AtomicOpValue &X,
9186     AtomicOpValue &V, Value *Expr, AtomicOrdering AO,
9187     AtomicRMWInst::BinOp RMWOp, AtomicUpdateCallbackTy &UpdateOp,
9188     bool UpdateExpr, bool IsPostfixUpdate, bool IsXBinopExpr) {
9189   if (!updateToLocation(Loc))
9190     return Loc.IP;
9191 
9192   LLVM_DEBUG({
9193     Type *XTy = X.Var->getType();
9194     assert(XTy->isPointerTy() &&
9195            "OMP Atomic expects a pointer to target memory");
9196     Type *XElemTy = X.ElemTy;
9197     assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
9198             XElemTy->isPointerTy()) &&
9199            "OMP atomic capture expected a scalar type");
9200     assert((RMWOp != AtomicRMWInst::Max) && (RMWOp != AtomicRMWInst::Min) &&
9201            "OpenMP atomic does not support LT or GT operations");
9202   });
9203 
9204   // If UpdateExpr is 'x' updated with some `expr` not based on 'x',
9205   // 'x' is simply atomically rewritten with 'expr'.
9206   AtomicRMWInst::BinOp AtomicOp = (UpdateExpr ? RMWOp : AtomicRMWInst::Xchg);
9207   Expected<std::pair<Value *, Value *>> AtomicResult =
9208       emitAtomicUpdate(AllocaIP, X.Var, X.ElemTy, Expr, AO, AtomicOp, UpdateOp,
9209                        X.IsVolatile, IsXBinopExpr);
9210   if (!AtomicResult)
9211     return AtomicResult.takeError();
9212   Value *CapturedVal =
9213       (IsPostfixUpdate ? AtomicResult->first : AtomicResult->second);
9214   Builder.CreateStore(CapturedVal, V.Var, V.IsVolatile);
9215 
9216   checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Capture);
9217   return Builder.saveIP();
9218 }
9219 
createAtomicCompare(const LocationDescription & Loc,AtomicOpValue & X,AtomicOpValue & V,AtomicOpValue & R,Value * E,Value * D,AtomicOrdering AO,omp::OMPAtomicCompareOp Op,bool IsXBinopExpr,bool IsPostfixUpdate,bool IsFailOnly)9220 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare(
9221     const LocationDescription &Loc, AtomicOpValue &X, AtomicOpValue &V,
9222     AtomicOpValue &R, Value *E, Value *D, AtomicOrdering AO,
9223     omp::OMPAtomicCompareOp Op, bool IsXBinopExpr, bool IsPostfixUpdate,
9224     bool IsFailOnly) {
9225 
9226   AtomicOrdering Failure = AtomicCmpXchgInst::getStrongestFailureOrdering(AO);
9227   return createAtomicCompare(Loc, X, V, R, E, D, AO, Op, IsXBinopExpr,
9228                              IsPostfixUpdate, IsFailOnly, Failure);
9229 }
9230 
createAtomicCompare(const LocationDescription & Loc,AtomicOpValue & X,AtomicOpValue & V,AtomicOpValue & R,Value * E,Value * D,AtomicOrdering AO,omp::OMPAtomicCompareOp Op,bool IsXBinopExpr,bool IsPostfixUpdate,bool IsFailOnly,AtomicOrdering Failure)9231 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare(
9232     const LocationDescription &Loc, AtomicOpValue &X, AtomicOpValue &V,
9233     AtomicOpValue &R, Value *E, Value *D, AtomicOrdering AO,
9234     omp::OMPAtomicCompareOp Op, bool IsXBinopExpr, bool IsPostfixUpdate,
9235     bool IsFailOnly, AtomicOrdering Failure) {
9236 
9237   if (!updateToLocation(Loc))
9238     return Loc.IP;
9239 
9240   assert(X.Var->getType()->isPointerTy() &&
9241          "OMP atomic expects a pointer to target memory");
9242   // compare capture
9243   if (V.Var) {
9244     assert(V.Var->getType()->isPointerTy() && "v.var must be of pointer type");
9245     assert(V.ElemTy == X.ElemTy && "x and v must be of same type");
9246   }
9247 
9248   bool IsInteger = E->getType()->isIntegerTy();
9249 
9250   if (Op == OMPAtomicCompareOp::EQ) {
9251     AtomicCmpXchgInst *Result = nullptr;
9252     if (!IsInteger) {
9253       IntegerType *IntCastTy =
9254           IntegerType::get(M.getContext(), X.ElemTy->getScalarSizeInBits());
9255       Value *EBCast = Builder.CreateBitCast(E, IntCastTy);
9256       Value *DBCast = Builder.CreateBitCast(D, IntCastTy);
9257       Result = Builder.CreateAtomicCmpXchg(X.Var, EBCast, DBCast, MaybeAlign(),
9258                                            AO, Failure);
9259     } else {
9260       Result =
9261           Builder.CreateAtomicCmpXchg(X.Var, E, D, MaybeAlign(), AO, Failure);
9262     }
9263 
9264     if (V.Var) {
9265       Value *OldValue = Builder.CreateExtractValue(Result, /*Idxs=*/0);
9266       if (!IsInteger)
9267         OldValue = Builder.CreateBitCast(OldValue, X.ElemTy);
9268       assert(OldValue->getType() == V.ElemTy &&
9269              "OldValue and V must be of same type");
9270       if (IsPostfixUpdate) {
9271         Builder.CreateStore(OldValue, V.Var, V.IsVolatile);
9272       } else {
9273         Value *SuccessOrFail = Builder.CreateExtractValue(Result, /*Idxs=*/1);
9274         if (IsFailOnly) {
9275           // CurBB----
9276           //   |     |
9277           //   v     |
9278           // ContBB  |
9279           //   |     |
9280           //   v     |
9281           // ExitBB <-
9282           //
9283           // where ContBB only contains the store of old value to 'v'.
9284           BasicBlock *CurBB = Builder.GetInsertBlock();
9285           Instruction *CurBBTI = CurBB->getTerminator();
9286           CurBBTI = CurBBTI ? CurBBTI : Builder.CreateUnreachable();
9287           BasicBlock *ExitBB = CurBB->splitBasicBlock(
9288               CurBBTI, X.Var->getName() + ".atomic.exit");
9289           BasicBlock *ContBB = CurBB->splitBasicBlock(
9290               CurBB->getTerminator(), X.Var->getName() + ".atomic.cont");
9291           ContBB->getTerminator()->eraseFromParent();
9292           CurBB->getTerminator()->eraseFromParent();
9293 
9294           Builder.CreateCondBr(SuccessOrFail, ExitBB, ContBB);
9295 
9296           Builder.SetInsertPoint(ContBB);
9297           Builder.CreateStore(OldValue, V.Var);
9298           Builder.CreateBr(ExitBB);
9299 
9300           if (UnreachableInst *ExitTI =
9301                   dyn_cast<UnreachableInst>(ExitBB->getTerminator())) {
9302             CurBBTI->eraseFromParent();
9303             Builder.SetInsertPoint(ExitBB);
9304           } else {
9305             Builder.SetInsertPoint(ExitTI);
9306           }
9307         } else {
9308           Value *CapturedValue =
9309               Builder.CreateSelect(SuccessOrFail, E, OldValue);
9310           Builder.CreateStore(CapturedValue, V.Var, V.IsVolatile);
9311         }
9312       }
9313     }
9314     // The comparison result has to be stored.
9315     if (R.Var) {
9316       assert(R.Var->getType()->isPointerTy() &&
9317              "r.var must be of pointer type");
9318       assert(R.ElemTy->isIntegerTy() && "r must be of integral type");
9319 
9320       Value *SuccessFailureVal = Builder.CreateExtractValue(Result, /*Idxs=*/1);
9321       Value *ResultCast = R.IsSigned
9322                               ? Builder.CreateSExt(SuccessFailureVal, R.ElemTy)
9323                               : Builder.CreateZExt(SuccessFailureVal, R.ElemTy);
9324       Builder.CreateStore(ResultCast, R.Var, R.IsVolatile);
9325     }
9326   } else {
9327     assert((Op == OMPAtomicCompareOp::MAX || Op == OMPAtomicCompareOp::MIN) &&
9328            "Op should be either max or min at this point");
9329     assert(!IsFailOnly && "IsFailOnly is only valid when the comparison is ==");
9330 
9331     // Reverse the ordop as the OpenMP forms are different from LLVM forms.
9332     // Let's take max as example.
9333     // OpenMP form:
9334     // x = x > expr ? expr : x;
9335     // LLVM form:
9336     // *ptr = *ptr > val ? *ptr : val;
9337     // We need to transform to LLVM form.
9338     // x = x <= expr ? x : expr;
9339     AtomicRMWInst::BinOp NewOp;
9340     if (IsXBinopExpr) {
9341       if (IsInteger) {
9342         if (X.IsSigned)
9343           NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::Min
9344                                                 : AtomicRMWInst::Max;
9345         else
9346           NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::UMin
9347                                                 : AtomicRMWInst::UMax;
9348       } else {
9349         NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::FMin
9350                                               : AtomicRMWInst::FMax;
9351       }
9352     } else {
9353       if (IsInteger) {
9354         if (X.IsSigned)
9355           NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::Max
9356                                                 : AtomicRMWInst::Min;
9357         else
9358           NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::UMax
9359                                                 : AtomicRMWInst::UMin;
9360       } else {
9361         NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::FMax
9362                                               : AtomicRMWInst::FMin;
9363       }
9364     }
9365 
9366     AtomicRMWInst *OldValue =
9367         Builder.CreateAtomicRMW(NewOp, X.Var, E, MaybeAlign(), AO);
9368     if (V.Var) {
9369       Value *CapturedValue = nullptr;
9370       if (IsPostfixUpdate) {
9371         CapturedValue = OldValue;
9372       } else {
9373         CmpInst::Predicate Pred;
9374         switch (NewOp) {
9375         case AtomicRMWInst::Max:
9376           Pred = CmpInst::ICMP_SGT;
9377           break;
9378         case AtomicRMWInst::UMax:
9379           Pred = CmpInst::ICMP_UGT;
9380           break;
9381         case AtomicRMWInst::FMax:
9382           Pred = CmpInst::FCMP_OGT;
9383           break;
9384         case AtomicRMWInst::Min:
9385           Pred = CmpInst::ICMP_SLT;
9386           break;
9387         case AtomicRMWInst::UMin:
9388           Pred = CmpInst::ICMP_ULT;
9389           break;
9390         case AtomicRMWInst::FMin:
9391           Pred = CmpInst::FCMP_OLT;
9392           break;
9393         default:
9394           llvm_unreachable("unexpected comparison op");
9395         }
9396         Value *NonAtomicCmp = Builder.CreateCmp(Pred, OldValue, E);
9397         CapturedValue = Builder.CreateSelect(NonAtomicCmp, E, OldValue);
9398       }
9399       Builder.CreateStore(CapturedValue, V.Var, V.IsVolatile);
9400     }
9401   }
9402 
9403   checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Compare);
9404 
9405   return Builder.saveIP();
9406 }
9407 
9408 OpenMPIRBuilder::InsertPointOrErrorTy
createTeams(const LocationDescription & Loc,BodyGenCallbackTy BodyGenCB,Value * NumTeamsLower,Value * NumTeamsUpper,Value * ThreadLimit,Value * IfExpr)9409 OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
9410                              BodyGenCallbackTy BodyGenCB, Value *NumTeamsLower,
9411                              Value *NumTeamsUpper, Value *ThreadLimit,
9412                              Value *IfExpr) {
9413   if (!updateToLocation(Loc))
9414     return InsertPointTy();
9415 
9416   uint32_t SrcLocStrSize;
9417   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
9418   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
9419   Function *CurrentFunction = Builder.GetInsertBlock()->getParent();
9420 
9421   // Outer allocation basicblock is the entry block of the current function.
9422   BasicBlock &OuterAllocaBB = CurrentFunction->getEntryBlock();
9423   if (&OuterAllocaBB == Builder.GetInsertBlock()) {
9424     BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "teams.entry");
9425     Builder.SetInsertPoint(BodyBB, BodyBB->begin());
9426   }
9427 
9428   // The current basic block is split into four basic blocks. After outlining,
9429   // they will be mapped as follows:
9430   // ```
9431   // def current_fn() {
9432   //   current_basic_block:
9433   //     br label %teams.exit
9434   //   teams.exit:
9435   //     ; instructions after teams
9436   // }
9437   //
9438   // def outlined_fn() {
9439   //   teams.alloca:
9440   //     br label %teams.body
9441   //   teams.body:
9442   //     ; instructions within teams body
9443   // }
9444   // ```
9445   BasicBlock *ExitBB = splitBB(Builder, /*CreateBranch=*/true, "teams.exit");
9446   BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "teams.body");
9447   BasicBlock *AllocaBB =
9448       splitBB(Builder, /*CreateBranch=*/true, "teams.alloca");
9449 
9450   bool SubClausesPresent =
9451       (NumTeamsLower || NumTeamsUpper || ThreadLimit || IfExpr);
9452   // Push num_teams
9453   if (!Config.isTargetDevice() && SubClausesPresent) {
9454     assert((NumTeamsLower == nullptr || NumTeamsUpper != nullptr) &&
9455            "if lowerbound is non-null, then upperbound must also be non-null "
9456            "for bounds on num_teams");
9457 
9458     if (NumTeamsUpper == nullptr)
9459       NumTeamsUpper = Builder.getInt32(0);
9460 
9461     if (NumTeamsLower == nullptr)
9462       NumTeamsLower = NumTeamsUpper;
9463 
9464     if (IfExpr) {
9465       assert(IfExpr->getType()->isIntegerTy() &&
9466              "argument to if clause must be an integer value");
9467 
9468       // upper = ifexpr ? upper : 1
9469       if (IfExpr->getType() != Int1)
9470         IfExpr = Builder.CreateICmpNE(IfExpr,
9471                                       ConstantInt::get(IfExpr->getType(), 0));
9472       NumTeamsUpper = Builder.CreateSelect(
9473           IfExpr, NumTeamsUpper, Builder.getInt32(1), "numTeamsUpper");
9474 
9475       // lower = ifexpr ? lower : 1
9476       NumTeamsLower = Builder.CreateSelect(
9477           IfExpr, NumTeamsLower, Builder.getInt32(1), "numTeamsLower");
9478     }
9479 
9480     if (ThreadLimit == nullptr)
9481       ThreadLimit = Builder.getInt32(0);
9482 
9483     Value *ThreadNum = getOrCreateThreadID(Ident);
9484     Builder.CreateCall(
9485         getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_num_teams_51),
9486         {Ident, ThreadNum, NumTeamsLower, NumTeamsUpper, ThreadLimit});
9487   }
9488   // Generate the body of teams.
9489   InsertPointTy AllocaIP(AllocaBB, AllocaBB->begin());
9490   InsertPointTy CodeGenIP(BodyBB, BodyBB->begin());
9491   if (Error Err = BodyGenCB(AllocaIP, CodeGenIP))
9492     return Err;
9493 
9494   OutlineInfo OI;
9495   OI.EntryBB = AllocaBB;
9496   OI.ExitBB = ExitBB;
9497   OI.OuterAllocaBB = &OuterAllocaBB;
9498 
9499   // Insert fake values for global tid and bound tid.
9500   SmallVector<Instruction *, 8> ToBeDeleted;
9501   InsertPointTy OuterAllocaIP(&OuterAllocaBB, OuterAllocaBB.begin());
9502   OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
9503       Builder, OuterAllocaIP, ToBeDeleted, AllocaIP, "gid", true));
9504   OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
9505       Builder, OuterAllocaIP, ToBeDeleted, AllocaIP, "tid", true));
9506 
9507   auto HostPostOutlineCB = [this, Ident,
9508                             ToBeDeleted](Function &OutlinedFn) mutable {
9509     // The stale call instruction will be replaced with a new call instruction
9510     // for runtime call with the outlined function.
9511 
9512     assert(OutlinedFn.hasOneUse() &&
9513            "there must be a single user for the outlined function");
9514     CallInst *StaleCI = cast<CallInst>(OutlinedFn.user_back());
9515     ToBeDeleted.push_back(StaleCI);
9516 
9517     assert((OutlinedFn.arg_size() == 2 || OutlinedFn.arg_size() == 3) &&
9518            "Outlined function must have two or three arguments only");
9519 
9520     bool HasShared = OutlinedFn.arg_size() == 3;
9521 
9522     OutlinedFn.getArg(0)->setName("global.tid.ptr");
9523     OutlinedFn.getArg(1)->setName("bound.tid.ptr");
9524     if (HasShared)
9525       OutlinedFn.getArg(2)->setName("data");
9526 
9527     // Call to the runtime function for teams in the current function.
9528     assert(StaleCI && "Error while outlining - no CallInst user found for the "
9529                       "outlined function.");
9530     Builder.SetInsertPoint(StaleCI);
9531     SmallVector<Value *> Args = {
9532         Ident, Builder.getInt32(StaleCI->arg_size() - 2), &OutlinedFn};
9533     if (HasShared)
9534       Args.push_back(StaleCI->getArgOperand(2));
9535     Builder.CreateCall(getOrCreateRuntimeFunctionPtr(
9536                            omp::RuntimeFunction::OMPRTL___kmpc_fork_teams),
9537                        Args);
9538 
9539     for (Instruction *I : llvm::reverse(ToBeDeleted))
9540       I->eraseFromParent();
9541   };
9542 
9543   if (!Config.isTargetDevice())
9544     OI.PostOutlineCB = HostPostOutlineCB;
9545 
9546   addOutlineInfo(std::move(OI));
9547 
9548   Builder.SetInsertPoint(ExitBB, ExitBB->begin());
9549 
9550   return Builder.saveIP();
9551 }
9552 
9553 OpenMPIRBuilder::InsertPointOrErrorTy
createDistribute(const LocationDescription & Loc,InsertPointTy OuterAllocaIP,BodyGenCallbackTy BodyGenCB)9554 OpenMPIRBuilder::createDistribute(const LocationDescription &Loc,
9555                                   InsertPointTy OuterAllocaIP,
9556                                   BodyGenCallbackTy BodyGenCB) {
9557   if (!updateToLocation(Loc))
9558     return InsertPointTy();
9559 
9560   BasicBlock *OuterAllocaBB = OuterAllocaIP.getBlock();
9561 
9562   if (OuterAllocaBB == Builder.GetInsertBlock()) {
9563     BasicBlock *BodyBB =
9564         splitBB(Builder, /*CreateBranch=*/true, "distribute.entry");
9565     Builder.SetInsertPoint(BodyBB, BodyBB->begin());
9566   }
9567   BasicBlock *ExitBB =
9568       splitBB(Builder, /*CreateBranch=*/true, "distribute.exit");
9569   BasicBlock *BodyBB =
9570       splitBB(Builder, /*CreateBranch=*/true, "distribute.body");
9571   BasicBlock *AllocaBB =
9572       splitBB(Builder, /*CreateBranch=*/true, "distribute.alloca");
9573 
9574   // Generate the body of distribute clause
9575   InsertPointTy AllocaIP(AllocaBB, AllocaBB->begin());
9576   InsertPointTy CodeGenIP(BodyBB, BodyBB->begin());
9577   if (Error Err = BodyGenCB(AllocaIP, CodeGenIP))
9578     return Err;
9579 
9580   OutlineInfo OI;
9581   OI.OuterAllocaBB = OuterAllocaIP.getBlock();
9582   OI.EntryBB = AllocaBB;
9583   OI.ExitBB = ExitBB;
9584 
9585   addOutlineInfo(std::move(OI));
9586   Builder.SetInsertPoint(ExitBB, ExitBB->begin());
9587 
9588   return Builder.saveIP();
9589 }
9590 
9591 GlobalVariable *
createOffloadMapnames(SmallVectorImpl<llvm::Constant * > & Names,std::string VarName)9592 OpenMPIRBuilder::createOffloadMapnames(SmallVectorImpl<llvm::Constant *> &Names,
9593                                        std::string VarName) {
9594   llvm::Constant *MapNamesArrayInit = llvm::ConstantArray::get(
9595       llvm::ArrayType::get(llvm::PointerType::getUnqual(M.getContext()),
9596                            Names.size()),
9597       Names);
9598   auto *MapNamesArrayGlobal = new llvm::GlobalVariable(
9599       M, MapNamesArrayInit->getType(),
9600       /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, MapNamesArrayInit,
9601       VarName);
9602   return MapNamesArrayGlobal;
9603 }
9604 
9605 // Create all simple and struct types exposed by the runtime and remember
9606 // the llvm::PointerTypes of them for easy access later.
initializeTypes(Module & M)9607 void OpenMPIRBuilder::initializeTypes(Module &M) {
9608   LLVMContext &Ctx = M.getContext();
9609   StructType *T;
9610 #define OMP_TYPE(VarName, InitValue) VarName = InitValue;
9611 #define OMP_ARRAY_TYPE(VarName, ElemTy, ArraySize)                             \
9612   VarName##Ty = ArrayType::get(ElemTy, ArraySize);                             \
9613   VarName##PtrTy = PointerType::getUnqual(Ctx);
9614 #define OMP_FUNCTION_TYPE(VarName, IsVarArg, ReturnType, ...)                  \
9615   VarName = FunctionType::get(ReturnType, {__VA_ARGS__}, IsVarArg);            \
9616   VarName##Ptr = PointerType::getUnqual(Ctx);
9617 #define OMP_STRUCT_TYPE(VarName, StructName, Packed, ...)                      \
9618   T = StructType::getTypeByName(Ctx, StructName);                              \
9619   if (!T)                                                                      \
9620     T = StructType::create(Ctx, {__VA_ARGS__}, StructName, Packed);            \
9621   VarName = T;                                                                 \
9622   VarName##Ptr = PointerType::getUnqual(Ctx);
9623 #include "llvm/Frontend/OpenMP/OMPKinds.def"
9624 }
9625 
collectBlocks(SmallPtrSetImpl<BasicBlock * > & BlockSet,SmallVectorImpl<BasicBlock * > & BlockVector)9626 void OpenMPIRBuilder::OutlineInfo::collectBlocks(
9627     SmallPtrSetImpl<BasicBlock *> &BlockSet,
9628     SmallVectorImpl<BasicBlock *> &BlockVector) {
9629   SmallVector<BasicBlock *, 32> Worklist;
9630   BlockSet.insert(EntryBB);
9631   BlockSet.insert(ExitBB);
9632 
9633   Worklist.push_back(EntryBB);
9634   while (!Worklist.empty()) {
9635     BasicBlock *BB = Worklist.pop_back_val();
9636     BlockVector.push_back(BB);
9637     for (BasicBlock *SuccBB : successors(BB))
9638       if (BlockSet.insert(SuccBB).second)
9639         Worklist.push_back(SuccBB);
9640   }
9641 }
9642 
createOffloadEntry(Constant * ID,Constant * Addr,uint64_t Size,int32_t Flags,GlobalValue::LinkageTypes,StringRef Name)9643 void OpenMPIRBuilder::createOffloadEntry(Constant *ID, Constant *Addr,
9644                                          uint64_t Size, int32_t Flags,
9645                                          GlobalValue::LinkageTypes,
9646                                          StringRef Name) {
9647   if (!Config.isGPU()) {
9648     llvm::offloading::emitOffloadingEntry(
9649         M, object::OffloadKind::OFK_OpenMP, ID,
9650         Name.empty() ? Addr->getName() : Name, Size, Flags, /*Data=*/0);
9651     return;
9652   }
9653   // TODO: Add support for global variables on the device after declare target
9654   // support.
9655   Function *Fn = dyn_cast<Function>(Addr);
9656   if (!Fn)
9657     return;
9658 
9659   // Add a function attribute for the kernel.
9660   Fn->addFnAttr("kernel");
9661   if (T.isAMDGCN())
9662     Fn->addFnAttr("uniform-work-group-size", "true");
9663   Fn->addFnAttr(Attribute::MustProgress);
9664 }
9665 
9666 // We only generate metadata for function that contain target regions.
createOffloadEntriesAndInfoMetadata(EmitMetadataErrorReportFunctionTy & ErrorFn)9667 void OpenMPIRBuilder::createOffloadEntriesAndInfoMetadata(
9668     EmitMetadataErrorReportFunctionTy &ErrorFn) {
9669 
9670   // If there are no entries, we don't need to do anything.
9671   if (OffloadInfoManager.empty())
9672     return;
9673 
9674   LLVMContext &C = M.getContext();
9675   SmallVector<std::pair<const OffloadEntriesInfoManager::OffloadEntryInfo *,
9676                         TargetRegionEntryInfo>,
9677               16>
9678       OrderedEntries(OffloadInfoManager.size());
9679 
9680   // Auxiliary methods to create metadata values and strings.
9681   auto &&GetMDInt = [this](unsigned V) {
9682     return ConstantAsMetadata::get(ConstantInt::get(Builder.getInt32Ty(), V));
9683   };
9684 
9685   auto &&GetMDString = [&C](StringRef V) { return MDString::get(C, V); };
9686 
9687   // Create the offloading info metadata node.
9688   NamedMDNode *MD = M.getOrInsertNamedMetadata("omp_offload.info");
9689   auto &&TargetRegionMetadataEmitter =
9690       [&C, MD, &OrderedEntries, &GetMDInt, &GetMDString](
9691           const TargetRegionEntryInfo &EntryInfo,
9692           const OffloadEntriesInfoManager::OffloadEntryInfoTargetRegion &E) {
9693         // Generate metadata for target regions. Each entry of this metadata
9694         // contains:
9695         // - Entry 0 -> Kind of this type of metadata (0).
9696         // - Entry 1 -> Device ID of the file where the entry was identified.
9697         // - Entry 2 -> File ID of the file where the entry was identified.
9698         // - Entry 3 -> Mangled name of the function where the entry was
9699         // identified.
9700         // - Entry 4 -> Line in the file where the entry was identified.
9701         // - Entry 5 -> Count of regions at this DeviceID/FilesID/Line.
9702         // - Entry 6 -> Order the entry was created.
9703         // The first element of the metadata node is the kind.
9704         Metadata *Ops[] = {
9705             GetMDInt(E.getKind()),      GetMDInt(EntryInfo.DeviceID),
9706             GetMDInt(EntryInfo.FileID), GetMDString(EntryInfo.ParentName),
9707             GetMDInt(EntryInfo.Line),   GetMDInt(EntryInfo.Count),
9708             GetMDInt(E.getOrder())};
9709 
9710         // Save this entry in the right position of the ordered entries array.
9711         OrderedEntries[E.getOrder()] = std::make_pair(&E, EntryInfo);
9712 
9713         // Add metadata to the named metadata node.
9714         MD->addOperand(MDNode::get(C, Ops));
9715       };
9716 
9717   OffloadInfoManager.actOnTargetRegionEntriesInfo(TargetRegionMetadataEmitter);
9718 
9719   // Create function that emits metadata for each device global variable entry;
9720   auto &&DeviceGlobalVarMetadataEmitter =
9721       [&C, &OrderedEntries, &GetMDInt, &GetMDString, MD](
9722           StringRef MangledName,
9723           const OffloadEntriesInfoManager::OffloadEntryInfoDeviceGlobalVar &E) {
9724         // Generate metadata for global variables. Each entry of this metadata
9725         // contains:
9726         // - Entry 0 -> Kind of this type of metadata (1).
9727         // - Entry 1 -> Mangled name of the variable.
9728         // - Entry 2 -> Declare target kind.
9729         // - Entry 3 -> Order the entry was created.
9730         // The first element of the metadata node is the kind.
9731         Metadata *Ops[] = {GetMDInt(E.getKind()), GetMDString(MangledName),
9732                            GetMDInt(E.getFlags()), GetMDInt(E.getOrder())};
9733 
9734         // Save this entry in the right position of the ordered entries array.
9735         TargetRegionEntryInfo varInfo(MangledName, 0, 0, 0);
9736         OrderedEntries[E.getOrder()] = std::make_pair(&E, varInfo);
9737 
9738         // Add metadata to the named metadata node.
9739         MD->addOperand(MDNode::get(C, Ops));
9740       };
9741 
9742   OffloadInfoManager.actOnDeviceGlobalVarEntriesInfo(
9743       DeviceGlobalVarMetadataEmitter);
9744 
9745   for (const auto &E : OrderedEntries) {
9746     assert(E.first && "All ordered entries must exist!");
9747     if (const auto *CE =
9748             dyn_cast<OffloadEntriesInfoManager::OffloadEntryInfoTargetRegion>(
9749                 E.first)) {
9750       if (!CE->getID() || !CE->getAddress()) {
9751         // Do not blame the entry if the parent funtion is not emitted.
9752         TargetRegionEntryInfo EntryInfo = E.second;
9753         StringRef FnName = EntryInfo.ParentName;
9754         if (!M.getNamedValue(FnName))
9755           continue;
9756         ErrorFn(EMIT_MD_TARGET_REGION_ERROR, EntryInfo);
9757         continue;
9758       }
9759       createOffloadEntry(CE->getID(), CE->getAddress(),
9760                          /*Size=*/0, CE->getFlags(),
9761                          GlobalValue::WeakAnyLinkage);
9762     } else if (const auto *CE = dyn_cast<
9763                    OffloadEntriesInfoManager::OffloadEntryInfoDeviceGlobalVar>(
9764                    E.first)) {
9765       OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind Flags =
9766           static_cast<OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind>(
9767               CE->getFlags());
9768       switch (Flags) {
9769       case OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter:
9770       case OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo:
9771         if (Config.isTargetDevice() && Config.hasRequiresUnifiedSharedMemory())
9772           continue;
9773         if (!CE->getAddress()) {
9774           ErrorFn(EMIT_MD_DECLARE_TARGET_ERROR, E.second);
9775           continue;
9776         }
9777         // The vaiable has no definition - no need to add the entry.
9778         if (CE->getVarSize() == 0)
9779           continue;
9780         break;
9781       case OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink:
9782         assert(((Config.isTargetDevice() && !CE->getAddress()) ||
9783                 (!Config.isTargetDevice() && CE->getAddress())) &&
9784                "Declaret target link address is set.");
9785         if (Config.isTargetDevice())
9786           continue;
9787         if (!CE->getAddress()) {
9788           ErrorFn(EMIT_MD_GLOBAL_VAR_LINK_ERROR, TargetRegionEntryInfo());
9789           continue;
9790         }
9791         break;
9792       default:
9793         break;
9794       }
9795 
9796       // Hidden or internal symbols on the device are not externally visible.
9797       // We should not attempt to register them by creating an offloading
9798       // entry. Indirect variables are handled separately on the device.
9799       if (auto *GV = dyn_cast<GlobalValue>(CE->getAddress()))
9800         if ((GV->hasLocalLinkage() || GV->hasHiddenVisibility()) &&
9801             Flags != OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirect)
9802           continue;
9803 
9804       // Indirect globals need to use a special name that doesn't match the name
9805       // of the associated host global.
9806       if (Flags == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirect)
9807         createOffloadEntry(CE->getAddress(), CE->getAddress(), CE->getVarSize(),
9808                            Flags, CE->getLinkage(), CE->getVarName());
9809       else
9810         createOffloadEntry(CE->getAddress(), CE->getAddress(), CE->getVarSize(),
9811                            Flags, CE->getLinkage());
9812 
9813     } else {
9814       llvm_unreachable("Unsupported entry kind.");
9815     }
9816   }
9817 
9818   // Emit requires directive globals to a special entry so the runtime can
9819   // register them when the device image is loaded.
9820   // TODO: This reduces the offloading entries to a 32-bit integer. Offloading
9821   //       entries should be redesigned to better suit this use-case.
9822   if (Config.hasRequiresFlags() && !Config.isTargetDevice())
9823     offloading::emitOffloadingEntry(
9824         M, object::OffloadKind::OFK_OpenMP,
9825         Constant::getNullValue(PointerType::getUnqual(M.getContext())),
9826         ".requires", /*Size=*/0,
9827         OffloadEntriesInfoManager::OMPTargetGlobalRegisterRequires,
9828         Config.getRequiresFlags());
9829 }
9830 
getTargetRegionEntryFnName(SmallVectorImpl<char> & Name,StringRef ParentName,unsigned DeviceID,unsigned FileID,unsigned Line,unsigned Count)9831 void TargetRegionEntryInfo::getTargetRegionEntryFnName(
9832     SmallVectorImpl<char> &Name, StringRef ParentName, unsigned DeviceID,
9833     unsigned FileID, unsigned Line, unsigned Count) {
9834   raw_svector_ostream OS(Name);
9835   OS << KernelNamePrefix << llvm::format("%x", DeviceID)
9836      << llvm::format("_%x_", FileID) << ParentName << "_l" << Line;
9837   if (Count)
9838     OS << "_" << Count;
9839 }
9840 
getTargetRegionEntryFnName(SmallVectorImpl<char> & Name,const TargetRegionEntryInfo & EntryInfo)9841 void OffloadEntriesInfoManager::getTargetRegionEntryFnName(
9842     SmallVectorImpl<char> &Name, const TargetRegionEntryInfo &EntryInfo) {
9843   unsigned NewCount = getTargetRegionEntryInfoCount(EntryInfo);
9844   TargetRegionEntryInfo::getTargetRegionEntryFnName(
9845       Name, EntryInfo.ParentName, EntryInfo.DeviceID, EntryInfo.FileID,
9846       EntryInfo.Line, NewCount);
9847 }
9848 
9849 TargetRegionEntryInfo
getTargetEntryUniqueInfo(FileIdentifierInfoCallbackTy CallBack,StringRef ParentName)9850 OpenMPIRBuilder::getTargetEntryUniqueInfo(FileIdentifierInfoCallbackTy CallBack,
9851                                           StringRef ParentName) {
9852   sys::fs::UniqueID ID(0xdeadf17e, 0);
9853   auto FileIDInfo = CallBack();
9854   uint64_t FileID = 0;
9855   std::error_code EC = sys::fs::getUniqueID(std::get<0>(FileIDInfo), ID);
9856   // If the inode ID could not be determined, create a hash value
9857   // the current file name and use that as an ID.
9858   if (EC)
9859     FileID = hash_value(std::get<0>(FileIDInfo));
9860   else
9861     FileID = ID.getFile();
9862 
9863   return TargetRegionEntryInfo(ParentName, ID.getDevice(), FileID,
9864                                std::get<1>(FileIDInfo));
9865 }
9866 
getFlagMemberOffset()9867 unsigned OpenMPIRBuilder::getFlagMemberOffset() {
9868   unsigned Offset = 0;
9869   for (uint64_t Remain =
9870            static_cast<std::underlying_type_t<omp::OpenMPOffloadMappingFlags>>(
9871                omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF);
9872        !(Remain & 1); Remain = Remain >> 1)
9873     Offset++;
9874   return Offset;
9875 }
9876 
9877 omp::OpenMPOffloadMappingFlags
getMemberOfFlag(unsigned Position)9878 OpenMPIRBuilder::getMemberOfFlag(unsigned Position) {
9879   // Rotate by getFlagMemberOffset() bits.
9880   return static_cast<omp::OpenMPOffloadMappingFlags>(((uint64_t)Position + 1)
9881                                                      << getFlagMemberOffset());
9882 }
9883 
setCorrectMemberOfFlag(omp::OpenMPOffloadMappingFlags & Flags,omp::OpenMPOffloadMappingFlags MemberOfFlag)9884 void OpenMPIRBuilder::setCorrectMemberOfFlag(
9885     omp::OpenMPOffloadMappingFlags &Flags,
9886     omp::OpenMPOffloadMappingFlags MemberOfFlag) {
9887   // If the entry is PTR_AND_OBJ but has not been marked with the special
9888   // placeholder value 0xFFFF in the MEMBER_OF field, then it should not be
9889   // marked as MEMBER_OF.
9890   if (static_cast<std::underlying_type_t<omp::OpenMPOffloadMappingFlags>>(
9891           Flags & omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ) &&
9892       static_cast<std::underlying_type_t<omp::OpenMPOffloadMappingFlags>>(
9893           (Flags & omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF) !=
9894           omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF))
9895     return;
9896 
9897   // Reset the placeholder value to prepare the flag for the assignment of the
9898   // proper MEMBER_OF value.
9899   Flags &= ~omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
9900   Flags |= MemberOfFlag;
9901 }
9902 
getAddrOfDeclareTargetVar(OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind CaptureClause,OffloadEntriesInfoManager::OMPTargetDeviceClauseKind DeviceClause,bool IsDeclaration,bool IsExternallyVisible,TargetRegionEntryInfo EntryInfo,StringRef MangledName,std::vector<GlobalVariable * > & GeneratedRefs,bool OpenMPSIMD,std::vector<Triple> TargetTriple,Type * LlvmPtrTy,std::function<Constant * ()> GlobalInitializer,std::function<GlobalValue::LinkageTypes ()> VariableLinkage)9903 Constant *OpenMPIRBuilder::getAddrOfDeclareTargetVar(
9904     OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind CaptureClause,
9905     OffloadEntriesInfoManager::OMPTargetDeviceClauseKind DeviceClause,
9906     bool IsDeclaration, bool IsExternallyVisible,
9907     TargetRegionEntryInfo EntryInfo, StringRef MangledName,
9908     std::vector<GlobalVariable *> &GeneratedRefs, bool OpenMPSIMD,
9909     std::vector<Triple> TargetTriple, Type *LlvmPtrTy,
9910     std::function<Constant *()> GlobalInitializer,
9911     std::function<GlobalValue::LinkageTypes()> VariableLinkage) {
9912   // TODO: convert this to utilise the IRBuilder Config rather than
9913   // a passed down argument.
9914   if (OpenMPSIMD)
9915     return nullptr;
9916 
9917   if (CaptureClause == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink ||
9918       ((CaptureClause == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo ||
9919         CaptureClause ==
9920             OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter) &&
9921        Config.hasRequiresUnifiedSharedMemory())) {
9922     SmallString<64> PtrName;
9923     {
9924       raw_svector_ostream OS(PtrName);
9925       OS << MangledName;
9926       if (!IsExternallyVisible)
9927         OS << format("_%x", EntryInfo.FileID);
9928       OS << "_decl_tgt_ref_ptr";
9929     }
9930 
9931     Value *Ptr = M.getNamedValue(PtrName);
9932 
9933     if (!Ptr) {
9934       GlobalValue *GlobalValue = M.getNamedValue(MangledName);
9935       Ptr = getOrCreateInternalVariable(LlvmPtrTy, PtrName);
9936 
9937       auto *GV = cast<GlobalVariable>(Ptr);
9938       GV->setLinkage(GlobalValue::WeakAnyLinkage);
9939 
9940       if (!Config.isTargetDevice()) {
9941         if (GlobalInitializer)
9942           GV->setInitializer(GlobalInitializer());
9943         else
9944           GV->setInitializer(GlobalValue);
9945       }
9946 
9947       registerTargetGlobalVariable(
9948           CaptureClause, DeviceClause, IsDeclaration, IsExternallyVisible,
9949           EntryInfo, MangledName, GeneratedRefs, OpenMPSIMD, TargetTriple,
9950           GlobalInitializer, VariableLinkage, LlvmPtrTy, cast<Constant>(Ptr));
9951     }
9952 
9953     return cast<Constant>(Ptr);
9954   }
9955 
9956   return nullptr;
9957 }
9958 
registerTargetGlobalVariable(OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind CaptureClause,OffloadEntriesInfoManager::OMPTargetDeviceClauseKind DeviceClause,bool IsDeclaration,bool IsExternallyVisible,TargetRegionEntryInfo EntryInfo,StringRef MangledName,std::vector<GlobalVariable * > & GeneratedRefs,bool OpenMPSIMD,std::vector<Triple> TargetTriple,std::function<Constant * ()> GlobalInitializer,std::function<GlobalValue::LinkageTypes ()> VariableLinkage,Type * LlvmPtrTy,Constant * Addr)9959 void OpenMPIRBuilder::registerTargetGlobalVariable(
9960     OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind CaptureClause,
9961     OffloadEntriesInfoManager::OMPTargetDeviceClauseKind DeviceClause,
9962     bool IsDeclaration, bool IsExternallyVisible,
9963     TargetRegionEntryInfo EntryInfo, StringRef MangledName,
9964     std::vector<GlobalVariable *> &GeneratedRefs, bool OpenMPSIMD,
9965     std::vector<Triple> TargetTriple,
9966     std::function<Constant *()> GlobalInitializer,
9967     std::function<GlobalValue::LinkageTypes()> VariableLinkage, Type *LlvmPtrTy,
9968     Constant *Addr) {
9969   if (DeviceClause != OffloadEntriesInfoManager::OMPTargetDeviceClauseAny ||
9970       (TargetTriple.empty() && !Config.isTargetDevice()))
9971     return;
9972 
9973   OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind Flags;
9974   StringRef VarName;
9975   int64_t VarSize;
9976   GlobalValue::LinkageTypes Linkage;
9977 
9978   if ((CaptureClause == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo ||
9979        CaptureClause ==
9980            OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter) &&
9981       !Config.hasRequiresUnifiedSharedMemory()) {
9982     Flags = OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
9983     VarName = MangledName;
9984     GlobalValue *LlvmVal = M.getNamedValue(VarName);
9985 
9986     if (!IsDeclaration)
9987       VarSize = divideCeil(
9988           M.getDataLayout().getTypeSizeInBits(LlvmVal->getValueType()), 8);
9989     else
9990       VarSize = 0;
9991     Linkage = (VariableLinkage) ? VariableLinkage() : LlvmVal->getLinkage();
9992 
9993     // This is a workaround carried over from Clang which prevents undesired
9994     // optimisation of internal variables.
9995     if (Config.isTargetDevice() &&
9996         (!IsExternallyVisible || Linkage == GlobalValue::LinkOnceODRLinkage)) {
9997       // Do not create a "ref-variable" if the original is not also available
9998       // on the host.
9999       if (!OffloadInfoManager.hasDeviceGlobalVarEntryInfo(VarName))
10000         return;
10001 
10002       std::string RefName = createPlatformSpecificName({VarName, "ref"});
10003 
10004       if (!M.getNamedValue(RefName)) {
10005         Constant *AddrRef =
10006             getOrCreateInternalVariable(Addr->getType(), RefName);
10007         auto *GvAddrRef = cast<GlobalVariable>(AddrRef);
10008         GvAddrRef->setConstant(true);
10009         GvAddrRef->setLinkage(GlobalValue::InternalLinkage);
10010         GvAddrRef->setInitializer(Addr);
10011         GeneratedRefs.push_back(GvAddrRef);
10012       }
10013     }
10014   } else {
10015     if (CaptureClause == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink)
10016       Flags = OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink;
10017     else
10018       Flags = OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
10019 
10020     if (Config.isTargetDevice()) {
10021       VarName = (Addr) ? Addr->getName() : "";
10022       Addr = nullptr;
10023     } else {
10024       Addr = getAddrOfDeclareTargetVar(
10025           CaptureClause, DeviceClause, IsDeclaration, IsExternallyVisible,
10026           EntryInfo, MangledName, GeneratedRefs, OpenMPSIMD, TargetTriple,
10027           LlvmPtrTy, GlobalInitializer, VariableLinkage);
10028       VarName = (Addr) ? Addr->getName() : "";
10029     }
10030     VarSize = M.getDataLayout().getPointerSize();
10031     Linkage = GlobalValue::WeakAnyLinkage;
10032   }
10033 
10034   OffloadInfoManager.registerDeviceGlobalVarEntryInfo(VarName, Addr, VarSize,
10035                                                       Flags, Linkage);
10036 }
10037 
10038 /// Loads all the offload entries information from the host IR
10039 /// metadata.
loadOffloadInfoMetadata(Module & M)10040 void OpenMPIRBuilder::loadOffloadInfoMetadata(Module &M) {
10041   // If we are in target mode, load the metadata from the host IR. This code has
10042   // to match the metadata creation in createOffloadEntriesAndInfoMetadata().
10043 
10044   NamedMDNode *MD = M.getNamedMetadata(ompOffloadInfoName);
10045   if (!MD)
10046     return;
10047 
10048   for (MDNode *MN : MD->operands()) {
10049     auto &&GetMDInt = [MN](unsigned Idx) {
10050       auto *V = cast<ConstantAsMetadata>(MN->getOperand(Idx));
10051       return cast<ConstantInt>(V->getValue())->getZExtValue();
10052     };
10053 
10054     auto &&GetMDString = [MN](unsigned Idx) {
10055       auto *V = cast<MDString>(MN->getOperand(Idx));
10056       return V->getString();
10057     };
10058 
10059     switch (GetMDInt(0)) {
10060     default:
10061       llvm_unreachable("Unexpected metadata!");
10062       break;
10063     case OffloadEntriesInfoManager::OffloadEntryInfo::
10064         OffloadingEntryInfoTargetRegion: {
10065       TargetRegionEntryInfo EntryInfo(/*ParentName=*/GetMDString(3),
10066                                       /*DeviceID=*/GetMDInt(1),
10067                                       /*FileID=*/GetMDInt(2),
10068                                       /*Line=*/GetMDInt(4),
10069                                       /*Count=*/GetMDInt(5));
10070       OffloadInfoManager.initializeTargetRegionEntryInfo(EntryInfo,
10071                                                          /*Order=*/GetMDInt(6));
10072       break;
10073     }
10074     case OffloadEntriesInfoManager::OffloadEntryInfo::
10075         OffloadingEntryInfoDeviceGlobalVar:
10076       OffloadInfoManager.initializeDeviceGlobalVarEntryInfo(
10077           /*MangledName=*/GetMDString(1),
10078           static_cast<OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind>(
10079               /*Flags=*/GetMDInt(2)),
10080           /*Order=*/GetMDInt(3));
10081       break;
10082     }
10083   }
10084 }
10085 
loadOffloadInfoMetadata(StringRef HostFilePath)10086 void OpenMPIRBuilder::loadOffloadInfoMetadata(StringRef HostFilePath) {
10087   if (HostFilePath.empty())
10088     return;
10089 
10090   auto Buf = MemoryBuffer::getFile(HostFilePath);
10091   if (std::error_code Err = Buf.getError()) {
10092     report_fatal_error(("error opening host file from host file path inside of "
10093                         "OpenMPIRBuilder: " +
10094                         Err.message())
10095                            .c_str());
10096   }
10097 
10098   LLVMContext Ctx;
10099   auto M = expectedToErrorOrAndEmitErrors(
10100       Ctx, parseBitcodeFile(Buf.get()->getMemBufferRef(), Ctx));
10101   if (std::error_code Err = M.getError()) {
10102     report_fatal_error(
10103         ("error parsing host file inside of OpenMPIRBuilder: " + Err.message())
10104             .c_str());
10105   }
10106 
10107   loadOffloadInfoMetadata(*M.get());
10108 }
10109 
10110 //===----------------------------------------------------------------------===//
10111 // OffloadEntriesInfoManager
10112 //===----------------------------------------------------------------------===//
10113 
empty() const10114 bool OffloadEntriesInfoManager::empty() const {
10115   return OffloadEntriesTargetRegion.empty() &&
10116          OffloadEntriesDeviceGlobalVar.empty();
10117 }
10118 
getTargetRegionEntryInfoCount(const TargetRegionEntryInfo & EntryInfo) const10119 unsigned OffloadEntriesInfoManager::getTargetRegionEntryInfoCount(
10120     const TargetRegionEntryInfo &EntryInfo) const {
10121   auto It = OffloadEntriesTargetRegionCount.find(
10122       getTargetRegionEntryCountKey(EntryInfo));
10123   if (It == OffloadEntriesTargetRegionCount.end())
10124     return 0;
10125   return It->second;
10126 }
10127 
incrementTargetRegionEntryInfoCount(const TargetRegionEntryInfo & EntryInfo)10128 void OffloadEntriesInfoManager::incrementTargetRegionEntryInfoCount(
10129     const TargetRegionEntryInfo &EntryInfo) {
10130   OffloadEntriesTargetRegionCount[getTargetRegionEntryCountKey(EntryInfo)] =
10131       EntryInfo.Count + 1;
10132 }
10133 
10134 /// Initialize target region entry.
initializeTargetRegionEntryInfo(const TargetRegionEntryInfo & EntryInfo,unsigned Order)10135 void OffloadEntriesInfoManager::initializeTargetRegionEntryInfo(
10136     const TargetRegionEntryInfo &EntryInfo, unsigned Order) {
10137   OffloadEntriesTargetRegion[EntryInfo] =
10138       OffloadEntryInfoTargetRegion(Order, /*Addr=*/nullptr, /*ID=*/nullptr,
10139                                    OMPTargetRegionEntryTargetRegion);
10140   ++OffloadingEntriesNum;
10141 }
10142 
registerTargetRegionEntryInfo(TargetRegionEntryInfo EntryInfo,Constant * Addr,Constant * ID,OMPTargetRegionEntryKind Flags)10143 void OffloadEntriesInfoManager::registerTargetRegionEntryInfo(
10144     TargetRegionEntryInfo EntryInfo, Constant *Addr, Constant *ID,
10145     OMPTargetRegionEntryKind Flags) {
10146   assert(EntryInfo.Count == 0 && "expected default EntryInfo");
10147 
10148   // Update the EntryInfo with the next available count for this location.
10149   EntryInfo.Count = getTargetRegionEntryInfoCount(EntryInfo);
10150 
10151   // If we are emitting code for a target, the entry is already initialized,
10152   // only has to be registered.
10153   if (OMPBuilder->Config.isTargetDevice()) {
10154     // This could happen if the device compilation is invoked standalone.
10155     if (!hasTargetRegionEntryInfo(EntryInfo)) {
10156       return;
10157     }
10158     auto &Entry = OffloadEntriesTargetRegion[EntryInfo];
10159     Entry.setAddress(Addr);
10160     Entry.setID(ID);
10161     Entry.setFlags(Flags);
10162   } else {
10163     if (Flags == OffloadEntriesInfoManager::OMPTargetRegionEntryTargetRegion &&
10164         hasTargetRegionEntryInfo(EntryInfo, /*IgnoreAddressId*/ true))
10165       return;
10166     assert(!hasTargetRegionEntryInfo(EntryInfo) &&
10167            "Target region entry already registered!");
10168     OffloadEntryInfoTargetRegion Entry(OffloadingEntriesNum, Addr, ID, Flags);
10169     OffloadEntriesTargetRegion[EntryInfo] = Entry;
10170     ++OffloadingEntriesNum;
10171   }
10172   incrementTargetRegionEntryInfoCount(EntryInfo);
10173 }
10174 
hasTargetRegionEntryInfo(TargetRegionEntryInfo EntryInfo,bool IgnoreAddressId) const10175 bool OffloadEntriesInfoManager::hasTargetRegionEntryInfo(
10176     TargetRegionEntryInfo EntryInfo, bool IgnoreAddressId) const {
10177 
10178   // Update the EntryInfo with the next available count for this location.
10179   EntryInfo.Count = getTargetRegionEntryInfoCount(EntryInfo);
10180 
10181   auto It = OffloadEntriesTargetRegion.find(EntryInfo);
10182   if (It == OffloadEntriesTargetRegion.end()) {
10183     return false;
10184   }
10185   // Fail if this entry is already registered.
10186   if (!IgnoreAddressId && (It->second.getAddress() || It->second.getID()))
10187     return false;
10188   return true;
10189 }
10190 
actOnTargetRegionEntriesInfo(const OffloadTargetRegionEntryInfoActTy & Action)10191 void OffloadEntriesInfoManager::actOnTargetRegionEntriesInfo(
10192     const OffloadTargetRegionEntryInfoActTy &Action) {
10193   // Scan all target region entries and perform the provided action.
10194   for (const auto &It : OffloadEntriesTargetRegion) {
10195     Action(It.first, It.second);
10196   }
10197 }
10198 
initializeDeviceGlobalVarEntryInfo(StringRef Name,OMPTargetGlobalVarEntryKind Flags,unsigned Order)10199 void OffloadEntriesInfoManager::initializeDeviceGlobalVarEntryInfo(
10200     StringRef Name, OMPTargetGlobalVarEntryKind Flags, unsigned Order) {
10201   OffloadEntriesDeviceGlobalVar.try_emplace(Name, Order, Flags);
10202   ++OffloadingEntriesNum;
10203 }
10204 
registerDeviceGlobalVarEntryInfo(StringRef VarName,Constant * Addr,int64_t VarSize,OMPTargetGlobalVarEntryKind Flags,GlobalValue::LinkageTypes Linkage)10205 void OffloadEntriesInfoManager::registerDeviceGlobalVarEntryInfo(
10206     StringRef VarName, Constant *Addr, int64_t VarSize,
10207     OMPTargetGlobalVarEntryKind Flags, GlobalValue::LinkageTypes Linkage) {
10208   if (OMPBuilder->Config.isTargetDevice()) {
10209     // This could happen if the device compilation is invoked standalone.
10210     if (!hasDeviceGlobalVarEntryInfo(VarName))
10211       return;
10212     auto &Entry = OffloadEntriesDeviceGlobalVar[VarName];
10213     if (Entry.getAddress() && hasDeviceGlobalVarEntryInfo(VarName)) {
10214       if (Entry.getVarSize() == 0) {
10215         Entry.setVarSize(VarSize);
10216         Entry.setLinkage(Linkage);
10217       }
10218       return;
10219     }
10220     Entry.setVarSize(VarSize);
10221     Entry.setLinkage(Linkage);
10222     Entry.setAddress(Addr);
10223   } else {
10224     if (hasDeviceGlobalVarEntryInfo(VarName)) {
10225       auto &Entry = OffloadEntriesDeviceGlobalVar[VarName];
10226       assert(Entry.isValid() && Entry.getFlags() == Flags &&
10227              "Entry not initialized!");
10228       if (Entry.getVarSize() == 0) {
10229         Entry.setVarSize(VarSize);
10230         Entry.setLinkage(Linkage);
10231       }
10232       return;
10233     }
10234     if (Flags == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirect)
10235       OffloadEntriesDeviceGlobalVar.try_emplace(VarName, OffloadingEntriesNum,
10236                                                 Addr, VarSize, Flags, Linkage,
10237                                                 VarName.str());
10238     else
10239       OffloadEntriesDeviceGlobalVar.try_emplace(
10240           VarName, OffloadingEntriesNum, Addr, VarSize, Flags, Linkage, "");
10241     ++OffloadingEntriesNum;
10242   }
10243 }
10244 
actOnDeviceGlobalVarEntriesInfo(const OffloadDeviceGlobalVarEntryInfoActTy & Action)10245 void OffloadEntriesInfoManager::actOnDeviceGlobalVarEntriesInfo(
10246     const OffloadDeviceGlobalVarEntryInfoActTy &Action) {
10247   // Scan all target region entries and perform the provided action.
10248   for (const auto &E : OffloadEntriesDeviceGlobalVar)
10249     Action(E.getKey(), E.getValue());
10250 }
10251 
10252 //===----------------------------------------------------------------------===//
10253 // CanonicalLoopInfo
10254 //===----------------------------------------------------------------------===//
10255 
collectControlBlocks(SmallVectorImpl<BasicBlock * > & BBs)10256 void CanonicalLoopInfo::collectControlBlocks(
10257     SmallVectorImpl<BasicBlock *> &BBs) {
10258   // We only count those BBs as control block for which we do not need to
10259   // reverse the CFG, i.e. not the loop body which can contain arbitrary control
10260   // flow. For consistency, this also means we do not add the Body block, which
10261   // is just the entry to the body code.
10262   BBs.reserve(BBs.size() + 6);
10263   BBs.append({getPreheader(), Header, Cond, Latch, Exit, getAfter()});
10264 }
10265 
getPreheader() const10266 BasicBlock *CanonicalLoopInfo::getPreheader() const {
10267   assert(isValid() && "Requires a valid canonical loop");
10268   for (BasicBlock *Pred : predecessors(Header)) {
10269     if (Pred != Latch)
10270       return Pred;
10271   }
10272   llvm_unreachable("Missing preheader");
10273 }
10274 
setTripCount(Value * TripCount)10275 void CanonicalLoopInfo::setTripCount(Value *TripCount) {
10276   assert(isValid() && "Requires a valid canonical loop");
10277 
10278   Instruction *CmpI = &getCond()->front();
10279   assert(isa<CmpInst>(CmpI) && "First inst must compare IV with TripCount");
10280   CmpI->setOperand(1, TripCount);
10281 
10282 #ifndef NDEBUG
10283   assertOK();
10284 #endif
10285 }
10286 
mapIndVar(llvm::function_ref<Value * (Instruction *)> Updater)10287 void CanonicalLoopInfo::mapIndVar(
10288     llvm::function_ref<Value *(Instruction *)> Updater) {
10289   assert(isValid() && "Requires a valid canonical loop");
10290 
10291   Instruction *OldIV = getIndVar();
10292 
10293   // Record all uses excluding those introduced by the updater. Uses by the
10294   // CanonicalLoopInfo itself to keep track of the number of iterations are
10295   // excluded.
10296   SmallVector<Use *> ReplacableUses;
10297   for (Use &U : OldIV->uses()) {
10298     auto *User = dyn_cast<Instruction>(U.getUser());
10299     if (!User)
10300       continue;
10301     if (User->getParent() == getCond())
10302       continue;
10303     if (User->getParent() == getLatch())
10304       continue;
10305     ReplacableUses.push_back(&U);
10306   }
10307 
10308   // Run the updater that may introduce new uses
10309   Value *NewIV = Updater(OldIV);
10310 
10311   // Replace the old uses with the value returned by the updater.
10312   for (Use *U : ReplacableUses)
10313     U->set(NewIV);
10314 
10315 #ifndef NDEBUG
10316   assertOK();
10317 #endif
10318 }
10319 
assertOK() const10320 void CanonicalLoopInfo::assertOK() const {
10321 #ifndef NDEBUG
10322   // No constraints if this object currently does not describe a loop.
10323   if (!isValid())
10324     return;
10325 
10326   BasicBlock *Preheader = getPreheader();
10327   BasicBlock *Body = getBody();
10328   BasicBlock *After = getAfter();
10329 
10330   // Verify standard control-flow we use for OpenMP loops.
10331   assert(Preheader);
10332   assert(isa<BranchInst>(Preheader->getTerminator()) &&
10333          "Preheader must terminate with unconditional branch");
10334   assert(Preheader->getSingleSuccessor() == Header &&
10335          "Preheader must jump to header");
10336 
10337   assert(Header);
10338   assert(isa<BranchInst>(Header->getTerminator()) &&
10339          "Header must terminate with unconditional branch");
10340   assert(Header->getSingleSuccessor() == Cond &&
10341          "Header must jump to exiting block");
10342 
10343   assert(Cond);
10344   assert(Cond->getSinglePredecessor() == Header &&
10345          "Exiting block only reachable from header");
10346 
10347   assert(isa<BranchInst>(Cond->getTerminator()) &&
10348          "Exiting block must terminate with conditional branch");
10349   assert(size(successors(Cond)) == 2 &&
10350          "Exiting block must have two successors");
10351   assert(cast<BranchInst>(Cond->getTerminator())->getSuccessor(0) == Body &&
10352          "Exiting block's first successor jump to the body");
10353   assert(cast<BranchInst>(Cond->getTerminator())->getSuccessor(1) == Exit &&
10354          "Exiting block's second successor must exit the loop");
10355 
10356   assert(Body);
10357   assert(Body->getSinglePredecessor() == Cond &&
10358          "Body only reachable from exiting block");
10359   assert(!isa<PHINode>(Body->front()));
10360 
10361   assert(Latch);
10362   assert(isa<BranchInst>(Latch->getTerminator()) &&
10363          "Latch must terminate with unconditional branch");
10364   assert(Latch->getSingleSuccessor() == Header && "Latch must jump to header");
10365   // TODO: To support simple redirecting of the end of the body code that has
10366   // multiple; introduce another auxiliary basic block like preheader and after.
10367   assert(Latch->getSinglePredecessor() != nullptr);
10368   assert(!isa<PHINode>(Latch->front()));
10369 
10370   assert(Exit);
10371   assert(isa<BranchInst>(Exit->getTerminator()) &&
10372          "Exit block must terminate with unconditional branch");
10373   assert(Exit->getSingleSuccessor() == After &&
10374          "Exit block must jump to after block");
10375 
10376   assert(After);
10377   assert(After->getSinglePredecessor() == Exit &&
10378          "After block only reachable from exit block");
10379   assert(After->empty() || !isa<PHINode>(After->front()));
10380 
10381   Instruction *IndVar = getIndVar();
10382   assert(IndVar && "Canonical induction variable not found?");
10383   assert(isa<IntegerType>(IndVar->getType()) &&
10384          "Induction variable must be an integer");
10385   assert(cast<PHINode>(IndVar)->getParent() == Header &&
10386          "Induction variable must be a PHI in the loop header");
10387   assert(cast<PHINode>(IndVar)->getIncomingBlock(0) == Preheader);
10388   assert(
10389       cast<ConstantInt>(cast<PHINode>(IndVar)->getIncomingValue(0))->isZero());
10390   assert(cast<PHINode>(IndVar)->getIncomingBlock(1) == Latch);
10391 
10392   auto *NextIndVar = cast<PHINode>(IndVar)->getIncomingValue(1);
10393   assert(cast<Instruction>(NextIndVar)->getParent() == Latch);
10394   assert(cast<BinaryOperator>(NextIndVar)->getOpcode() == BinaryOperator::Add);
10395   assert(cast<BinaryOperator>(NextIndVar)->getOperand(0) == IndVar);
10396   assert(cast<ConstantInt>(cast<BinaryOperator>(NextIndVar)->getOperand(1))
10397              ->isOne());
10398 
10399   Value *TripCount = getTripCount();
10400   assert(TripCount && "Loop trip count not found?");
10401   assert(IndVar->getType() == TripCount->getType() &&
10402          "Trip count and induction variable must have the same type");
10403 
10404   auto *CmpI = cast<CmpInst>(&Cond->front());
10405   assert(CmpI->getPredicate() == CmpInst::ICMP_ULT &&
10406          "Exit condition must be a signed less-than comparison");
10407   assert(CmpI->getOperand(0) == IndVar &&
10408          "Exit condition must compare the induction variable");
10409   assert(CmpI->getOperand(1) == TripCount &&
10410          "Exit condition must compare with the trip count");
10411 #endif
10412 }
10413 
invalidate()10414 void CanonicalLoopInfo::invalidate() {
10415   Header = nullptr;
10416   Cond = nullptr;
10417   Latch = nullptr;
10418   Exit = nullptr;
10419 }
10420