xref: /freebsd/contrib/llvm-project/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1 //===- OpenMPIRBuilder.cpp - Builder for LLVM-IR for OpenMP directives ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 ///
10 /// This file implements the OpenMPIRBuilder class, which is used as a
11 /// convenient way to create LLVM instructions for OpenMP directives.
12 ///
13 //===----------------------------------------------------------------------===//
14 
15 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
16 #include "llvm/ADT/SmallSet.h"
17 #include "llvm/ADT/StringExtras.h"
18 #include "llvm/ADT/StringRef.h"
19 #include "llvm/Analysis/AssumptionCache.h"
20 #include "llvm/Analysis/CodeMetrics.h"
21 #include "llvm/Analysis/LoopInfo.h"
22 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
23 #include "llvm/Analysis/ScalarEvolution.h"
24 #include "llvm/Analysis/TargetLibraryInfo.h"
25 #include "llvm/Bitcode/BitcodeReader.h"
26 #include "llvm/Frontend/Offloading/Utility.h"
27 #include "llvm/Frontend/OpenMP/OMPGridValues.h"
28 #include "llvm/IR/Attributes.h"
29 #include "llvm/IR/BasicBlock.h"
30 #include "llvm/IR/CFG.h"
31 #include "llvm/IR/CallingConv.h"
32 #include "llvm/IR/Constant.h"
33 #include "llvm/IR/Constants.h"
34 #include "llvm/IR/DebugInfoMetadata.h"
35 #include "llvm/IR/DerivedTypes.h"
36 #include "llvm/IR/Function.h"
37 #include "llvm/IR/GlobalVariable.h"
38 #include "llvm/IR/IRBuilder.h"
39 #include "llvm/IR/LLVMContext.h"
40 #include "llvm/IR/MDBuilder.h"
41 #include "llvm/IR/Metadata.h"
42 #include "llvm/IR/PassManager.h"
43 #include "llvm/IR/PassInstrumentation.h"
44 #include "llvm/IR/ReplaceConstant.h"
45 #include "llvm/IR/Value.h"
46 #include "llvm/MC/TargetRegistry.h"
47 #include "llvm/Support/CommandLine.h"
48 #include "llvm/Support/ErrorHandling.h"
49 #include "llvm/Support/FileSystem.h"
50 #include "llvm/Target/TargetMachine.h"
51 #include "llvm/Target/TargetOptions.h"
52 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
53 #include "llvm/Transforms/Utils/Cloning.h"
54 #include "llvm/Transforms/Utils/CodeExtractor.h"
55 #include "llvm/Transforms/Utils/LoopPeel.h"
56 #include "llvm/Transforms/Utils/UnrollLoop.h"
57 
58 #include <cstdint>
59 #include <optional>
60 #include <stack>
61 
62 #define DEBUG_TYPE "openmp-ir-builder"
63 
64 using namespace llvm;
65 using namespace omp;
66 
67 static cl::opt<bool>
68     OptimisticAttributes("openmp-ir-builder-optimistic-attributes", cl::Hidden,
69                          cl::desc("Use optimistic attributes describing "
70                                   "'as-if' properties of runtime calls."),
71                          cl::init(false));
72 
73 static cl::opt<double> UnrollThresholdFactor(
74     "openmp-ir-builder-unroll-threshold-factor", cl::Hidden,
75     cl::desc("Factor for the unroll threshold to account for code "
76              "simplifications still taking place"),
77     cl::init(1.5));
78 
79 #ifndef NDEBUG
80 /// Return whether IP1 and IP2 are ambiguous, i.e. that inserting instructions
81 /// at position IP1 may change the meaning of IP2 or vice-versa. This is because
82 /// an InsertPoint stores the instruction before something is inserted. For
83 /// instance, if both point to the same instruction, two IRBuilders alternating
84 /// creating instruction will cause the instructions to be interleaved.
isConflictIP(IRBuilder<>::InsertPoint IP1,IRBuilder<>::InsertPoint IP2)85 static bool isConflictIP(IRBuilder<>::InsertPoint IP1,
86                          IRBuilder<>::InsertPoint IP2) {
87   if (!IP1.isSet() || !IP2.isSet())
88     return false;
89   return IP1.getBlock() == IP2.getBlock() && IP1.getPoint() == IP2.getPoint();
90 }
91 
isValidWorkshareLoopScheduleType(OMPScheduleType SchedType)92 static bool isValidWorkshareLoopScheduleType(OMPScheduleType SchedType) {
93   // Valid ordered/unordered and base algorithm combinations.
94   switch (SchedType & ~OMPScheduleType::MonotonicityMask) {
95   case OMPScheduleType::UnorderedStaticChunked:
96   case OMPScheduleType::UnorderedStatic:
97   case OMPScheduleType::UnorderedDynamicChunked:
98   case OMPScheduleType::UnorderedGuidedChunked:
99   case OMPScheduleType::UnorderedRuntime:
100   case OMPScheduleType::UnorderedAuto:
101   case OMPScheduleType::UnorderedTrapezoidal:
102   case OMPScheduleType::UnorderedGreedy:
103   case OMPScheduleType::UnorderedBalanced:
104   case OMPScheduleType::UnorderedGuidedIterativeChunked:
105   case OMPScheduleType::UnorderedGuidedAnalyticalChunked:
106   case OMPScheduleType::UnorderedSteal:
107   case OMPScheduleType::UnorderedStaticBalancedChunked:
108   case OMPScheduleType::UnorderedGuidedSimd:
109   case OMPScheduleType::UnorderedRuntimeSimd:
110   case OMPScheduleType::OrderedStaticChunked:
111   case OMPScheduleType::OrderedStatic:
112   case OMPScheduleType::OrderedDynamicChunked:
113   case OMPScheduleType::OrderedGuidedChunked:
114   case OMPScheduleType::OrderedRuntime:
115   case OMPScheduleType::OrderedAuto:
116   case OMPScheduleType::OrderdTrapezoidal:
117   case OMPScheduleType::NomergeUnorderedStaticChunked:
118   case OMPScheduleType::NomergeUnorderedStatic:
119   case OMPScheduleType::NomergeUnorderedDynamicChunked:
120   case OMPScheduleType::NomergeUnorderedGuidedChunked:
121   case OMPScheduleType::NomergeUnorderedRuntime:
122   case OMPScheduleType::NomergeUnorderedAuto:
123   case OMPScheduleType::NomergeUnorderedTrapezoidal:
124   case OMPScheduleType::NomergeUnorderedGreedy:
125   case OMPScheduleType::NomergeUnorderedBalanced:
126   case OMPScheduleType::NomergeUnorderedGuidedIterativeChunked:
127   case OMPScheduleType::NomergeUnorderedGuidedAnalyticalChunked:
128   case OMPScheduleType::NomergeUnorderedSteal:
129   case OMPScheduleType::NomergeOrderedStaticChunked:
130   case OMPScheduleType::NomergeOrderedStatic:
131   case OMPScheduleType::NomergeOrderedDynamicChunked:
132   case OMPScheduleType::NomergeOrderedGuidedChunked:
133   case OMPScheduleType::NomergeOrderedRuntime:
134   case OMPScheduleType::NomergeOrderedAuto:
135   case OMPScheduleType::NomergeOrderedTrapezoidal:
136     break;
137   default:
138     return false;
139   }
140 
141   // Must not set both monotonicity modifiers at the same time.
142   OMPScheduleType MonotonicityFlags =
143       SchedType & OMPScheduleType::MonotonicityMask;
144   if (MonotonicityFlags == OMPScheduleType::MonotonicityMask)
145     return false;
146 
147   return true;
148 }
149 #endif
150 
getGridValue(const Triple & T,Function * Kernel)151 static const omp::GV &getGridValue(const Triple &T, Function *Kernel) {
152   if (T.isAMDGPU()) {
153     StringRef Features =
154         Kernel->getFnAttribute("target-features").getValueAsString();
155     if (Features.count("+wavefrontsize64"))
156       return omp::getAMDGPUGridValues<64>();
157     return omp::getAMDGPUGridValues<32>();
158   }
159   if (T.isNVPTX())
160     return omp::NVPTXGridValues;
161   llvm_unreachable("No grid value available for this architecture!");
162 }
163 
164 /// Determine which scheduling algorithm to use, determined from schedule clause
165 /// arguments.
166 static OMPScheduleType
getOpenMPBaseScheduleType(llvm::omp::ScheduleKind ClauseKind,bool HasChunks,bool HasSimdModifier)167 getOpenMPBaseScheduleType(llvm::omp::ScheduleKind ClauseKind, bool HasChunks,
168                           bool HasSimdModifier) {
169   // Currently, the default schedule it static.
170   switch (ClauseKind) {
171   case OMP_SCHEDULE_Default:
172   case OMP_SCHEDULE_Static:
173     return HasChunks ? OMPScheduleType::BaseStaticChunked
174                      : OMPScheduleType::BaseStatic;
175   case OMP_SCHEDULE_Dynamic:
176     return OMPScheduleType::BaseDynamicChunked;
177   case OMP_SCHEDULE_Guided:
178     return HasSimdModifier ? OMPScheduleType::BaseGuidedSimd
179                            : OMPScheduleType::BaseGuidedChunked;
180   case OMP_SCHEDULE_Auto:
181     return llvm::omp::OMPScheduleType::BaseAuto;
182   case OMP_SCHEDULE_Runtime:
183     return HasSimdModifier ? OMPScheduleType::BaseRuntimeSimd
184                            : OMPScheduleType::BaseRuntime;
185   }
186   llvm_unreachable("unhandled schedule clause argument");
187 }
188 
189 /// Adds ordering modifier flags to schedule type.
190 static OMPScheduleType
getOpenMPOrderingScheduleType(OMPScheduleType BaseScheduleType,bool HasOrderedClause)191 getOpenMPOrderingScheduleType(OMPScheduleType BaseScheduleType,
192                               bool HasOrderedClause) {
193   assert((BaseScheduleType & OMPScheduleType::ModifierMask) ==
194              OMPScheduleType::None &&
195          "Must not have ordering nor monotonicity flags already set");
196 
197   OMPScheduleType OrderingModifier = HasOrderedClause
198                                          ? OMPScheduleType::ModifierOrdered
199                                          : OMPScheduleType::ModifierUnordered;
200   OMPScheduleType OrderingScheduleType = BaseScheduleType | OrderingModifier;
201 
202   // Unsupported combinations
203   if (OrderingScheduleType ==
204       (OMPScheduleType::BaseGuidedSimd | OMPScheduleType::ModifierOrdered))
205     return OMPScheduleType::OrderedGuidedChunked;
206   else if (OrderingScheduleType == (OMPScheduleType::BaseRuntimeSimd |
207                                     OMPScheduleType::ModifierOrdered))
208     return OMPScheduleType::OrderedRuntime;
209 
210   return OrderingScheduleType;
211 }
212 
213 /// Adds monotonicity modifier flags to schedule type.
214 static OMPScheduleType
getOpenMPMonotonicityScheduleType(OMPScheduleType ScheduleType,bool HasSimdModifier,bool HasMonotonic,bool HasNonmonotonic,bool HasOrderedClause)215 getOpenMPMonotonicityScheduleType(OMPScheduleType ScheduleType,
216                                   bool HasSimdModifier, bool HasMonotonic,
217                                   bool HasNonmonotonic, bool HasOrderedClause) {
218   assert((ScheduleType & OMPScheduleType::MonotonicityMask) ==
219              OMPScheduleType::None &&
220          "Must not have monotonicity flags already set");
221   assert((!HasMonotonic || !HasNonmonotonic) &&
222          "Monotonic and Nonmonotonic are contradicting each other");
223 
224   if (HasMonotonic) {
225     return ScheduleType | OMPScheduleType::ModifierMonotonic;
226   } else if (HasNonmonotonic) {
227     return ScheduleType | OMPScheduleType::ModifierNonmonotonic;
228   } else {
229     // OpenMP 5.1, 2.11.4 Worksharing-Loop Construct, Description.
230     // If the static schedule kind is specified or if the ordered clause is
231     // specified, and if the nonmonotonic modifier is not specified, the
232     // effect is as if the monotonic modifier is specified. Otherwise, unless
233     // the monotonic modifier is specified, the effect is as if the
234     // nonmonotonic modifier is specified.
235     OMPScheduleType BaseScheduleType =
236         ScheduleType & ~OMPScheduleType::ModifierMask;
237     if ((BaseScheduleType == OMPScheduleType::BaseStatic) ||
238         (BaseScheduleType == OMPScheduleType::BaseStaticChunked) ||
239         HasOrderedClause) {
240       // The monotonic is used by default in openmp runtime library, so no need
241       // to set it.
242       return ScheduleType;
243     } else {
244       return ScheduleType | OMPScheduleType::ModifierNonmonotonic;
245     }
246   }
247 }
248 
249 /// Determine the schedule type using schedule and ordering clause arguments.
250 static OMPScheduleType
computeOpenMPScheduleType(ScheduleKind ClauseKind,bool HasChunks,bool HasSimdModifier,bool HasMonotonicModifier,bool HasNonmonotonicModifier,bool HasOrderedClause)251 computeOpenMPScheduleType(ScheduleKind ClauseKind, bool HasChunks,
252                           bool HasSimdModifier, bool HasMonotonicModifier,
253                           bool HasNonmonotonicModifier, bool HasOrderedClause) {
254   OMPScheduleType BaseSchedule =
255       getOpenMPBaseScheduleType(ClauseKind, HasChunks, HasSimdModifier);
256   OMPScheduleType OrderedSchedule =
257       getOpenMPOrderingScheduleType(BaseSchedule, HasOrderedClause);
258   OMPScheduleType Result = getOpenMPMonotonicityScheduleType(
259       OrderedSchedule, HasSimdModifier, HasMonotonicModifier,
260       HasNonmonotonicModifier, HasOrderedClause);
261 
262   assert(isValidWorkshareLoopScheduleType(Result));
263   return Result;
264 }
265 
266 /// Make \p Source branch to \p Target.
267 ///
268 /// Handles two situations:
269 /// * \p Source already has an unconditional branch.
270 /// * \p Source is a degenerate block (no terminator because the BB is
271 ///             the current head of the IR construction).
redirectTo(BasicBlock * Source,BasicBlock * Target,DebugLoc DL)272 static void redirectTo(BasicBlock *Source, BasicBlock *Target, DebugLoc DL) {
273   if (Instruction *Term = Source->getTerminator()) {
274     auto *Br = cast<BranchInst>(Term);
275     assert(!Br->isConditional() &&
276            "BB's terminator must be an unconditional branch (or degenerate)");
277     BasicBlock *Succ = Br->getSuccessor(0);
278     Succ->removePredecessor(Source, /*KeepOneInputPHIs=*/true);
279     Br->setSuccessor(0, Target);
280     return;
281   }
282 
283   auto *NewBr = BranchInst::Create(Target, Source);
284   NewBr->setDebugLoc(DL);
285 }
286 
spliceBB(IRBuilderBase::InsertPoint IP,BasicBlock * New,bool CreateBranch)287 void llvm::spliceBB(IRBuilderBase::InsertPoint IP, BasicBlock *New,
288                     bool CreateBranch) {
289   assert(New->getFirstInsertionPt() == New->begin() &&
290          "Target BB must not have PHI nodes");
291 
292   // Move instructions to new block.
293   BasicBlock *Old = IP.getBlock();
294   New->splice(New->begin(), Old, IP.getPoint(), Old->end());
295 
296   if (CreateBranch)
297     BranchInst::Create(New, Old);
298 }
299 
spliceBB(IRBuilder<> & Builder,BasicBlock * New,bool CreateBranch)300 void llvm::spliceBB(IRBuilder<> &Builder, BasicBlock *New, bool CreateBranch) {
301   DebugLoc DebugLoc = Builder.getCurrentDebugLocation();
302   BasicBlock *Old = Builder.GetInsertBlock();
303 
304   spliceBB(Builder.saveIP(), New, CreateBranch);
305   if (CreateBranch)
306     Builder.SetInsertPoint(Old->getTerminator());
307   else
308     Builder.SetInsertPoint(Old);
309 
310   // SetInsertPoint also updates the Builder's debug location, but we want to
311   // keep the one the Builder was configured to use.
312   Builder.SetCurrentDebugLocation(DebugLoc);
313 }
314 
splitBB(IRBuilderBase::InsertPoint IP,bool CreateBranch,llvm::Twine Name)315 BasicBlock *llvm::splitBB(IRBuilderBase::InsertPoint IP, bool CreateBranch,
316                           llvm::Twine Name) {
317   BasicBlock *Old = IP.getBlock();
318   BasicBlock *New = BasicBlock::Create(
319       Old->getContext(), Name.isTriviallyEmpty() ? Old->getName() : Name,
320       Old->getParent(), Old->getNextNode());
321   spliceBB(IP, New, CreateBranch);
322   New->replaceSuccessorsPhiUsesWith(Old, New);
323   return New;
324 }
325 
splitBB(IRBuilderBase & Builder,bool CreateBranch,llvm::Twine Name)326 BasicBlock *llvm::splitBB(IRBuilderBase &Builder, bool CreateBranch,
327                           llvm::Twine Name) {
328   DebugLoc DebugLoc = Builder.getCurrentDebugLocation();
329   BasicBlock *New = splitBB(Builder.saveIP(), CreateBranch, Name);
330   if (CreateBranch)
331     Builder.SetInsertPoint(Builder.GetInsertBlock()->getTerminator());
332   else
333     Builder.SetInsertPoint(Builder.GetInsertBlock());
334   // SetInsertPoint also updates the Builder's debug location, but we want to
335   // keep the one the Builder was configured to use.
336   Builder.SetCurrentDebugLocation(DebugLoc);
337   return New;
338 }
339 
splitBB(IRBuilder<> & Builder,bool CreateBranch,llvm::Twine Name)340 BasicBlock *llvm::splitBB(IRBuilder<> &Builder, bool CreateBranch,
341                           llvm::Twine Name) {
342   DebugLoc DebugLoc = Builder.getCurrentDebugLocation();
343   BasicBlock *New = splitBB(Builder.saveIP(), CreateBranch, Name);
344   if (CreateBranch)
345     Builder.SetInsertPoint(Builder.GetInsertBlock()->getTerminator());
346   else
347     Builder.SetInsertPoint(Builder.GetInsertBlock());
348   // SetInsertPoint also updates the Builder's debug location, but we want to
349   // keep the one the Builder was configured to use.
350   Builder.SetCurrentDebugLocation(DebugLoc);
351   return New;
352 }
353 
splitBBWithSuffix(IRBuilderBase & Builder,bool CreateBranch,llvm::Twine Suffix)354 BasicBlock *llvm::splitBBWithSuffix(IRBuilderBase &Builder, bool CreateBranch,
355                                     llvm::Twine Suffix) {
356   BasicBlock *Old = Builder.GetInsertBlock();
357   return splitBB(Builder, CreateBranch, Old->getName() + Suffix);
358 }
359 
360 // This function creates a fake integer value and a fake use for the integer
361 // value. It returns the fake value created. This is useful in modeling the
362 // extra arguments to the outlined functions.
createFakeIntVal(IRBuilderBase & Builder,OpenMPIRBuilder::InsertPointTy OuterAllocaIP,llvm::SmallVectorImpl<Instruction * > & ToBeDeleted,OpenMPIRBuilder::InsertPointTy InnerAllocaIP,const Twine & Name="",bool AsPtr=true)363 Value *createFakeIntVal(IRBuilderBase &Builder,
364                         OpenMPIRBuilder::InsertPointTy OuterAllocaIP,
365                         llvm::SmallVectorImpl<Instruction *> &ToBeDeleted,
366                         OpenMPIRBuilder::InsertPointTy InnerAllocaIP,
367                         const Twine &Name = "", bool AsPtr = true) {
368   Builder.restoreIP(OuterAllocaIP);
369   Instruction *FakeVal;
370   AllocaInst *FakeValAddr =
371       Builder.CreateAlloca(Builder.getInt32Ty(), nullptr, Name + ".addr");
372   ToBeDeleted.push_back(FakeValAddr);
373 
374   if (AsPtr) {
375     FakeVal = FakeValAddr;
376   } else {
377     FakeVal =
378         Builder.CreateLoad(Builder.getInt32Ty(), FakeValAddr, Name + ".val");
379     ToBeDeleted.push_back(FakeVal);
380   }
381 
382   // Generate a fake use of this value
383   Builder.restoreIP(InnerAllocaIP);
384   Instruction *UseFakeVal;
385   if (AsPtr) {
386     UseFakeVal =
387         Builder.CreateLoad(Builder.getInt32Ty(), FakeVal, Name + ".use");
388   } else {
389     UseFakeVal =
390         cast<BinaryOperator>(Builder.CreateAdd(FakeVal, Builder.getInt32(10)));
391   }
392   ToBeDeleted.push_back(UseFakeVal);
393   return FakeVal;
394 }
395 
396 //===----------------------------------------------------------------------===//
397 // OpenMPIRBuilderConfig
398 //===----------------------------------------------------------------------===//
399 
400 namespace {
401 LLVM_ENABLE_BITMASK_ENUMS_IN_NAMESPACE();
402 /// Values for bit flags for marking which requires clauses have been used.
403 enum OpenMPOffloadingRequiresDirFlags {
404   /// flag undefined.
405   OMP_REQ_UNDEFINED = 0x000,
406   /// no requires directive present.
407   OMP_REQ_NONE = 0x001,
408   /// reverse_offload clause.
409   OMP_REQ_REVERSE_OFFLOAD = 0x002,
410   /// unified_address clause.
411   OMP_REQ_UNIFIED_ADDRESS = 0x004,
412   /// unified_shared_memory clause.
413   OMP_REQ_UNIFIED_SHARED_MEMORY = 0x008,
414   /// dynamic_allocators clause.
415   OMP_REQ_DYNAMIC_ALLOCATORS = 0x010,
416   LLVM_MARK_AS_BITMASK_ENUM(/*LargestValue=*/OMP_REQ_DYNAMIC_ALLOCATORS)
417 };
418 
419 } // anonymous namespace
420 
OpenMPIRBuilderConfig()421 OpenMPIRBuilderConfig::OpenMPIRBuilderConfig()
422     : RequiresFlags(OMP_REQ_UNDEFINED) {}
423 
OpenMPIRBuilderConfig(bool IsTargetDevice,bool IsGPU,bool OpenMPOffloadMandatory,bool HasRequiresReverseOffload,bool HasRequiresUnifiedAddress,bool HasRequiresUnifiedSharedMemory,bool HasRequiresDynamicAllocators)424 OpenMPIRBuilderConfig::OpenMPIRBuilderConfig(
425     bool IsTargetDevice, bool IsGPU, bool OpenMPOffloadMandatory,
426     bool HasRequiresReverseOffload, bool HasRequiresUnifiedAddress,
427     bool HasRequiresUnifiedSharedMemory, bool HasRequiresDynamicAllocators)
428     : IsTargetDevice(IsTargetDevice), IsGPU(IsGPU),
429       OpenMPOffloadMandatory(OpenMPOffloadMandatory),
430       RequiresFlags(OMP_REQ_UNDEFINED) {
431   if (HasRequiresReverseOffload)
432     RequiresFlags |= OMP_REQ_REVERSE_OFFLOAD;
433   if (HasRequiresUnifiedAddress)
434     RequiresFlags |= OMP_REQ_UNIFIED_ADDRESS;
435   if (HasRequiresUnifiedSharedMemory)
436     RequiresFlags |= OMP_REQ_UNIFIED_SHARED_MEMORY;
437   if (HasRequiresDynamicAllocators)
438     RequiresFlags |= OMP_REQ_DYNAMIC_ALLOCATORS;
439 }
440 
hasRequiresReverseOffload() const441 bool OpenMPIRBuilderConfig::hasRequiresReverseOffload() const {
442   return RequiresFlags & OMP_REQ_REVERSE_OFFLOAD;
443 }
444 
hasRequiresUnifiedAddress() const445 bool OpenMPIRBuilderConfig::hasRequiresUnifiedAddress() const {
446   return RequiresFlags & OMP_REQ_UNIFIED_ADDRESS;
447 }
448 
hasRequiresUnifiedSharedMemory() const449 bool OpenMPIRBuilderConfig::hasRequiresUnifiedSharedMemory() const {
450   return RequiresFlags & OMP_REQ_UNIFIED_SHARED_MEMORY;
451 }
452 
hasRequiresDynamicAllocators() const453 bool OpenMPIRBuilderConfig::hasRequiresDynamicAllocators() const {
454   return RequiresFlags & OMP_REQ_DYNAMIC_ALLOCATORS;
455 }
456 
getRequiresFlags() const457 int64_t OpenMPIRBuilderConfig::getRequiresFlags() const {
458   return hasRequiresFlags() ? RequiresFlags
459                             : static_cast<int64_t>(OMP_REQ_NONE);
460 }
461 
setHasRequiresReverseOffload(bool Value)462 void OpenMPIRBuilderConfig::setHasRequiresReverseOffload(bool Value) {
463   if (Value)
464     RequiresFlags |= OMP_REQ_REVERSE_OFFLOAD;
465   else
466     RequiresFlags &= ~OMP_REQ_REVERSE_OFFLOAD;
467 }
468 
setHasRequiresUnifiedAddress(bool Value)469 void OpenMPIRBuilderConfig::setHasRequiresUnifiedAddress(bool Value) {
470   if (Value)
471     RequiresFlags |= OMP_REQ_UNIFIED_ADDRESS;
472   else
473     RequiresFlags &= ~OMP_REQ_UNIFIED_ADDRESS;
474 }
475 
setHasRequiresUnifiedSharedMemory(bool Value)476 void OpenMPIRBuilderConfig::setHasRequiresUnifiedSharedMemory(bool Value) {
477   if (Value)
478     RequiresFlags |= OMP_REQ_UNIFIED_SHARED_MEMORY;
479   else
480     RequiresFlags &= ~OMP_REQ_UNIFIED_SHARED_MEMORY;
481 }
482 
setHasRequiresDynamicAllocators(bool Value)483 void OpenMPIRBuilderConfig::setHasRequiresDynamicAllocators(bool Value) {
484   if (Value)
485     RequiresFlags |= OMP_REQ_DYNAMIC_ALLOCATORS;
486   else
487     RequiresFlags &= ~OMP_REQ_DYNAMIC_ALLOCATORS;
488 }
489 
490 //===----------------------------------------------------------------------===//
491 // OpenMPIRBuilder
492 //===----------------------------------------------------------------------===//
493 
getKernelArgsVector(TargetKernelArgs & KernelArgs,IRBuilderBase & Builder,SmallVector<Value * > & ArgsVector)494 void OpenMPIRBuilder::getKernelArgsVector(TargetKernelArgs &KernelArgs,
495                                           IRBuilderBase &Builder,
496                                           SmallVector<Value *> &ArgsVector) {
497   Value *Version = Builder.getInt32(OMP_KERNEL_ARG_VERSION);
498   Value *PointerNum = Builder.getInt32(KernelArgs.NumTargetItems);
499   auto Int32Ty = Type::getInt32Ty(Builder.getContext());
500   Value *ZeroArray = Constant::getNullValue(ArrayType::get(Int32Ty, 3));
501   Value *Flags = Builder.getInt64(KernelArgs.HasNoWait);
502 
503   Value *NumTeams3D =
504       Builder.CreateInsertValue(ZeroArray, KernelArgs.NumTeams, {0});
505   Value *NumThreads3D =
506       Builder.CreateInsertValue(ZeroArray, KernelArgs.NumThreads, {0});
507 
508   ArgsVector = {Version,
509                 PointerNum,
510                 KernelArgs.RTArgs.BasePointersArray,
511                 KernelArgs.RTArgs.PointersArray,
512                 KernelArgs.RTArgs.SizesArray,
513                 KernelArgs.RTArgs.MapTypesArray,
514                 KernelArgs.RTArgs.MapNamesArray,
515                 KernelArgs.RTArgs.MappersArray,
516                 KernelArgs.NumIterations,
517                 Flags,
518                 NumTeams3D,
519                 NumThreads3D,
520                 KernelArgs.DynCGGroupMem};
521 }
522 
addAttributes(omp::RuntimeFunction FnID,Function & Fn)523 void OpenMPIRBuilder::addAttributes(omp::RuntimeFunction FnID, Function &Fn) {
524   LLVMContext &Ctx = Fn.getContext();
525 
526   // Get the function's current attributes.
527   auto Attrs = Fn.getAttributes();
528   auto FnAttrs = Attrs.getFnAttrs();
529   auto RetAttrs = Attrs.getRetAttrs();
530   SmallVector<AttributeSet, 4> ArgAttrs;
531   for (size_t ArgNo = 0; ArgNo < Fn.arg_size(); ++ArgNo)
532     ArgAttrs.emplace_back(Attrs.getParamAttrs(ArgNo));
533 
534   // Add AS to FnAS while taking special care with integer extensions.
535   auto addAttrSet = [&](AttributeSet &FnAS, const AttributeSet &AS,
536                         bool Param = true) -> void {
537     bool HasSignExt = AS.hasAttribute(Attribute::SExt);
538     bool HasZeroExt = AS.hasAttribute(Attribute::ZExt);
539     if (HasSignExt || HasZeroExt) {
540       assert(AS.getNumAttributes() == 1 &&
541              "Currently not handling extension attr combined with others.");
542       if (Param) {
543         if (auto AK = TargetLibraryInfo::getExtAttrForI32Param(T, HasSignExt))
544           FnAS = FnAS.addAttribute(Ctx, AK);
545       } else if (auto AK =
546                      TargetLibraryInfo::getExtAttrForI32Return(T, HasSignExt))
547         FnAS = FnAS.addAttribute(Ctx, AK);
548     } else {
549       FnAS = FnAS.addAttributes(Ctx, AS);
550     }
551   };
552 
553 #define OMP_ATTRS_SET(VarName, AttrSet) AttributeSet VarName = AttrSet;
554 #include "llvm/Frontend/OpenMP/OMPKinds.def"
555 
556   // Add attributes to the function declaration.
557   switch (FnID) {
558 #define OMP_RTL_ATTRS(Enum, FnAttrSet, RetAttrSet, ArgAttrSets)                \
559   case Enum:                                                                   \
560     FnAttrs = FnAttrs.addAttributes(Ctx, FnAttrSet);                           \
561     addAttrSet(RetAttrs, RetAttrSet, /*Param*/ false);                         \
562     for (size_t ArgNo = 0; ArgNo < ArgAttrSets.size(); ++ArgNo)                \
563       addAttrSet(ArgAttrs[ArgNo], ArgAttrSets[ArgNo]);                         \
564     Fn.setAttributes(AttributeList::get(Ctx, FnAttrs, RetAttrs, ArgAttrs));    \
565     break;
566 #include "llvm/Frontend/OpenMP/OMPKinds.def"
567   default:
568     // Attributes are optional.
569     break;
570   }
571 }
572 
573 FunctionCallee
getOrCreateRuntimeFunction(Module & M,RuntimeFunction FnID)574 OpenMPIRBuilder::getOrCreateRuntimeFunction(Module &M, RuntimeFunction FnID) {
575   FunctionType *FnTy = nullptr;
576   Function *Fn = nullptr;
577 
578   // Try to find the declation in the module first.
579   switch (FnID) {
580 #define OMP_RTL(Enum, Str, IsVarArg, ReturnType, ...)                          \
581   case Enum:                                                                   \
582     FnTy = FunctionType::get(ReturnType, ArrayRef<Type *>{__VA_ARGS__},        \
583                              IsVarArg);                                        \
584     Fn = M.getFunction(Str);                                                   \
585     break;
586 #include "llvm/Frontend/OpenMP/OMPKinds.def"
587   }
588 
589   if (!Fn) {
590     // Create a new declaration if we need one.
591     switch (FnID) {
592 #define OMP_RTL(Enum, Str, ...)                                                \
593   case Enum:                                                                   \
594     Fn = Function::Create(FnTy, GlobalValue::ExternalLinkage, Str, M);         \
595     break;
596 #include "llvm/Frontend/OpenMP/OMPKinds.def"
597     }
598 
599     // Add information if the runtime function takes a callback function
600     if (FnID == OMPRTL___kmpc_fork_call || FnID == OMPRTL___kmpc_fork_teams) {
601       if (!Fn->hasMetadata(LLVMContext::MD_callback)) {
602         LLVMContext &Ctx = Fn->getContext();
603         MDBuilder MDB(Ctx);
604         // Annotate the callback behavior of the runtime function:
605         //  - The callback callee is argument number 2 (microtask).
606         //  - The first two arguments of the callback callee are unknown (-1).
607         //  - All variadic arguments to the runtime function are passed to the
608         //    callback callee.
609         Fn->addMetadata(
610             LLVMContext::MD_callback,
611             *MDNode::get(Ctx, {MDB.createCallbackEncoding(
612                                   2, {-1, -1}, /* VarArgsArePassed */ true)}));
613       }
614     }
615 
616     LLVM_DEBUG(dbgs() << "Created OpenMP runtime function " << Fn->getName()
617                       << " with type " << *Fn->getFunctionType() << "\n");
618     addAttributes(FnID, *Fn);
619 
620   } else {
621     LLVM_DEBUG(dbgs() << "Found OpenMP runtime function " << Fn->getName()
622                       << " with type " << *Fn->getFunctionType() << "\n");
623   }
624 
625   assert(Fn && "Failed to create OpenMP runtime function");
626 
627   return {FnTy, Fn};
628 }
629 
getOrCreateRuntimeFunctionPtr(RuntimeFunction FnID)630 Function *OpenMPIRBuilder::getOrCreateRuntimeFunctionPtr(RuntimeFunction FnID) {
631   FunctionCallee RTLFn = getOrCreateRuntimeFunction(M, FnID);
632   auto *Fn = dyn_cast<llvm::Function>(RTLFn.getCallee());
633   assert(Fn && "Failed to create OpenMP runtime function pointer");
634   return Fn;
635 }
636 
initialize()637 void OpenMPIRBuilder::initialize() { initializeTypes(M); }
638 
raiseUserConstantDataAllocasToEntryBlock(IRBuilderBase & Builder,Function * Function)639 static void raiseUserConstantDataAllocasToEntryBlock(IRBuilderBase &Builder,
640                                                      Function *Function) {
641   BasicBlock &EntryBlock = Function->getEntryBlock();
642   Instruction *MoveLocInst = EntryBlock.getFirstNonPHI();
643 
644   // Loop over blocks looking for constant allocas, skipping the entry block
645   // as any allocas there are already in the desired location.
646   for (auto Block = std::next(Function->begin(), 1); Block != Function->end();
647        Block++) {
648     for (auto Inst = Block->getReverseIterator()->begin();
649          Inst != Block->getReverseIterator()->end();) {
650       if (auto *AllocaInst = dyn_cast_if_present<llvm::AllocaInst>(Inst)) {
651         Inst++;
652         if (!isa<ConstantData>(AllocaInst->getArraySize()))
653           continue;
654         AllocaInst->moveBeforePreserving(MoveLocInst);
655       } else {
656         Inst++;
657       }
658     }
659   }
660 }
661 
finalize(Function * Fn)662 void OpenMPIRBuilder::finalize(Function *Fn) {
663   SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
664   SmallVector<BasicBlock *, 32> Blocks;
665   SmallVector<OutlineInfo, 16> DeferredOutlines;
666   for (OutlineInfo &OI : OutlineInfos) {
667     // Skip functions that have not finalized yet; may happen with nested
668     // function generation.
669     if (Fn && OI.getFunction() != Fn) {
670       DeferredOutlines.push_back(OI);
671       continue;
672     }
673 
674     ParallelRegionBlockSet.clear();
675     Blocks.clear();
676     OI.collectBlocks(ParallelRegionBlockSet, Blocks);
677 
678     Function *OuterFn = OI.getFunction();
679     CodeExtractorAnalysisCache CEAC(*OuterFn);
680     // If we generate code for the target device, we need to allocate
681     // struct for aggregate params in the device default alloca address space.
682     // OpenMP runtime requires that the params of the extracted functions are
683     // passed as zero address space pointers. This flag ensures that
684     // CodeExtractor generates correct code for extracted functions
685     // which are used by OpenMP runtime.
686     bool ArgsInZeroAddressSpace = Config.isTargetDevice();
687     CodeExtractor Extractor(Blocks, /* DominatorTree */ nullptr,
688                             /* AggregateArgs */ true,
689                             /* BlockFrequencyInfo */ nullptr,
690                             /* BranchProbabilityInfo */ nullptr,
691                             /* AssumptionCache */ nullptr,
692                             /* AllowVarArgs */ true,
693                             /* AllowAlloca */ true,
694                             /* AllocaBlock*/ OI.OuterAllocaBB,
695                             /* Suffix */ ".omp_par", ArgsInZeroAddressSpace);
696 
697     LLVM_DEBUG(dbgs() << "Before     outlining: " << *OuterFn << "\n");
698     LLVM_DEBUG(dbgs() << "Entry " << OI.EntryBB->getName()
699                       << " Exit: " << OI.ExitBB->getName() << "\n");
700     assert(Extractor.isEligible() &&
701            "Expected OpenMP outlining to be possible!");
702 
703     for (auto *V : OI.ExcludeArgsFromAggregate)
704       Extractor.excludeArgFromAggregate(V);
705 
706     Function *OutlinedFn = Extractor.extractCodeRegion(CEAC);
707 
708     // Forward target-cpu, target-features attributes to the outlined function.
709     auto TargetCpuAttr = OuterFn->getFnAttribute("target-cpu");
710     if (TargetCpuAttr.isStringAttribute())
711       OutlinedFn->addFnAttr(TargetCpuAttr);
712 
713     auto TargetFeaturesAttr = OuterFn->getFnAttribute("target-features");
714     if (TargetFeaturesAttr.isStringAttribute())
715       OutlinedFn->addFnAttr(TargetFeaturesAttr);
716 
717     LLVM_DEBUG(dbgs() << "After      outlining: " << *OuterFn << "\n");
718     LLVM_DEBUG(dbgs() << "   Outlined function: " << *OutlinedFn << "\n");
719     assert(OutlinedFn->getReturnType()->isVoidTy() &&
720            "OpenMP outlined functions should not return a value!");
721 
722     // For compability with the clang CG we move the outlined function after the
723     // one with the parallel region.
724     OutlinedFn->removeFromParent();
725     M.getFunctionList().insertAfter(OuterFn->getIterator(), OutlinedFn);
726 
727     // Remove the artificial entry introduced by the extractor right away, we
728     // made our own entry block after all.
729     {
730       BasicBlock &ArtificialEntry = OutlinedFn->getEntryBlock();
731       assert(ArtificialEntry.getUniqueSuccessor() == OI.EntryBB);
732       assert(OI.EntryBB->getUniquePredecessor() == &ArtificialEntry);
733       // Move instructions from the to-be-deleted ArtificialEntry to the entry
734       // basic block of the parallel region. CodeExtractor generates
735       // instructions to unwrap the aggregate argument and may sink
736       // allocas/bitcasts for values that are solely used in the outlined region
737       // and do not escape.
738       assert(!ArtificialEntry.empty() &&
739              "Expected instructions to add in the outlined region entry");
740       for (BasicBlock::reverse_iterator It = ArtificialEntry.rbegin(),
741                                         End = ArtificialEntry.rend();
742            It != End;) {
743         Instruction &I = *It;
744         It++;
745 
746         if (I.isTerminator())
747           continue;
748 
749         I.moveBeforePreserving(*OI.EntryBB, OI.EntryBB->getFirstInsertionPt());
750       }
751 
752       OI.EntryBB->moveBefore(&ArtificialEntry);
753       ArtificialEntry.eraseFromParent();
754     }
755     assert(&OutlinedFn->getEntryBlock() == OI.EntryBB);
756     assert(OutlinedFn && OutlinedFn->getNumUses() == 1);
757 
758     // Run a user callback, e.g. to add attributes.
759     if (OI.PostOutlineCB)
760       OI.PostOutlineCB(*OutlinedFn);
761   }
762 
763   // Remove work items that have been completed.
764   OutlineInfos = std::move(DeferredOutlines);
765 
766   // The createTarget functions embeds user written code into
767   // the target region which may inject allocas which need to
768   // be moved to the entry block of our target or risk malformed
769   // optimisations by later passes, this is only relevant for
770   // the device pass which appears to be a little more delicate
771   // when it comes to optimisations (however, we do not block on
772   // that here, it's up to the inserter to the list to do so).
773   // This notbaly has to occur after the OutlinedInfo candidates
774   // have been extracted so we have an end product that will not
775   // be implicitly adversely affected by any raises unless
776   // intentionally appended to the list.
777   // NOTE: This only does so for ConstantData, it could be extended
778   // to ConstantExpr's with further effort, however, they should
779   // largely be folded when they get here. Extending it to runtime
780   // defined/read+writeable allocation sizes would be non-trivial
781   // (need to factor in movement of any stores to variables the
782   // allocation size depends on, as well as the usual loads,
783   // otherwise it'll yield the wrong result after movement) and
784   // likely be more suitable as an LLVM optimisation pass.
785   for (Function *F : ConstantAllocaRaiseCandidates)
786     raiseUserConstantDataAllocasToEntryBlock(Builder, F);
787 
788   EmitMetadataErrorReportFunctionTy &&ErrorReportFn =
789       [](EmitMetadataErrorKind Kind,
790          const TargetRegionEntryInfo &EntryInfo) -> void {
791     errs() << "Error of kind: " << Kind
792            << " when emitting offload entries and metadata during "
793               "OMPIRBuilder finalization \n";
794   };
795 
796   if (!OffloadInfoManager.empty())
797     createOffloadEntriesAndInfoMetadata(ErrorReportFn);
798 
799   if (Config.EmitLLVMUsedMetaInfo.value_or(false)) {
800     std::vector<WeakTrackingVH> LLVMCompilerUsed = {
801         M.getGlobalVariable("__openmp_nvptx_data_transfer_temporary_storage")};
802     emitUsed("llvm.compiler.used", LLVMCompilerUsed);
803   }
804 }
805 
~OpenMPIRBuilder()806 OpenMPIRBuilder::~OpenMPIRBuilder() {
807   assert(OutlineInfos.empty() && "There must be no outstanding outlinings");
808 }
809 
createGlobalFlag(unsigned Value,StringRef Name)810 GlobalValue *OpenMPIRBuilder::createGlobalFlag(unsigned Value, StringRef Name) {
811   IntegerType *I32Ty = Type::getInt32Ty(M.getContext());
812   auto *GV =
813       new GlobalVariable(M, I32Ty,
814                          /* isConstant = */ true, GlobalValue::WeakODRLinkage,
815                          ConstantInt::get(I32Ty, Value), Name);
816   GV->setVisibility(GlobalValue::HiddenVisibility);
817 
818   return GV;
819 }
820 
getOrCreateIdent(Constant * SrcLocStr,uint32_t SrcLocStrSize,IdentFlag LocFlags,unsigned Reserve2Flags)821 Constant *OpenMPIRBuilder::getOrCreateIdent(Constant *SrcLocStr,
822                                             uint32_t SrcLocStrSize,
823                                             IdentFlag LocFlags,
824                                             unsigned Reserve2Flags) {
825   // Enable "C-mode".
826   LocFlags |= OMP_IDENT_FLAG_KMPC;
827 
828   Constant *&Ident =
829       IdentMap[{SrcLocStr, uint64_t(LocFlags) << 31 | Reserve2Flags}];
830   if (!Ident) {
831     Constant *I32Null = ConstantInt::getNullValue(Int32);
832     Constant *IdentData[] = {I32Null,
833                              ConstantInt::get(Int32, uint32_t(LocFlags)),
834                              ConstantInt::get(Int32, Reserve2Flags),
835                              ConstantInt::get(Int32, SrcLocStrSize), SrcLocStr};
836     Constant *Initializer =
837         ConstantStruct::get(OpenMPIRBuilder::Ident, IdentData);
838 
839     // Look for existing encoding of the location + flags, not needed but
840     // minimizes the difference to the existing solution while we transition.
841     for (GlobalVariable &GV : M.globals())
842       if (GV.getValueType() == OpenMPIRBuilder::Ident && GV.hasInitializer())
843         if (GV.getInitializer() == Initializer)
844           Ident = &GV;
845 
846     if (!Ident) {
847       auto *GV = new GlobalVariable(
848           M, OpenMPIRBuilder::Ident,
849           /* isConstant = */ true, GlobalValue::PrivateLinkage, Initializer, "",
850           nullptr, GlobalValue::NotThreadLocal,
851           M.getDataLayout().getDefaultGlobalsAddressSpace());
852       GV->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
853       GV->setAlignment(Align(8));
854       Ident = GV;
855     }
856   }
857 
858   return ConstantExpr::getPointerBitCastOrAddrSpaceCast(Ident, IdentPtr);
859 }
860 
getOrCreateSrcLocStr(StringRef LocStr,uint32_t & SrcLocStrSize)861 Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(StringRef LocStr,
862                                                 uint32_t &SrcLocStrSize) {
863   SrcLocStrSize = LocStr.size();
864   Constant *&SrcLocStr = SrcLocStrMap[LocStr];
865   if (!SrcLocStr) {
866     Constant *Initializer =
867         ConstantDataArray::getString(M.getContext(), LocStr);
868 
869     // Look for existing encoding of the location, not needed but minimizes the
870     // difference to the existing solution while we transition.
871     for (GlobalVariable &GV : M.globals())
872       if (GV.isConstant() && GV.hasInitializer() &&
873           GV.getInitializer() == Initializer)
874         return SrcLocStr = ConstantExpr::getPointerCast(&GV, Int8Ptr);
875 
876     SrcLocStr = Builder.CreateGlobalStringPtr(LocStr, /* Name */ "",
877                                               /* AddressSpace */ 0, &M);
878   }
879   return SrcLocStr;
880 }
881 
getOrCreateSrcLocStr(StringRef FunctionName,StringRef FileName,unsigned Line,unsigned Column,uint32_t & SrcLocStrSize)882 Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(StringRef FunctionName,
883                                                 StringRef FileName,
884                                                 unsigned Line, unsigned Column,
885                                                 uint32_t &SrcLocStrSize) {
886   SmallString<128> Buffer;
887   Buffer.push_back(';');
888   Buffer.append(FileName);
889   Buffer.push_back(';');
890   Buffer.append(FunctionName);
891   Buffer.push_back(';');
892   Buffer.append(std::to_string(Line));
893   Buffer.push_back(';');
894   Buffer.append(std::to_string(Column));
895   Buffer.push_back(';');
896   Buffer.push_back(';');
897   return getOrCreateSrcLocStr(Buffer.str(), SrcLocStrSize);
898 }
899 
900 Constant *
getOrCreateDefaultSrcLocStr(uint32_t & SrcLocStrSize)901 OpenMPIRBuilder::getOrCreateDefaultSrcLocStr(uint32_t &SrcLocStrSize) {
902   StringRef UnknownLoc = ";unknown;unknown;0;0;;";
903   return getOrCreateSrcLocStr(UnknownLoc, SrcLocStrSize);
904 }
905 
getOrCreateSrcLocStr(DebugLoc DL,uint32_t & SrcLocStrSize,Function * F)906 Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(DebugLoc DL,
907                                                 uint32_t &SrcLocStrSize,
908                                                 Function *F) {
909   DILocation *DIL = DL.get();
910   if (!DIL)
911     return getOrCreateDefaultSrcLocStr(SrcLocStrSize);
912   StringRef FileName = M.getName();
913   if (DIFile *DIF = DIL->getFile())
914     if (std::optional<StringRef> Source = DIF->getSource())
915       FileName = *Source;
916   StringRef Function = DIL->getScope()->getSubprogram()->getName();
917   if (Function.empty() && F)
918     Function = F->getName();
919   return getOrCreateSrcLocStr(Function, FileName, DIL->getLine(),
920                               DIL->getColumn(), SrcLocStrSize);
921 }
922 
getOrCreateSrcLocStr(const LocationDescription & Loc,uint32_t & SrcLocStrSize)923 Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(const LocationDescription &Loc,
924                                                 uint32_t &SrcLocStrSize) {
925   return getOrCreateSrcLocStr(Loc.DL, SrcLocStrSize,
926                               Loc.IP.getBlock()->getParent());
927 }
928 
getOrCreateThreadID(Value * Ident)929 Value *OpenMPIRBuilder::getOrCreateThreadID(Value *Ident) {
930   return Builder.CreateCall(
931       getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_global_thread_num), Ident,
932       "omp_global_thread_num");
933 }
934 
935 OpenMPIRBuilder::InsertPointTy
createBarrier(const LocationDescription & Loc,Directive Kind,bool ForceSimpleCall,bool CheckCancelFlag)936 OpenMPIRBuilder::createBarrier(const LocationDescription &Loc, Directive Kind,
937                                bool ForceSimpleCall, bool CheckCancelFlag) {
938   if (!updateToLocation(Loc))
939     return Loc.IP;
940 
941   // Build call __kmpc_cancel_barrier(loc, thread_id) or
942   //            __kmpc_barrier(loc, thread_id);
943 
944   IdentFlag BarrierLocFlags;
945   switch (Kind) {
946   case OMPD_for:
947     BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL_FOR;
948     break;
949   case OMPD_sections:
950     BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL_SECTIONS;
951     break;
952   case OMPD_single:
953     BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL_SINGLE;
954     break;
955   case OMPD_barrier:
956     BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_EXPL;
957     break;
958   default:
959     BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL;
960     break;
961   }
962 
963   uint32_t SrcLocStrSize;
964   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
965   Value *Args[] = {
966       getOrCreateIdent(SrcLocStr, SrcLocStrSize, BarrierLocFlags),
967       getOrCreateThreadID(getOrCreateIdent(SrcLocStr, SrcLocStrSize))};
968 
969   // If we are in a cancellable parallel region, barriers are cancellation
970   // points.
971   // TODO: Check why we would force simple calls or to ignore the cancel flag.
972   bool UseCancelBarrier =
973       !ForceSimpleCall && isLastFinalizationInfoCancellable(OMPD_parallel);
974 
975   Value *Result =
976       Builder.CreateCall(getOrCreateRuntimeFunctionPtr(
977                              UseCancelBarrier ? OMPRTL___kmpc_cancel_barrier
978                                               : OMPRTL___kmpc_barrier),
979                          Args);
980 
981   if (UseCancelBarrier && CheckCancelFlag)
982     emitCancelationCheckImpl(Result, OMPD_parallel);
983 
984   return Builder.saveIP();
985 }
986 
987 OpenMPIRBuilder::InsertPointTy
createCancel(const LocationDescription & Loc,Value * IfCondition,omp::Directive CanceledDirective)988 OpenMPIRBuilder::createCancel(const LocationDescription &Loc,
989                               Value *IfCondition,
990                               omp::Directive CanceledDirective) {
991   if (!updateToLocation(Loc))
992     return Loc.IP;
993 
994   // LLVM utilities like blocks with terminators.
995   auto *UI = Builder.CreateUnreachable();
996 
997   Instruction *ThenTI = UI, *ElseTI = nullptr;
998   if (IfCondition)
999     SplitBlockAndInsertIfThenElse(IfCondition, UI, &ThenTI, &ElseTI);
1000   Builder.SetInsertPoint(ThenTI);
1001 
1002   Value *CancelKind = nullptr;
1003   switch (CanceledDirective) {
1004 #define OMP_CANCEL_KIND(Enum, Str, DirectiveEnum, Value)                       \
1005   case DirectiveEnum:                                                          \
1006     CancelKind = Builder.getInt32(Value);                                      \
1007     break;
1008 #include "llvm/Frontend/OpenMP/OMPKinds.def"
1009   default:
1010     llvm_unreachable("Unknown cancel kind!");
1011   }
1012 
1013   uint32_t SrcLocStrSize;
1014   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1015   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1016   Value *Args[] = {Ident, getOrCreateThreadID(Ident), CancelKind};
1017   Value *Result = Builder.CreateCall(
1018       getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_cancel), Args);
1019   auto ExitCB = [this, CanceledDirective, Loc](InsertPointTy IP) {
1020     if (CanceledDirective == OMPD_parallel) {
1021       IRBuilder<>::InsertPointGuard IPG(Builder);
1022       Builder.restoreIP(IP);
1023       createBarrier(LocationDescription(Builder.saveIP(), Loc.DL),
1024                     omp::Directive::OMPD_unknown, /* ForceSimpleCall */ false,
1025                     /* CheckCancelFlag */ false);
1026     }
1027   };
1028 
1029   // The actual cancel logic is shared with others, e.g., cancel_barriers.
1030   emitCancelationCheckImpl(Result, CanceledDirective, ExitCB);
1031 
1032   // Update the insertion point and remove the terminator we introduced.
1033   Builder.SetInsertPoint(UI->getParent());
1034   UI->eraseFromParent();
1035 
1036   return Builder.saveIP();
1037 }
1038 
emitTargetKernel(const LocationDescription & Loc,InsertPointTy AllocaIP,Value * & Return,Value * Ident,Value * DeviceID,Value * NumTeams,Value * NumThreads,Value * HostPtr,ArrayRef<Value * > KernelArgs)1039 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitTargetKernel(
1040     const LocationDescription &Loc, InsertPointTy AllocaIP, Value *&Return,
1041     Value *Ident, Value *DeviceID, Value *NumTeams, Value *NumThreads,
1042     Value *HostPtr, ArrayRef<Value *> KernelArgs) {
1043   if (!updateToLocation(Loc))
1044     return Loc.IP;
1045 
1046   Builder.restoreIP(AllocaIP);
1047   auto *KernelArgsPtr =
1048       Builder.CreateAlloca(OpenMPIRBuilder::KernelArgs, nullptr, "kernel_args");
1049   Builder.restoreIP(Loc.IP);
1050 
1051   for (unsigned I = 0, Size = KernelArgs.size(); I != Size; ++I) {
1052     llvm::Value *Arg =
1053         Builder.CreateStructGEP(OpenMPIRBuilder::KernelArgs, KernelArgsPtr, I);
1054     Builder.CreateAlignedStore(
1055         KernelArgs[I], Arg,
1056         M.getDataLayout().getPrefTypeAlign(KernelArgs[I]->getType()));
1057   }
1058 
1059   SmallVector<Value *> OffloadingArgs{Ident,      DeviceID, NumTeams,
1060                                       NumThreads, HostPtr,  KernelArgsPtr};
1061 
1062   Return = Builder.CreateCall(
1063       getOrCreateRuntimeFunction(M, OMPRTL___tgt_target_kernel),
1064       OffloadingArgs);
1065 
1066   return Builder.saveIP();
1067 }
1068 
emitKernelLaunch(const LocationDescription & Loc,Function * OutlinedFn,Value * OutlinedFnID,EmitFallbackCallbackTy emitTargetCallFallbackCB,TargetKernelArgs & Args,Value * DeviceID,Value * RTLoc,InsertPointTy AllocaIP)1069 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitKernelLaunch(
1070     const LocationDescription &Loc, Function *OutlinedFn, Value *OutlinedFnID,
1071     EmitFallbackCallbackTy emitTargetCallFallbackCB, TargetKernelArgs &Args,
1072     Value *DeviceID, Value *RTLoc, InsertPointTy AllocaIP) {
1073 
1074   if (!updateToLocation(Loc))
1075     return Loc.IP;
1076 
1077   Builder.restoreIP(Loc.IP);
1078   // On top of the arrays that were filled up, the target offloading call
1079   // takes as arguments the device id as well as the host pointer. The host
1080   // pointer is used by the runtime library to identify the current target
1081   // region, so it only has to be unique and not necessarily point to
1082   // anything. It could be the pointer to the outlined function that
1083   // implements the target region, but we aren't using that so that the
1084   // compiler doesn't need to keep that, and could therefore inline the host
1085   // function if proven worthwhile during optimization.
1086 
1087   // From this point on, we need to have an ID of the target region defined.
1088   assert(OutlinedFnID && "Invalid outlined function ID!");
1089   (void)OutlinedFnID;
1090 
1091   // Return value of the runtime offloading call.
1092   Value *Return = nullptr;
1093 
1094   // Arguments for the target kernel.
1095   SmallVector<Value *> ArgsVector;
1096   getKernelArgsVector(Args, Builder, ArgsVector);
1097 
1098   // The target region is an outlined function launched by the runtime
1099   // via calls to __tgt_target_kernel().
1100   //
1101   // Note that on the host and CPU targets, the runtime implementation of
1102   // these calls simply call the outlined function without forking threads.
1103   // The outlined functions themselves have runtime calls to
1104   // __kmpc_fork_teams() and __kmpc_fork() for this purpose, codegen'd by
1105   // the compiler in emitTeamsCall() and emitParallelCall().
1106   //
1107   // In contrast, on the NVPTX target, the implementation of
1108   // __tgt_target_teams() launches a GPU kernel with the requested number
1109   // of teams and threads so no additional calls to the runtime are required.
1110   // Check the error code and execute the host version if required.
1111   Builder.restoreIP(emitTargetKernel(Builder, AllocaIP, Return, RTLoc, DeviceID,
1112                                      Args.NumTeams, Args.NumThreads,
1113                                      OutlinedFnID, ArgsVector));
1114 
1115   BasicBlock *OffloadFailedBlock =
1116       BasicBlock::Create(Builder.getContext(), "omp_offload.failed");
1117   BasicBlock *OffloadContBlock =
1118       BasicBlock::Create(Builder.getContext(), "omp_offload.cont");
1119   Value *Failed = Builder.CreateIsNotNull(Return);
1120   Builder.CreateCondBr(Failed, OffloadFailedBlock, OffloadContBlock);
1121 
1122   auto CurFn = Builder.GetInsertBlock()->getParent();
1123   emitBlock(OffloadFailedBlock, CurFn);
1124   Builder.restoreIP(emitTargetCallFallbackCB(Builder.saveIP()));
1125   emitBranch(OffloadContBlock);
1126   emitBlock(OffloadContBlock, CurFn, /*IsFinished=*/true);
1127   return Builder.saveIP();
1128 }
1129 
emitCancelationCheckImpl(Value * CancelFlag,omp::Directive CanceledDirective,FinalizeCallbackTy ExitCB)1130 void OpenMPIRBuilder::emitCancelationCheckImpl(Value *CancelFlag,
1131                                                omp::Directive CanceledDirective,
1132                                                FinalizeCallbackTy ExitCB) {
1133   assert(isLastFinalizationInfoCancellable(CanceledDirective) &&
1134          "Unexpected cancellation!");
1135 
1136   // For a cancel barrier we create two new blocks.
1137   BasicBlock *BB = Builder.GetInsertBlock();
1138   BasicBlock *NonCancellationBlock;
1139   if (Builder.GetInsertPoint() == BB->end()) {
1140     // TODO: This branch will not be needed once we moved to the
1141     // OpenMPIRBuilder codegen completely.
1142     NonCancellationBlock = BasicBlock::Create(
1143         BB->getContext(), BB->getName() + ".cont", BB->getParent());
1144   } else {
1145     NonCancellationBlock = SplitBlock(BB, &*Builder.GetInsertPoint());
1146     BB->getTerminator()->eraseFromParent();
1147     Builder.SetInsertPoint(BB);
1148   }
1149   BasicBlock *CancellationBlock = BasicBlock::Create(
1150       BB->getContext(), BB->getName() + ".cncl", BB->getParent());
1151 
1152   // Jump to them based on the return value.
1153   Value *Cmp = Builder.CreateIsNull(CancelFlag);
1154   Builder.CreateCondBr(Cmp, NonCancellationBlock, CancellationBlock,
1155                        /* TODO weight */ nullptr, nullptr);
1156 
1157   // From the cancellation block we finalize all variables and go to the
1158   // post finalization block that is known to the FiniCB callback.
1159   Builder.SetInsertPoint(CancellationBlock);
1160   if (ExitCB)
1161     ExitCB(Builder.saveIP());
1162   auto &FI = FinalizationStack.back();
1163   FI.FiniCB(Builder.saveIP());
1164 
1165   // The continuation block is where code generation continues.
1166   Builder.SetInsertPoint(NonCancellationBlock, NonCancellationBlock->begin());
1167 }
1168 
1169 // Callback used to create OpenMP runtime calls to support
1170 // omp parallel clause for the device.
1171 // We need to use this callback to replace call to the OutlinedFn in OuterFn
1172 // by the call to the OpenMP DeviceRTL runtime function (kmpc_parallel_51)
targetParallelCallback(OpenMPIRBuilder * OMPIRBuilder,Function & OutlinedFn,Function * OuterFn,BasicBlock * OuterAllocaBB,Value * Ident,Value * IfCondition,Value * NumThreads,Instruction * PrivTID,AllocaInst * PrivTIDAddr,Value * ThreadID,const SmallVector<Instruction *,4> & ToBeDeleted)1173 static void targetParallelCallback(
1174     OpenMPIRBuilder *OMPIRBuilder, Function &OutlinedFn, Function *OuterFn,
1175     BasicBlock *OuterAllocaBB, Value *Ident, Value *IfCondition,
1176     Value *NumThreads, Instruction *PrivTID, AllocaInst *PrivTIDAddr,
1177     Value *ThreadID, const SmallVector<Instruction *, 4> &ToBeDeleted) {
1178   // Add some known attributes.
1179   IRBuilder<> &Builder = OMPIRBuilder->Builder;
1180   OutlinedFn.addParamAttr(0, Attribute::NoAlias);
1181   OutlinedFn.addParamAttr(1, Attribute::NoAlias);
1182   OutlinedFn.addParamAttr(0, Attribute::NoUndef);
1183   OutlinedFn.addParamAttr(1, Attribute::NoUndef);
1184   OutlinedFn.addFnAttr(Attribute::NoUnwind);
1185 
1186   assert(OutlinedFn.arg_size() >= 2 &&
1187          "Expected at least tid and bounded tid as arguments");
1188   unsigned NumCapturedVars = OutlinedFn.arg_size() - /* tid & bounded tid */ 2;
1189 
1190   CallInst *CI = cast<CallInst>(OutlinedFn.user_back());
1191   assert(CI && "Expected call instruction to outlined function");
1192   CI->getParent()->setName("omp_parallel");
1193 
1194   Builder.SetInsertPoint(CI);
1195   Type *PtrTy = OMPIRBuilder->VoidPtr;
1196   Value *NullPtrValue = Constant::getNullValue(PtrTy);
1197 
1198   // Add alloca for kernel args
1199   OpenMPIRBuilder ::InsertPointTy CurrentIP = Builder.saveIP();
1200   Builder.SetInsertPoint(OuterAllocaBB, OuterAllocaBB->getFirstInsertionPt());
1201   AllocaInst *ArgsAlloca =
1202       Builder.CreateAlloca(ArrayType::get(PtrTy, NumCapturedVars));
1203   Value *Args = ArgsAlloca;
1204   // Add address space cast if array for storing arguments is not allocated
1205   // in address space 0
1206   if (ArgsAlloca->getAddressSpace())
1207     Args = Builder.CreatePointerCast(ArgsAlloca, PtrTy);
1208   Builder.restoreIP(CurrentIP);
1209 
1210   // Store captured vars which are used by kmpc_parallel_51
1211   for (unsigned Idx = 0; Idx < NumCapturedVars; Idx++) {
1212     Value *V = *(CI->arg_begin() + 2 + Idx);
1213     Value *StoreAddress = Builder.CreateConstInBoundsGEP2_64(
1214         ArrayType::get(PtrTy, NumCapturedVars), Args, 0, Idx);
1215     Builder.CreateStore(V, StoreAddress);
1216   }
1217 
1218   Value *Cond =
1219       IfCondition ? Builder.CreateSExtOrTrunc(IfCondition, OMPIRBuilder->Int32)
1220                   : Builder.getInt32(1);
1221 
1222   // Build kmpc_parallel_51 call
1223   Value *Parallel51CallArgs[] = {
1224       /* identifier*/ Ident,
1225       /* global thread num*/ ThreadID,
1226       /* if expression */ Cond,
1227       /* number of threads */ NumThreads ? NumThreads : Builder.getInt32(-1),
1228       /* Proc bind */ Builder.getInt32(-1),
1229       /* outlined function */
1230       Builder.CreateBitCast(&OutlinedFn, OMPIRBuilder->ParallelTaskPtr),
1231       /* wrapper function */ NullPtrValue,
1232       /* arguments of the outlined funciton*/ Args,
1233       /* number of arguments */ Builder.getInt64(NumCapturedVars)};
1234 
1235   FunctionCallee RTLFn =
1236       OMPIRBuilder->getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_parallel_51);
1237 
1238   Builder.CreateCall(RTLFn, Parallel51CallArgs);
1239 
1240   LLVM_DEBUG(dbgs() << "With kmpc_parallel_51 placed: "
1241                     << *Builder.GetInsertBlock()->getParent() << "\n");
1242 
1243   // Initialize the local TID stack location with the argument value.
1244   Builder.SetInsertPoint(PrivTID);
1245   Function::arg_iterator OutlinedAI = OutlinedFn.arg_begin();
1246   Builder.CreateStore(Builder.CreateLoad(OMPIRBuilder->Int32, OutlinedAI),
1247                       PrivTIDAddr);
1248 
1249   // Remove redundant call to the outlined function.
1250   CI->eraseFromParent();
1251 
1252   for (Instruction *I : ToBeDeleted) {
1253     I->eraseFromParent();
1254   }
1255 }
1256 
1257 // Callback used to create OpenMP runtime calls to support
1258 // omp parallel clause for the host.
1259 // We need to use this callback to replace call to the OutlinedFn in OuterFn
1260 // by the call to the OpenMP host runtime function ( __kmpc_fork_call[_if])
1261 static void
hostParallelCallback(OpenMPIRBuilder * OMPIRBuilder,Function & OutlinedFn,Function * OuterFn,Value * Ident,Value * IfCondition,Instruction * PrivTID,AllocaInst * PrivTIDAddr,const SmallVector<Instruction *,4> & ToBeDeleted)1262 hostParallelCallback(OpenMPIRBuilder *OMPIRBuilder, Function &OutlinedFn,
1263                      Function *OuterFn, Value *Ident, Value *IfCondition,
1264                      Instruction *PrivTID, AllocaInst *PrivTIDAddr,
1265                      const SmallVector<Instruction *, 4> &ToBeDeleted) {
1266   IRBuilder<> &Builder = OMPIRBuilder->Builder;
1267   FunctionCallee RTLFn;
1268   if (IfCondition) {
1269     RTLFn =
1270         OMPIRBuilder->getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_call_if);
1271   } else {
1272     RTLFn =
1273         OMPIRBuilder->getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_call);
1274   }
1275   if (auto *F = dyn_cast<Function>(RTLFn.getCallee())) {
1276     if (!F->hasMetadata(LLVMContext::MD_callback)) {
1277       LLVMContext &Ctx = F->getContext();
1278       MDBuilder MDB(Ctx);
1279       // Annotate the callback behavior of the __kmpc_fork_call:
1280       //  - The callback callee is argument number 2 (microtask).
1281       //  - The first two arguments of the callback callee are unknown (-1).
1282       //  - All variadic arguments to the __kmpc_fork_call are passed to the
1283       //    callback callee.
1284       F->addMetadata(LLVMContext::MD_callback,
1285                      *MDNode::get(Ctx, {MDB.createCallbackEncoding(
1286                                            2, {-1, -1},
1287                                            /* VarArgsArePassed */ true)}));
1288     }
1289   }
1290   // Add some known attributes.
1291   OutlinedFn.addParamAttr(0, Attribute::NoAlias);
1292   OutlinedFn.addParamAttr(1, Attribute::NoAlias);
1293   OutlinedFn.addFnAttr(Attribute::NoUnwind);
1294 
1295   assert(OutlinedFn.arg_size() >= 2 &&
1296          "Expected at least tid and bounded tid as arguments");
1297   unsigned NumCapturedVars = OutlinedFn.arg_size() - /* tid & bounded tid */ 2;
1298 
1299   CallInst *CI = cast<CallInst>(OutlinedFn.user_back());
1300   CI->getParent()->setName("omp_parallel");
1301   Builder.SetInsertPoint(CI);
1302 
1303   // Build call __kmpc_fork_call[_if](Ident, n, microtask, var1, .., varn);
1304   Value *ForkCallArgs[] = {
1305       Ident, Builder.getInt32(NumCapturedVars),
1306       Builder.CreateBitCast(&OutlinedFn, OMPIRBuilder->ParallelTaskPtr)};
1307 
1308   SmallVector<Value *, 16> RealArgs;
1309   RealArgs.append(std::begin(ForkCallArgs), std::end(ForkCallArgs));
1310   if (IfCondition) {
1311     Value *Cond = Builder.CreateSExtOrTrunc(IfCondition, OMPIRBuilder->Int32);
1312     RealArgs.push_back(Cond);
1313   }
1314   RealArgs.append(CI->arg_begin() + /* tid & bound tid */ 2, CI->arg_end());
1315 
1316   // __kmpc_fork_call_if always expects a void ptr as the last argument
1317   // If there are no arguments, pass a null pointer.
1318   auto PtrTy = OMPIRBuilder->VoidPtr;
1319   if (IfCondition && NumCapturedVars == 0) {
1320     Value *NullPtrValue = Constant::getNullValue(PtrTy);
1321     RealArgs.push_back(NullPtrValue);
1322   }
1323   if (IfCondition && RealArgs.back()->getType() != PtrTy)
1324     RealArgs.back() = Builder.CreateBitCast(RealArgs.back(), PtrTy);
1325 
1326   Builder.CreateCall(RTLFn, RealArgs);
1327 
1328   LLVM_DEBUG(dbgs() << "With fork_call placed: "
1329                     << *Builder.GetInsertBlock()->getParent() << "\n");
1330 
1331   // Initialize the local TID stack location with the argument value.
1332   Builder.SetInsertPoint(PrivTID);
1333   Function::arg_iterator OutlinedAI = OutlinedFn.arg_begin();
1334   Builder.CreateStore(Builder.CreateLoad(OMPIRBuilder->Int32, OutlinedAI),
1335                       PrivTIDAddr);
1336 
1337   // Remove redundant call to the outlined function.
1338   CI->eraseFromParent();
1339 
1340   for (Instruction *I : ToBeDeleted) {
1341     I->eraseFromParent();
1342   }
1343 }
1344 
createParallel(const LocationDescription & Loc,InsertPointTy OuterAllocaIP,BodyGenCallbackTy BodyGenCB,PrivatizeCallbackTy PrivCB,FinalizeCallbackTy FiniCB,Value * IfCondition,Value * NumThreads,omp::ProcBindKind ProcBind,bool IsCancellable)1345 IRBuilder<>::InsertPoint OpenMPIRBuilder::createParallel(
1346     const LocationDescription &Loc, InsertPointTy OuterAllocaIP,
1347     BodyGenCallbackTy BodyGenCB, PrivatizeCallbackTy PrivCB,
1348     FinalizeCallbackTy FiniCB, Value *IfCondition, Value *NumThreads,
1349     omp::ProcBindKind ProcBind, bool IsCancellable) {
1350   assert(!isConflictIP(Loc.IP, OuterAllocaIP) && "IPs must not be ambiguous");
1351 
1352   if (!updateToLocation(Loc))
1353     return Loc.IP;
1354 
1355   uint32_t SrcLocStrSize;
1356   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1357   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1358   Value *ThreadID = getOrCreateThreadID(Ident);
1359   // If we generate code for the target device, we need to allocate
1360   // struct for aggregate params in the device default alloca address space.
1361   // OpenMP runtime requires that the params of the extracted functions are
1362   // passed as zero address space pointers. This flag ensures that extracted
1363   // function arguments are declared in zero address space
1364   bool ArgsInZeroAddressSpace = Config.isTargetDevice();
1365 
1366   // Build call __kmpc_push_num_threads(&Ident, global_tid, num_threads)
1367   // only if we compile for host side.
1368   if (NumThreads && !Config.isTargetDevice()) {
1369     Value *Args[] = {
1370         Ident, ThreadID,
1371         Builder.CreateIntCast(NumThreads, Int32, /*isSigned*/ false)};
1372     Builder.CreateCall(
1373         getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_num_threads), Args);
1374   }
1375 
1376   if (ProcBind != OMP_PROC_BIND_default) {
1377     // Build call __kmpc_push_proc_bind(&Ident, global_tid, proc_bind)
1378     Value *Args[] = {
1379         Ident, ThreadID,
1380         ConstantInt::get(Int32, unsigned(ProcBind), /*isSigned=*/true)};
1381     Builder.CreateCall(
1382         getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_proc_bind), Args);
1383   }
1384 
1385   BasicBlock *InsertBB = Builder.GetInsertBlock();
1386   Function *OuterFn = InsertBB->getParent();
1387 
1388   // Save the outer alloca block because the insertion iterator may get
1389   // invalidated and we still need this later.
1390   BasicBlock *OuterAllocaBlock = OuterAllocaIP.getBlock();
1391 
1392   // Vector to remember instructions we used only during the modeling but which
1393   // we want to delete at the end.
1394   SmallVector<Instruction *, 4> ToBeDeleted;
1395 
1396   // Change the location to the outer alloca insertion point to create and
1397   // initialize the allocas we pass into the parallel region.
1398   InsertPointTy NewOuter(OuterAllocaBlock, OuterAllocaBlock->begin());
1399   Builder.restoreIP(NewOuter);
1400   AllocaInst *TIDAddrAlloca = Builder.CreateAlloca(Int32, nullptr, "tid.addr");
1401   AllocaInst *ZeroAddrAlloca =
1402       Builder.CreateAlloca(Int32, nullptr, "zero.addr");
1403   Instruction *TIDAddr = TIDAddrAlloca;
1404   Instruction *ZeroAddr = ZeroAddrAlloca;
1405   if (ArgsInZeroAddressSpace && M.getDataLayout().getAllocaAddrSpace() != 0) {
1406     // Add additional casts to enforce pointers in zero address space
1407     TIDAddr = new AddrSpaceCastInst(
1408         TIDAddrAlloca, PointerType ::get(M.getContext(), 0), "tid.addr.ascast");
1409     TIDAddr->insertAfter(TIDAddrAlloca);
1410     ToBeDeleted.push_back(TIDAddr);
1411     ZeroAddr = new AddrSpaceCastInst(ZeroAddrAlloca,
1412                                      PointerType ::get(M.getContext(), 0),
1413                                      "zero.addr.ascast");
1414     ZeroAddr->insertAfter(ZeroAddrAlloca);
1415     ToBeDeleted.push_back(ZeroAddr);
1416   }
1417 
1418   // We only need TIDAddr and ZeroAddr for modeling purposes to get the
1419   // associated arguments in the outlined function, so we delete them later.
1420   ToBeDeleted.push_back(TIDAddrAlloca);
1421   ToBeDeleted.push_back(ZeroAddrAlloca);
1422 
1423   // Create an artificial insertion point that will also ensure the blocks we
1424   // are about to split are not degenerated.
1425   auto *UI = new UnreachableInst(Builder.getContext(), InsertBB);
1426 
1427   BasicBlock *EntryBB = UI->getParent();
1428   BasicBlock *PRegEntryBB = EntryBB->splitBasicBlock(UI, "omp.par.entry");
1429   BasicBlock *PRegBodyBB = PRegEntryBB->splitBasicBlock(UI, "omp.par.region");
1430   BasicBlock *PRegPreFiniBB =
1431       PRegBodyBB->splitBasicBlock(UI, "omp.par.pre_finalize");
1432   BasicBlock *PRegExitBB = PRegPreFiniBB->splitBasicBlock(UI, "omp.par.exit");
1433 
1434   auto FiniCBWrapper = [&](InsertPointTy IP) {
1435     // Hide "open-ended" blocks from the given FiniCB by setting the right jump
1436     // target to the region exit block.
1437     if (IP.getBlock()->end() == IP.getPoint()) {
1438       IRBuilder<>::InsertPointGuard IPG(Builder);
1439       Builder.restoreIP(IP);
1440       Instruction *I = Builder.CreateBr(PRegExitBB);
1441       IP = InsertPointTy(I->getParent(), I->getIterator());
1442     }
1443     assert(IP.getBlock()->getTerminator()->getNumSuccessors() == 1 &&
1444            IP.getBlock()->getTerminator()->getSuccessor(0) == PRegExitBB &&
1445            "Unexpected insertion point for finalization call!");
1446     return FiniCB(IP);
1447   };
1448 
1449   FinalizationStack.push_back({FiniCBWrapper, OMPD_parallel, IsCancellable});
1450 
1451   // Generate the privatization allocas in the block that will become the entry
1452   // of the outlined function.
1453   Builder.SetInsertPoint(PRegEntryBB->getTerminator());
1454   InsertPointTy InnerAllocaIP = Builder.saveIP();
1455 
1456   AllocaInst *PrivTIDAddr =
1457       Builder.CreateAlloca(Int32, nullptr, "tid.addr.local");
1458   Instruction *PrivTID = Builder.CreateLoad(Int32, PrivTIDAddr, "tid");
1459 
1460   // Add some fake uses for OpenMP provided arguments.
1461   ToBeDeleted.push_back(Builder.CreateLoad(Int32, TIDAddr, "tid.addr.use"));
1462   Instruction *ZeroAddrUse =
1463       Builder.CreateLoad(Int32, ZeroAddr, "zero.addr.use");
1464   ToBeDeleted.push_back(ZeroAddrUse);
1465 
1466   // EntryBB
1467   //   |
1468   //   V
1469   // PRegionEntryBB         <- Privatization allocas are placed here.
1470   //   |
1471   //   V
1472   // PRegionBodyBB          <- BodeGen is invoked here.
1473   //   |
1474   //   V
1475   // PRegPreFiniBB          <- The block we will start finalization from.
1476   //   |
1477   //   V
1478   // PRegionExitBB          <- A common exit to simplify block collection.
1479   //
1480 
1481   LLVM_DEBUG(dbgs() << "Before body codegen: " << *OuterFn << "\n");
1482 
1483   // Let the caller create the body.
1484   assert(BodyGenCB && "Expected body generation callback!");
1485   InsertPointTy CodeGenIP(PRegBodyBB, PRegBodyBB->begin());
1486   BodyGenCB(InnerAllocaIP, CodeGenIP);
1487 
1488   LLVM_DEBUG(dbgs() << "After  body codegen: " << *OuterFn << "\n");
1489 
1490   OutlineInfo OI;
1491   if (Config.isTargetDevice()) {
1492     // Generate OpenMP target specific runtime call
1493     OI.PostOutlineCB = [=, ToBeDeletedVec =
1494                                std::move(ToBeDeleted)](Function &OutlinedFn) {
1495       targetParallelCallback(this, OutlinedFn, OuterFn, OuterAllocaBlock, Ident,
1496                              IfCondition, NumThreads, PrivTID, PrivTIDAddr,
1497                              ThreadID, ToBeDeletedVec);
1498     };
1499   } else {
1500     // Generate OpenMP host runtime call
1501     OI.PostOutlineCB = [=, ToBeDeletedVec =
1502                                std::move(ToBeDeleted)](Function &OutlinedFn) {
1503       hostParallelCallback(this, OutlinedFn, OuterFn, Ident, IfCondition,
1504                            PrivTID, PrivTIDAddr, ToBeDeletedVec);
1505     };
1506   }
1507 
1508   OI.OuterAllocaBB = OuterAllocaBlock;
1509   OI.EntryBB = PRegEntryBB;
1510   OI.ExitBB = PRegExitBB;
1511 
1512   SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
1513   SmallVector<BasicBlock *, 32> Blocks;
1514   OI.collectBlocks(ParallelRegionBlockSet, Blocks);
1515 
1516   // Ensure a single exit node for the outlined region by creating one.
1517   // We might have multiple incoming edges to the exit now due to finalizations,
1518   // e.g., cancel calls that cause the control flow to leave the region.
1519   BasicBlock *PRegOutlinedExitBB = PRegExitBB;
1520   PRegExitBB = SplitBlock(PRegExitBB, &*PRegExitBB->getFirstInsertionPt());
1521   PRegOutlinedExitBB->setName("omp.par.outlined.exit");
1522   Blocks.push_back(PRegOutlinedExitBB);
1523 
1524   CodeExtractorAnalysisCache CEAC(*OuterFn);
1525   CodeExtractor Extractor(Blocks, /* DominatorTree */ nullptr,
1526                           /* AggregateArgs */ false,
1527                           /* BlockFrequencyInfo */ nullptr,
1528                           /* BranchProbabilityInfo */ nullptr,
1529                           /* AssumptionCache */ nullptr,
1530                           /* AllowVarArgs */ true,
1531                           /* AllowAlloca */ true,
1532                           /* AllocationBlock */ OuterAllocaBlock,
1533                           /* Suffix */ ".omp_par", ArgsInZeroAddressSpace);
1534 
1535   // Find inputs to, outputs from the code region.
1536   BasicBlock *CommonExit = nullptr;
1537   SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
1538   Extractor.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
1539   Extractor.findInputsOutputs(Inputs, Outputs, SinkingCands);
1540 
1541   LLVM_DEBUG(dbgs() << "Before privatization: " << *OuterFn << "\n");
1542 
1543   FunctionCallee TIDRTLFn =
1544       getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_global_thread_num);
1545 
1546   auto PrivHelper = [&](Value &V) {
1547     if (&V == TIDAddr || &V == ZeroAddr) {
1548       OI.ExcludeArgsFromAggregate.push_back(&V);
1549       return;
1550     }
1551 
1552     SetVector<Use *> Uses;
1553     for (Use &U : V.uses())
1554       if (auto *UserI = dyn_cast<Instruction>(U.getUser()))
1555         if (ParallelRegionBlockSet.count(UserI->getParent()))
1556           Uses.insert(&U);
1557 
1558     // __kmpc_fork_call expects extra arguments as pointers. If the input
1559     // already has a pointer type, everything is fine. Otherwise, store the
1560     // value onto stack and load it back inside the to-be-outlined region. This
1561     // will ensure only the pointer will be passed to the function.
1562     // FIXME: if there are more than 15 trailing arguments, they must be
1563     // additionally packed in a struct.
1564     Value *Inner = &V;
1565     if (!V.getType()->isPointerTy()) {
1566       IRBuilder<>::InsertPointGuard Guard(Builder);
1567       LLVM_DEBUG(llvm::dbgs() << "Forwarding input as pointer: " << V << "\n");
1568 
1569       Builder.restoreIP(OuterAllocaIP);
1570       Value *Ptr =
1571           Builder.CreateAlloca(V.getType(), nullptr, V.getName() + ".reloaded");
1572 
1573       // Store to stack at end of the block that currently branches to the entry
1574       // block of the to-be-outlined region.
1575       Builder.SetInsertPoint(InsertBB,
1576                              InsertBB->getTerminator()->getIterator());
1577       Builder.CreateStore(&V, Ptr);
1578 
1579       // Load back next to allocations in the to-be-outlined region.
1580       Builder.restoreIP(InnerAllocaIP);
1581       Inner = Builder.CreateLoad(V.getType(), Ptr);
1582     }
1583 
1584     Value *ReplacementValue = nullptr;
1585     CallInst *CI = dyn_cast<CallInst>(&V);
1586     if (CI && CI->getCalledFunction() == TIDRTLFn.getCallee()) {
1587       ReplacementValue = PrivTID;
1588     } else {
1589       Builder.restoreIP(
1590           PrivCB(InnerAllocaIP, Builder.saveIP(), V, *Inner, ReplacementValue));
1591       InnerAllocaIP = {
1592           InnerAllocaIP.getBlock(),
1593           InnerAllocaIP.getBlock()->getTerminator()->getIterator()};
1594 
1595       assert(ReplacementValue &&
1596              "Expected copy/create callback to set replacement value!");
1597       if (ReplacementValue == &V)
1598         return;
1599     }
1600 
1601     for (Use *UPtr : Uses)
1602       UPtr->set(ReplacementValue);
1603   };
1604 
1605   // Reset the inner alloca insertion as it will be used for loading the values
1606   // wrapped into pointers before passing them into the to-be-outlined region.
1607   // Configure it to insert immediately after the fake use of zero address so
1608   // that they are available in the generated body and so that the
1609   // OpenMP-related values (thread ID and zero address pointers) remain leading
1610   // in the argument list.
1611   InnerAllocaIP = IRBuilder<>::InsertPoint(
1612       ZeroAddrUse->getParent(), ZeroAddrUse->getNextNode()->getIterator());
1613 
1614   // Reset the outer alloca insertion point to the entry of the relevant block
1615   // in case it was invalidated.
1616   OuterAllocaIP = IRBuilder<>::InsertPoint(
1617       OuterAllocaBlock, OuterAllocaBlock->getFirstInsertionPt());
1618 
1619   for (Value *Input : Inputs) {
1620     LLVM_DEBUG(dbgs() << "Captured input: " << *Input << "\n");
1621     PrivHelper(*Input);
1622   }
1623   LLVM_DEBUG({
1624     for (Value *Output : Outputs)
1625       LLVM_DEBUG(dbgs() << "Captured output: " << *Output << "\n");
1626   });
1627   assert(Outputs.empty() &&
1628          "OpenMP outlining should not produce live-out values!");
1629 
1630   LLVM_DEBUG(dbgs() << "After  privatization: " << *OuterFn << "\n");
1631   LLVM_DEBUG({
1632     for (auto *BB : Blocks)
1633       dbgs() << " PBR: " << BB->getName() << "\n";
1634   });
1635 
1636   // Adjust the finalization stack, verify the adjustment, and call the
1637   // finalize function a last time to finalize values between the pre-fini
1638   // block and the exit block if we left the parallel "the normal way".
1639   auto FiniInfo = FinalizationStack.pop_back_val();
1640   (void)FiniInfo;
1641   assert(FiniInfo.DK == OMPD_parallel &&
1642          "Unexpected finalization stack state!");
1643 
1644   Instruction *PRegPreFiniTI = PRegPreFiniBB->getTerminator();
1645 
1646   InsertPointTy PreFiniIP(PRegPreFiniBB, PRegPreFiniTI->getIterator());
1647   FiniCB(PreFiniIP);
1648 
1649   // Register the outlined info.
1650   addOutlineInfo(std::move(OI));
1651 
1652   InsertPointTy AfterIP(UI->getParent(), UI->getParent()->end());
1653   UI->eraseFromParent();
1654 
1655   return AfterIP;
1656 }
1657 
emitFlush(const LocationDescription & Loc)1658 void OpenMPIRBuilder::emitFlush(const LocationDescription &Loc) {
1659   // Build call void __kmpc_flush(ident_t *loc)
1660   uint32_t SrcLocStrSize;
1661   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1662   Value *Args[] = {getOrCreateIdent(SrcLocStr, SrcLocStrSize)};
1663 
1664   Builder.CreateCall(getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_flush), Args);
1665 }
1666 
createFlush(const LocationDescription & Loc)1667 void OpenMPIRBuilder::createFlush(const LocationDescription &Loc) {
1668   if (!updateToLocation(Loc))
1669     return;
1670   emitFlush(Loc);
1671 }
1672 
emitTaskwaitImpl(const LocationDescription & Loc)1673 void OpenMPIRBuilder::emitTaskwaitImpl(const LocationDescription &Loc) {
1674   // Build call kmp_int32 __kmpc_omp_taskwait(ident_t *loc, kmp_int32
1675   // global_tid);
1676   uint32_t SrcLocStrSize;
1677   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1678   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1679   Value *Args[] = {Ident, getOrCreateThreadID(Ident)};
1680 
1681   // Ignore return result until untied tasks are supported.
1682   Builder.CreateCall(getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_taskwait),
1683                      Args);
1684 }
1685 
createTaskwait(const LocationDescription & Loc)1686 void OpenMPIRBuilder::createTaskwait(const LocationDescription &Loc) {
1687   if (!updateToLocation(Loc))
1688     return;
1689   emitTaskwaitImpl(Loc);
1690 }
1691 
emitTaskyieldImpl(const LocationDescription & Loc)1692 void OpenMPIRBuilder::emitTaskyieldImpl(const LocationDescription &Loc) {
1693   // Build call __kmpc_omp_taskyield(loc, thread_id, 0);
1694   uint32_t SrcLocStrSize;
1695   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1696   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1697   Constant *I32Null = ConstantInt::getNullValue(Int32);
1698   Value *Args[] = {Ident, getOrCreateThreadID(Ident), I32Null};
1699 
1700   Builder.CreateCall(getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_taskyield),
1701                      Args);
1702 }
1703 
createTaskyield(const LocationDescription & Loc)1704 void OpenMPIRBuilder::createTaskyield(const LocationDescription &Loc) {
1705   if (!updateToLocation(Loc))
1706     return;
1707   emitTaskyieldImpl(Loc);
1708 }
1709 
1710 // Processes the dependencies in Dependencies and does the following
1711 // - Allocates space on the stack of an array of DependInfo objects
1712 // - Populates each DependInfo object with relevant information of
1713 //   the corresponding dependence.
1714 // - All code is inserted in the entry block of the current function.
emitTaskDependencies(OpenMPIRBuilder & OMPBuilder,SmallVectorImpl<OpenMPIRBuilder::DependData> & Dependencies)1715 static Value *emitTaskDependencies(
1716     OpenMPIRBuilder &OMPBuilder,
1717     SmallVectorImpl<OpenMPIRBuilder::DependData> &Dependencies) {
1718   // Early return if we have no dependencies to process
1719   if (Dependencies.empty())
1720     return nullptr;
1721 
1722   // Given a vector of DependData objects, in this function we create an
1723   // array on the stack that holds kmp_dep_info objects corresponding
1724   // to each dependency. This is then passed to the OpenMP runtime.
1725   // For example, if there are 'n' dependencies then the following psedo
1726   // code is generated. Assume the first dependence is on a variable 'a'
1727   //
1728   // \code{c}
1729   // DepArray = alloc(n x sizeof(kmp_depend_info);
1730   // idx = 0;
1731   // DepArray[idx].base_addr = ptrtoint(&a);
1732   // DepArray[idx].len = 8;
1733   // DepArray[idx].flags = Dep.DepKind; /*(See OMPContants.h for DepKind)*/
1734   // ++idx;
1735   // DepArray[idx].base_addr = ...;
1736   // \endcode
1737 
1738   IRBuilderBase &Builder = OMPBuilder.Builder;
1739   Type *DependInfo = OMPBuilder.DependInfo;
1740   Module &M = OMPBuilder.M;
1741 
1742   Value *DepArray = nullptr;
1743   OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
1744   Builder.SetInsertPoint(
1745       OldIP.getBlock()->getParent()->getEntryBlock().getTerminator());
1746 
1747   Type *DepArrayTy = ArrayType::get(DependInfo, Dependencies.size());
1748   DepArray = Builder.CreateAlloca(DepArrayTy, nullptr, ".dep.arr.addr");
1749 
1750   for (const auto &[DepIdx, Dep] : enumerate(Dependencies)) {
1751     Value *Base =
1752         Builder.CreateConstInBoundsGEP2_64(DepArrayTy, DepArray, 0, DepIdx);
1753     // Store the pointer to the variable
1754     Value *Addr = Builder.CreateStructGEP(
1755         DependInfo, Base,
1756         static_cast<unsigned int>(RTLDependInfoFields::BaseAddr));
1757     Value *DepValPtr = Builder.CreatePtrToInt(Dep.DepVal, Builder.getInt64Ty());
1758     Builder.CreateStore(DepValPtr, Addr);
1759     // Store the size of the variable
1760     Value *Size = Builder.CreateStructGEP(
1761         DependInfo, Base, static_cast<unsigned int>(RTLDependInfoFields::Len));
1762     Builder.CreateStore(
1763         Builder.getInt64(M.getDataLayout().getTypeStoreSize(Dep.DepValueType)),
1764         Size);
1765     // Store the dependency kind
1766     Value *Flags = Builder.CreateStructGEP(
1767         DependInfo, Base,
1768         static_cast<unsigned int>(RTLDependInfoFields::Flags));
1769     Builder.CreateStore(
1770         ConstantInt::get(Builder.getInt8Ty(),
1771                          static_cast<unsigned int>(Dep.DepKind)),
1772         Flags);
1773   }
1774   Builder.restoreIP(OldIP);
1775   return DepArray;
1776 }
1777 
1778 OpenMPIRBuilder::InsertPointTy
createTask(const LocationDescription & Loc,InsertPointTy AllocaIP,BodyGenCallbackTy BodyGenCB,bool Tied,Value * Final,Value * IfCondition,SmallVector<DependData> Dependencies)1779 OpenMPIRBuilder::createTask(const LocationDescription &Loc,
1780                             InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB,
1781                             bool Tied, Value *Final, Value *IfCondition,
1782                             SmallVector<DependData> Dependencies) {
1783 
1784   if (!updateToLocation(Loc))
1785     return InsertPointTy();
1786 
1787   uint32_t SrcLocStrSize;
1788   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1789   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1790   // The current basic block is split into four basic blocks. After outlining,
1791   // they will be mapped as follows:
1792   // ```
1793   // def current_fn() {
1794   //   current_basic_block:
1795   //     br label %task.exit
1796   //   task.exit:
1797   //     ; instructions after task
1798   // }
1799   // def outlined_fn() {
1800   //   task.alloca:
1801   //     br label %task.body
1802   //   task.body:
1803   //     ret void
1804   // }
1805   // ```
1806   BasicBlock *TaskExitBB = splitBB(Builder, /*CreateBranch=*/true, "task.exit");
1807   BasicBlock *TaskBodyBB = splitBB(Builder, /*CreateBranch=*/true, "task.body");
1808   BasicBlock *TaskAllocaBB =
1809       splitBB(Builder, /*CreateBranch=*/true, "task.alloca");
1810 
1811   InsertPointTy TaskAllocaIP =
1812       InsertPointTy(TaskAllocaBB, TaskAllocaBB->begin());
1813   InsertPointTy TaskBodyIP = InsertPointTy(TaskBodyBB, TaskBodyBB->begin());
1814   BodyGenCB(TaskAllocaIP, TaskBodyIP);
1815 
1816   OutlineInfo OI;
1817   OI.EntryBB = TaskAllocaBB;
1818   OI.OuterAllocaBB = AllocaIP.getBlock();
1819   OI.ExitBB = TaskExitBB;
1820 
1821   // Add the thread ID argument.
1822   SmallVector<Instruction *, 4> ToBeDeleted;
1823   OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
1824       Builder, AllocaIP, ToBeDeleted, TaskAllocaIP, "global.tid", false));
1825 
1826   OI.PostOutlineCB = [this, Ident, Tied, Final, IfCondition, Dependencies,
1827                       TaskAllocaBB, ToBeDeleted](Function &OutlinedFn) mutable {
1828     // Replace the Stale CI by appropriate RTL function call.
1829     assert(OutlinedFn.getNumUses() == 1 &&
1830            "there must be a single user for the outlined function");
1831     CallInst *StaleCI = cast<CallInst>(OutlinedFn.user_back());
1832 
1833     // HasShareds is true if any variables are captured in the outlined region,
1834     // false otherwise.
1835     bool HasShareds = StaleCI->arg_size() > 1;
1836     Builder.SetInsertPoint(StaleCI);
1837 
1838     // Gather the arguments for emitting the runtime call for
1839     // @__kmpc_omp_task_alloc
1840     Function *TaskAllocFn =
1841         getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_alloc);
1842 
1843     // Arguments - `loc_ref` (Ident) and `gtid` (ThreadID)
1844     // call.
1845     Value *ThreadID = getOrCreateThreadID(Ident);
1846 
1847     // Argument - `flags`
1848     // Task is tied iff (Flags & 1) == 1.
1849     // Task is untied iff (Flags & 1) == 0.
1850     // Task is final iff (Flags & 2) == 2.
1851     // Task is not final iff (Flags & 2) == 0.
1852     // TODO: Handle the other flags.
1853     Value *Flags = Builder.getInt32(Tied);
1854     if (Final) {
1855       Value *FinalFlag =
1856           Builder.CreateSelect(Final, Builder.getInt32(2), Builder.getInt32(0));
1857       Flags = Builder.CreateOr(FinalFlag, Flags);
1858     }
1859 
1860     // Argument - `sizeof_kmp_task_t` (TaskSize)
1861     // Tasksize refers to the size in bytes of kmp_task_t data structure
1862     // including private vars accessed in task.
1863     // TODO: add kmp_task_t_with_privates (privates)
1864     Value *TaskSize = Builder.getInt64(
1865         divideCeil(M.getDataLayout().getTypeSizeInBits(Task), 8));
1866 
1867     // Argument - `sizeof_shareds` (SharedsSize)
1868     // SharedsSize refers to the shareds array size in the kmp_task_t data
1869     // structure.
1870     Value *SharedsSize = Builder.getInt64(0);
1871     if (HasShareds) {
1872       AllocaInst *ArgStructAlloca =
1873           dyn_cast<AllocaInst>(StaleCI->getArgOperand(1));
1874       assert(ArgStructAlloca &&
1875              "Unable to find the alloca instruction corresponding to arguments "
1876              "for extracted function");
1877       StructType *ArgStructType =
1878           dyn_cast<StructType>(ArgStructAlloca->getAllocatedType());
1879       assert(ArgStructType && "Unable to find struct type corresponding to "
1880                               "arguments for extracted function");
1881       SharedsSize =
1882           Builder.getInt64(M.getDataLayout().getTypeStoreSize(ArgStructType));
1883     }
1884     // Emit the @__kmpc_omp_task_alloc runtime call
1885     // The runtime call returns a pointer to an area where the task captured
1886     // variables must be copied before the task is run (TaskData)
1887     CallInst *TaskData = Builder.CreateCall(
1888         TaskAllocFn, {/*loc_ref=*/Ident, /*gtid=*/ThreadID, /*flags=*/Flags,
1889                       /*sizeof_task=*/TaskSize, /*sizeof_shared=*/SharedsSize,
1890                       /*task_func=*/&OutlinedFn});
1891 
1892     // Copy the arguments for outlined function
1893     if (HasShareds) {
1894       Value *Shareds = StaleCI->getArgOperand(1);
1895       Align Alignment = TaskData->getPointerAlignment(M.getDataLayout());
1896       Value *TaskShareds = Builder.CreateLoad(VoidPtr, TaskData);
1897       Builder.CreateMemCpy(TaskShareds, Alignment, Shareds, Alignment,
1898                            SharedsSize);
1899     }
1900 
1901     Value *DepArray = nullptr;
1902     if (Dependencies.size()) {
1903       InsertPointTy OldIP = Builder.saveIP();
1904       Builder.SetInsertPoint(
1905           &OldIP.getBlock()->getParent()->getEntryBlock().back());
1906 
1907       Type *DepArrayTy = ArrayType::get(DependInfo, Dependencies.size());
1908       DepArray = Builder.CreateAlloca(DepArrayTy, nullptr, ".dep.arr.addr");
1909 
1910       unsigned P = 0;
1911       for (const DependData &Dep : Dependencies) {
1912         Value *Base =
1913             Builder.CreateConstInBoundsGEP2_64(DepArrayTy, DepArray, 0, P);
1914         // Store the pointer to the variable
1915         Value *Addr = Builder.CreateStructGEP(
1916             DependInfo, Base,
1917             static_cast<unsigned int>(RTLDependInfoFields::BaseAddr));
1918         Value *DepValPtr =
1919             Builder.CreatePtrToInt(Dep.DepVal, Builder.getInt64Ty());
1920         Builder.CreateStore(DepValPtr, Addr);
1921         // Store the size of the variable
1922         Value *Size = Builder.CreateStructGEP(
1923             DependInfo, Base,
1924             static_cast<unsigned int>(RTLDependInfoFields::Len));
1925         Builder.CreateStore(Builder.getInt64(M.getDataLayout().getTypeStoreSize(
1926                                 Dep.DepValueType)),
1927                             Size);
1928         // Store the dependency kind
1929         Value *Flags = Builder.CreateStructGEP(
1930             DependInfo, Base,
1931             static_cast<unsigned int>(RTLDependInfoFields::Flags));
1932         Builder.CreateStore(
1933             ConstantInt::get(Builder.getInt8Ty(),
1934                              static_cast<unsigned int>(Dep.DepKind)),
1935             Flags);
1936         ++P;
1937       }
1938 
1939       Builder.restoreIP(OldIP);
1940     }
1941 
1942     // In the presence of the `if` clause, the following IR is generated:
1943     //    ...
1944     //    %data = call @__kmpc_omp_task_alloc(...)
1945     //    br i1 %if_condition, label %then, label %else
1946     //  then:
1947     //    call @__kmpc_omp_task(...)
1948     //    br label %exit
1949     //  else:
1950     //    ;; Wait for resolution of dependencies, if any, before
1951     //    ;; beginning the task
1952     //    call @__kmpc_omp_wait_deps(...)
1953     //    call @__kmpc_omp_task_begin_if0(...)
1954     //    call @outlined_fn(...)
1955     //    call @__kmpc_omp_task_complete_if0(...)
1956     //    br label %exit
1957     //  exit:
1958     //    ...
1959     if (IfCondition) {
1960       // `SplitBlockAndInsertIfThenElse` requires the block to have a
1961       // terminator.
1962       splitBB(Builder, /*CreateBranch=*/true, "if.end");
1963       Instruction *IfTerminator =
1964           Builder.GetInsertPoint()->getParent()->getTerminator();
1965       Instruction *ThenTI = IfTerminator, *ElseTI = nullptr;
1966       Builder.SetInsertPoint(IfTerminator);
1967       SplitBlockAndInsertIfThenElse(IfCondition, IfTerminator, &ThenTI,
1968                                     &ElseTI);
1969       Builder.SetInsertPoint(ElseTI);
1970 
1971       if (Dependencies.size()) {
1972         Function *TaskWaitFn =
1973             getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_wait_deps);
1974         Builder.CreateCall(
1975             TaskWaitFn,
1976             {Ident, ThreadID, Builder.getInt32(Dependencies.size()), DepArray,
1977              ConstantInt::get(Builder.getInt32Ty(), 0),
1978              ConstantPointerNull::get(PointerType::getUnqual(M.getContext()))});
1979       }
1980       Function *TaskBeginFn =
1981           getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_begin_if0);
1982       Function *TaskCompleteFn =
1983           getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_complete_if0);
1984       Builder.CreateCall(TaskBeginFn, {Ident, ThreadID, TaskData});
1985       CallInst *CI = nullptr;
1986       if (HasShareds)
1987         CI = Builder.CreateCall(&OutlinedFn, {ThreadID, TaskData});
1988       else
1989         CI = Builder.CreateCall(&OutlinedFn, {ThreadID});
1990       CI->setDebugLoc(StaleCI->getDebugLoc());
1991       Builder.CreateCall(TaskCompleteFn, {Ident, ThreadID, TaskData});
1992       Builder.SetInsertPoint(ThenTI);
1993     }
1994 
1995     if (Dependencies.size()) {
1996       Function *TaskFn =
1997           getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_with_deps);
1998       Builder.CreateCall(
1999           TaskFn,
2000           {Ident, ThreadID, TaskData, Builder.getInt32(Dependencies.size()),
2001            DepArray, ConstantInt::get(Builder.getInt32Ty(), 0),
2002            ConstantPointerNull::get(PointerType::getUnqual(M.getContext()))});
2003 
2004     } else {
2005       // Emit the @__kmpc_omp_task runtime call to spawn the task
2006       Function *TaskFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task);
2007       Builder.CreateCall(TaskFn, {Ident, ThreadID, TaskData});
2008     }
2009 
2010     StaleCI->eraseFromParent();
2011 
2012     Builder.SetInsertPoint(TaskAllocaBB, TaskAllocaBB->begin());
2013     if (HasShareds) {
2014       LoadInst *Shareds = Builder.CreateLoad(VoidPtr, OutlinedFn.getArg(1));
2015       OutlinedFn.getArg(1)->replaceUsesWithIf(
2016           Shareds, [Shareds](Use &U) { return U.getUser() != Shareds; });
2017     }
2018 
2019     llvm::for_each(llvm::reverse(ToBeDeleted),
2020                    [](Instruction *I) { I->eraseFromParent(); });
2021   };
2022 
2023   addOutlineInfo(std::move(OI));
2024   Builder.SetInsertPoint(TaskExitBB, TaskExitBB->begin());
2025 
2026   return Builder.saveIP();
2027 }
2028 
2029 OpenMPIRBuilder::InsertPointTy
createTaskgroup(const LocationDescription & Loc,InsertPointTy AllocaIP,BodyGenCallbackTy BodyGenCB)2030 OpenMPIRBuilder::createTaskgroup(const LocationDescription &Loc,
2031                                  InsertPointTy AllocaIP,
2032                                  BodyGenCallbackTy BodyGenCB) {
2033   if (!updateToLocation(Loc))
2034     return InsertPointTy();
2035 
2036   uint32_t SrcLocStrSize;
2037   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
2038   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
2039   Value *ThreadID = getOrCreateThreadID(Ident);
2040 
2041   // Emit the @__kmpc_taskgroup runtime call to start the taskgroup
2042   Function *TaskgroupFn =
2043       getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_taskgroup);
2044   Builder.CreateCall(TaskgroupFn, {Ident, ThreadID});
2045 
2046   BasicBlock *TaskgroupExitBB = splitBB(Builder, true, "taskgroup.exit");
2047   BodyGenCB(AllocaIP, Builder.saveIP());
2048 
2049   Builder.SetInsertPoint(TaskgroupExitBB);
2050   // Emit the @__kmpc_end_taskgroup runtime call to end the taskgroup
2051   Function *EndTaskgroupFn =
2052       getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_taskgroup);
2053   Builder.CreateCall(EndTaskgroupFn, {Ident, ThreadID});
2054 
2055   return Builder.saveIP();
2056 }
2057 
createSections(const LocationDescription & Loc,InsertPointTy AllocaIP,ArrayRef<StorableBodyGenCallbackTy> SectionCBs,PrivatizeCallbackTy PrivCB,FinalizeCallbackTy FiniCB,bool IsCancellable,bool IsNowait)2058 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createSections(
2059     const LocationDescription &Loc, InsertPointTy AllocaIP,
2060     ArrayRef<StorableBodyGenCallbackTy> SectionCBs, PrivatizeCallbackTy PrivCB,
2061     FinalizeCallbackTy FiniCB, bool IsCancellable, bool IsNowait) {
2062   assert(!isConflictIP(AllocaIP, Loc.IP) && "Dedicated IP allocas required");
2063 
2064   if (!updateToLocation(Loc))
2065     return Loc.IP;
2066 
2067   auto FiniCBWrapper = [&](InsertPointTy IP) {
2068     if (IP.getBlock()->end() != IP.getPoint())
2069       return FiniCB(IP);
2070     // This must be done otherwise any nested constructs using FinalizeOMPRegion
2071     // will fail because that function requires the Finalization Basic Block to
2072     // have a terminator, which is already removed by EmitOMPRegionBody.
2073     // IP is currently at cancelation block.
2074     // We need to backtrack to the condition block to fetch
2075     // the exit block and create a branch from cancelation
2076     // to exit block.
2077     IRBuilder<>::InsertPointGuard IPG(Builder);
2078     Builder.restoreIP(IP);
2079     auto *CaseBB = IP.getBlock()->getSinglePredecessor();
2080     auto *CondBB = CaseBB->getSinglePredecessor()->getSinglePredecessor();
2081     auto *ExitBB = CondBB->getTerminator()->getSuccessor(1);
2082     Instruction *I = Builder.CreateBr(ExitBB);
2083     IP = InsertPointTy(I->getParent(), I->getIterator());
2084     return FiniCB(IP);
2085   };
2086 
2087   FinalizationStack.push_back({FiniCBWrapper, OMPD_sections, IsCancellable});
2088 
2089   // Each section is emitted as a switch case
2090   // Each finalization callback is handled from clang.EmitOMPSectionDirective()
2091   // -> OMP.createSection() which generates the IR for each section
2092   // Iterate through all sections and emit a switch construct:
2093   // switch (IV) {
2094   //   case 0:
2095   //     <SectionStmt[0]>;
2096   //     break;
2097   // ...
2098   //   case <NumSection> - 1:
2099   //     <SectionStmt[<NumSection> - 1]>;
2100   //     break;
2101   // }
2102   // ...
2103   // section_loop.after:
2104   // <FiniCB>;
2105   auto LoopBodyGenCB = [&](InsertPointTy CodeGenIP, Value *IndVar) {
2106     Builder.restoreIP(CodeGenIP);
2107     BasicBlock *Continue =
2108         splitBBWithSuffix(Builder, /*CreateBranch=*/false, ".sections.after");
2109     Function *CurFn = Continue->getParent();
2110     SwitchInst *SwitchStmt = Builder.CreateSwitch(IndVar, Continue);
2111 
2112     unsigned CaseNumber = 0;
2113     for (auto SectionCB : SectionCBs) {
2114       BasicBlock *CaseBB = BasicBlock::Create(
2115           M.getContext(), "omp_section_loop.body.case", CurFn, Continue);
2116       SwitchStmt->addCase(Builder.getInt32(CaseNumber), CaseBB);
2117       Builder.SetInsertPoint(CaseBB);
2118       BranchInst *CaseEndBr = Builder.CreateBr(Continue);
2119       SectionCB(InsertPointTy(),
2120                 {CaseEndBr->getParent(), CaseEndBr->getIterator()});
2121       CaseNumber++;
2122     }
2123     // remove the existing terminator from body BB since there can be no
2124     // terminators after switch/case
2125   };
2126   // Loop body ends here
2127   // LowerBound, UpperBound, and STride for createCanonicalLoop
2128   Type *I32Ty = Type::getInt32Ty(M.getContext());
2129   Value *LB = ConstantInt::get(I32Ty, 0);
2130   Value *UB = ConstantInt::get(I32Ty, SectionCBs.size());
2131   Value *ST = ConstantInt::get(I32Ty, 1);
2132   llvm::CanonicalLoopInfo *LoopInfo = createCanonicalLoop(
2133       Loc, LoopBodyGenCB, LB, UB, ST, true, false, AllocaIP, "section_loop");
2134   InsertPointTy AfterIP =
2135       applyStaticWorkshareLoop(Loc.DL, LoopInfo, AllocaIP, !IsNowait);
2136 
2137   // Apply the finalization callback in LoopAfterBB
2138   auto FiniInfo = FinalizationStack.pop_back_val();
2139   assert(FiniInfo.DK == OMPD_sections &&
2140          "Unexpected finalization stack state!");
2141   if (FinalizeCallbackTy &CB = FiniInfo.FiniCB) {
2142     Builder.restoreIP(AfterIP);
2143     BasicBlock *FiniBB =
2144         splitBBWithSuffix(Builder, /*CreateBranch=*/true, "sections.fini");
2145     CB(Builder.saveIP());
2146     AfterIP = {FiniBB, FiniBB->begin()};
2147   }
2148 
2149   return AfterIP;
2150 }
2151 
2152 OpenMPIRBuilder::InsertPointTy
createSection(const LocationDescription & Loc,BodyGenCallbackTy BodyGenCB,FinalizeCallbackTy FiniCB)2153 OpenMPIRBuilder::createSection(const LocationDescription &Loc,
2154                                BodyGenCallbackTy BodyGenCB,
2155                                FinalizeCallbackTy FiniCB) {
2156   if (!updateToLocation(Loc))
2157     return Loc.IP;
2158 
2159   auto FiniCBWrapper = [&](InsertPointTy IP) {
2160     if (IP.getBlock()->end() != IP.getPoint())
2161       return FiniCB(IP);
2162     // This must be done otherwise any nested constructs using FinalizeOMPRegion
2163     // will fail because that function requires the Finalization Basic Block to
2164     // have a terminator, which is already removed by EmitOMPRegionBody.
2165     // IP is currently at cancelation block.
2166     // We need to backtrack to the condition block to fetch
2167     // the exit block and create a branch from cancelation
2168     // to exit block.
2169     IRBuilder<>::InsertPointGuard IPG(Builder);
2170     Builder.restoreIP(IP);
2171     auto *CaseBB = Loc.IP.getBlock();
2172     auto *CondBB = CaseBB->getSinglePredecessor()->getSinglePredecessor();
2173     auto *ExitBB = CondBB->getTerminator()->getSuccessor(1);
2174     Instruction *I = Builder.CreateBr(ExitBB);
2175     IP = InsertPointTy(I->getParent(), I->getIterator());
2176     return FiniCB(IP);
2177   };
2178 
2179   Directive OMPD = Directive::OMPD_sections;
2180   // Since we are using Finalization Callback here, HasFinalize
2181   // and IsCancellable have to be true
2182   return EmitOMPInlinedRegion(OMPD, nullptr, nullptr, BodyGenCB, FiniCBWrapper,
2183                               /*Conditional*/ false, /*hasFinalize*/ true,
2184                               /*IsCancellable*/ true);
2185 }
2186 
getInsertPointAfterInstr(Instruction * I)2187 static OpenMPIRBuilder::InsertPointTy getInsertPointAfterInstr(Instruction *I) {
2188   BasicBlock::iterator IT(I);
2189   IT++;
2190   return OpenMPIRBuilder::InsertPointTy(I->getParent(), IT);
2191 }
2192 
emitUsed(StringRef Name,std::vector<WeakTrackingVH> & List)2193 void OpenMPIRBuilder::emitUsed(StringRef Name,
2194                                std::vector<WeakTrackingVH> &List) {
2195   if (List.empty())
2196     return;
2197 
2198   // Convert List to what ConstantArray needs.
2199   SmallVector<Constant *, 8> UsedArray;
2200   UsedArray.resize(List.size());
2201   for (unsigned I = 0, E = List.size(); I != E; ++I)
2202     UsedArray[I] = ConstantExpr::getPointerBitCastOrAddrSpaceCast(
2203         cast<Constant>(&*List[I]), Builder.getPtrTy());
2204 
2205   if (UsedArray.empty())
2206     return;
2207   ArrayType *ATy = ArrayType::get(Builder.getPtrTy(), UsedArray.size());
2208 
2209   auto *GV = new GlobalVariable(M, ATy, false, GlobalValue::AppendingLinkage,
2210                                 ConstantArray::get(ATy, UsedArray), Name);
2211 
2212   GV->setSection("llvm.metadata");
2213 }
2214 
getGPUThreadID()2215 Value *OpenMPIRBuilder::getGPUThreadID() {
2216   return Builder.CreateCall(
2217       getOrCreateRuntimeFunction(M,
2218                                  OMPRTL___kmpc_get_hardware_thread_id_in_block),
2219       {});
2220 }
2221 
getGPUWarpSize()2222 Value *OpenMPIRBuilder::getGPUWarpSize() {
2223   return Builder.CreateCall(
2224       getOrCreateRuntimeFunction(M, OMPRTL___kmpc_get_warp_size), {});
2225 }
2226 
getNVPTXWarpID()2227 Value *OpenMPIRBuilder::getNVPTXWarpID() {
2228   unsigned LaneIDBits = Log2_32(Config.getGridValue().GV_Warp_Size);
2229   return Builder.CreateAShr(getGPUThreadID(), LaneIDBits, "nvptx_warp_id");
2230 }
2231 
getNVPTXLaneID()2232 Value *OpenMPIRBuilder::getNVPTXLaneID() {
2233   unsigned LaneIDBits = Log2_32(Config.getGridValue().GV_Warp_Size);
2234   assert(LaneIDBits < 32 && "Invalid LaneIDBits size in NVPTX device.");
2235   unsigned LaneIDMask = ~0u >> (32u - LaneIDBits);
2236   return Builder.CreateAnd(getGPUThreadID(), Builder.getInt32(LaneIDMask),
2237                            "nvptx_lane_id");
2238 }
2239 
castValueToType(InsertPointTy AllocaIP,Value * From,Type * ToType)2240 Value *OpenMPIRBuilder::castValueToType(InsertPointTy AllocaIP, Value *From,
2241                                         Type *ToType) {
2242   Type *FromType = From->getType();
2243   uint64_t FromSize = M.getDataLayout().getTypeStoreSize(FromType);
2244   uint64_t ToSize = M.getDataLayout().getTypeStoreSize(ToType);
2245   assert(FromSize > 0 && "From size must be greater than zero");
2246   assert(ToSize > 0 && "To size must be greater than zero");
2247   if (FromType == ToType)
2248     return From;
2249   if (FromSize == ToSize)
2250     return Builder.CreateBitCast(From, ToType);
2251   if (ToType->isIntegerTy() && FromType->isIntegerTy())
2252     return Builder.CreateIntCast(From, ToType, /*isSigned*/ true);
2253   InsertPointTy SaveIP = Builder.saveIP();
2254   Builder.restoreIP(AllocaIP);
2255   Value *CastItem = Builder.CreateAlloca(ToType);
2256   Builder.restoreIP(SaveIP);
2257 
2258   Value *ValCastItem = Builder.CreatePointerBitCastOrAddrSpaceCast(
2259       CastItem, FromType->getPointerTo());
2260   Builder.CreateStore(From, ValCastItem);
2261   return Builder.CreateLoad(ToType, CastItem);
2262 }
2263 
createRuntimeShuffleFunction(InsertPointTy AllocaIP,Value * Element,Type * ElementType,Value * Offset)2264 Value *OpenMPIRBuilder::createRuntimeShuffleFunction(InsertPointTy AllocaIP,
2265                                                      Value *Element,
2266                                                      Type *ElementType,
2267                                                      Value *Offset) {
2268   uint64_t Size = M.getDataLayout().getTypeStoreSize(ElementType);
2269   assert(Size <= 8 && "Unsupported bitwidth in shuffle instruction");
2270 
2271   // Cast all types to 32- or 64-bit values before calling shuffle routines.
2272   Type *CastTy = Builder.getIntNTy(Size <= 4 ? 32 : 64);
2273   Value *ElemCast = castValueToType(AllocaIP, Element, CastTy);
2274   Value *WarpSize =
2275       Builder.CreateIntCast(getGPUWarpSize(), Builder.getInt16Ty(), true);
2276   Function *ShuffleFunc = getOrCreateRuntimeFunctionPtr(
2277       Size <= 4 ? RuntimeFunction::OMPRTL___kmpc_shuffle_int32
2278                 : RuntimeFunction::OMPRTL___kmpc_shuffle_int64);
2279   Value *WarpSizeCast =
2280       Builder.CreateIntCast(WarpSize, Builder.getInt16Ty(), /*isSigned=*/true);
2281   Value *ShuffleCall =
2282       Builder.CreateCall(ShuffleFunc, {ElemCast, Offset, WarpSizeCast});
2283   return castValueToType(AllocaIP, ShuffleCall, CastTy);
2284 }
2285 
shuffleAndStore(InsertPointTy AllocaIP,Value * SrcAddr,Value * DstAddr,Type * ElemType,Value * Offset,Type * ReductionArrayTy)2286 void OpenMPIRBuilder::shuffleAndStore(InsertPointTy AllocaIP, Value *SrcAddr,
2287                                       Value *DstAddr, Type *ElemType,
2288                                       Value *Offset, Type *ReductionArrayTy) {
2289   uint64_t Size = M.getDataLayout().getTypeStoreSize(ElemType);
2290   // Create the loop over the big sized data.
2291   // ptr = (void*)Elem;
2292   // ptrEnd = (void*) Elem + 1;
2293   // Step = 8;
2294   // while (ptr + Step < ptrEnd)
2295   //   shuffle((int64_t)*ptr);
2296   // Step = 4;
2297   // while (ptr + Step < ptrEnd)
2298   //   shuffle((int32_t)*ptr);
2299   // ...
2300   Type *IndexTy = Builder.getIndexTy(
2301       M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
2302   Value *ElemPtr = DstAddr;
2303   Value *Ptr = SrcAddr;
2304   for (unsigned IntSize = 8; IntSize >= 1; IntSize /= 2) {
2305     if (Size < IntSize)
2306       continue;
2307     Type *IntType = Builder.getIntNTy(IntSize * 8);
2308     Ptr = Builder.CreatePointerBitCastOrAddrSpaceCast(
2309         Ptr, IntType->getPointerTo(), Ptr->getName() + ".ascast");
2310     Value *SrcAddrGEP =
2311         Builder.CreateGEP(ElemType, SrcAddr, {ConstantInt::get(IndexTy, 1)});
2312     ElemPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
2313         ElemPtr, IntType->getPointerTo(), ElemPtr->getName() + ".ascast");
2314 
2315     Function *CurFunc = Builder.GetInsertBlock()->getParent();
2316     if ((Size / IntSize) > 1) {
2317       Value *PtrEnd = Builder.CreatePointerBitCastOrAddrSpaceCast(
2318           SrcAddrGEP, Builder.getPtrTy());
2319       BasicBlock *PreCondBB =
2320           BasicBlock::Create(M.getContext(), ".shuffle.pre_cond");
2321       BasicBlock *ThenBB = BasicBlock::Create(M.getContext(), ".shuffle.then");
2322       BasicBlock *ExitBB = BasicBlock::Create(M.getContext(), ".shuffle.exit");
2323       BasicBlock *CurrentBB = Builder.GetInsertBlock();
2324       emitBlock(PreCondBB, CurFunc);
2325       PHINode *PhiSrc =
2326           Builder.CreatePHI(Ptr->getType(), /*NumReservedValues=*/2);
2327       PhiSrc->addIncoming(Ptr, CurrentBB);
2328       PHINode *PhiDest =
2329           Builder.CreatePHI(ElemPtr->getType(), /*NumReservedValues=*/2);
2330       PhiDest->addIncoming(ElemPtr, CurrentBB);
2331       Ptr = PhiSrc;
2332       ElemPtr = PhiDest;
2333       Value *PtrDiff = Builder.CreatePtrDiff(
2334           Builder.getInt8Ty(), PtrEnd,
2335           Builder.CreatePointerBitCastOrAddrSpaceCast(Ptr, Builder.getPtrTy()));
2336       Builder.CreateCondBr(
2337           Builder.CreateICmpSGT(PtrDiff, Builder.getInt64(IntSize - 1)), ThenBB,
2338           ExitBB);
2339       emitBlock(ThenBB, CurFunc);
2340       Value *Res = createRuntimeShuffleFunction(
2341           AllocaIP,
2342           Builder.CreateAlignedLoad(
2343               IntType, Ptr, M.getDataLayout().getPrefTypeAlign(ElemType)),
2344           IntType, Offset);
2345       Builder.CreateAlignedStore(Res, ElemPtr,
2346                                  M.getDataLayout().getPrefTypeAlign(ElemType));
2347       Value *LocalPtr =
2348           Builder.CreateGEP(IntType, Ptr, {ConstantInt::get(IndexTy, 1)});
2349       Value *LocalElemPtr =
2350           Builder.CreateGEP(IntType, ElemPtr, {ConstantInt::get(IndexTy, 1)});
2351       PhiSrc->addIncoming(LocalPtr, ThenBB);
2352       PhiDest->addIncoming(LocalElemPtr, ThenBB);
2353       emitBranch(PreCondBB);
2354       emitBlock(ExitBB, CurFunc);
2355     } else {
2356       Value *Res = createRuntimeShuffleFunction(
2357           AllocaIP, Builder.CreateLoad(IntType, Ptr), IntType, Offset);
2358       if (ElemType->isIntegerTy() && ElemType->getScalarSizeInBits() <
2359                                          Res->getType()->getScalarSizeInBits())
2360         Res = Builder.CreateTrunc(Res, ElemType);
2361       Builder.CreateStore(Res, ElemPtr);
2362       Ptr = Builder.CreateGEP(IntType, Ptr, {ConstantInt::get(IndexTy, 1)});
2363       ElemPtr =
2364           Builder.CreateGEP(IntType, ElemPtr, {ConstantInt::get(IndexTy, 1)});
2365     }
2366     Size = Size % IntSize;
2367   }
2368 }
2369 
emitReductionListCopy(InsertPointTy AllocaIP,CopyAction Action,Type * ReductionArrayTy,ArrayRef<ReductionInfo> ReductionInfos,Value * SrcBase,Value * DestBase,CopyOptionsTy CopyOptions)2370 void OpenMPIRBuilder::emitReductionListCopy(
2371     InsertPointTy AllocaIP, CopyAction Action, Type *ReductionArrayTy,
2372     ArrayRef<ReductionInfo> ReductionInfos, Value *SrcBase, Value *DestBase,
2373     CopyOptionsTy CopyOptions) {
2374   Type *IndexTy = Builder.getIndexTy(
2375       M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
2376   Value *RemoteLaneOffset = CopyOptions.RemoteLaneOffset;
2377 
2378   // Iterates, element-by-element, through the source Reduce list and
2379   // make a copy.
2380   for (auto En : enumerate(ReductionInfos)) {
2381     const ReductionInfo &RI = En.value();
2382     Value *SrcElementAddr = nullptr;
2383     Value *DestElementAddr = nullptr;
2384     Value *DestElementPtrAddr = nullptr;
2385     // Should we shuffle in an element from a remote lane?
2386     bool ShuffleInElement = false;
2387     // Set to true to update the pointer in the dest Reduce list to a
2388     // newly created element.
2389     bool UpdateDestListPtr = false;
2390 
2391     // Step 1.1: Get the address for the src element in the Reduce list.
2392     Value *SrcElementPtrAddr = Builder.CreateInBoundsGEP(
2393         ReductionArrayTy, SrcBase,
2394         {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
2395     SrcElementAddr = Builder.CreateLoad(Builder.getPtrTy(), SrcElementPtrAddr);
2396 
2397     // Step 1.2: Create a temporary to store the element in the destination
2398     // Reduce list.
2399     DestElementPtrAddr = Builder.CreateInBoundsGEP(
2400         ReductionArrayTy, DestBase,
2401         {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
2402     switch (Action) {
2403     case CopyAction::RemoteLaneToThread: {
2404       InsertPointTy CurIP = Builder.saveIP();
2405       Builder.restoreIP(AllocaIP);
2406       AllocaInst *DestAlloca = Builder.CreateAlloca(RI.ElementType, nullptr,
2407                                                     ".omp.reduction.element");
2408       DestAlloca->setAlignment(
2409           M.getDataLayout().getPrefTypeAlign(RI.ElementType));
2410       DestElementAddr = DestAlloca;
2411       DestElementAddr =
2412           Builder.CreateAddrSpaceCast(DestElementAddr, Builder.getPtrTy(),
2413                                       DestElementAddr->getName() + ".ascast");
2414       Builder.restoreIP(CurIP);
2415       ShuffleInElement = true;
2416       UpdateDestListPtr = true;
2417       break;
2418     }
2419     case CopyAction::ThreadCopy: {
2420       DestElementAddr =
2421           Builder.CreateLoad(Builder.getPtrTy(), DestElementPtrAddr);
2422       break;
2423     }
2424     }
2425 
2426     // Now that all active lanes have read the element in the
2427     // Reduce list, shuffle over the value from the remote lane.
2428     if (ShuffleInElement) {
2429       shuffleAndStore(AllocaIP, SrcElementAddr, DestElementAddr, RI.ElementType,
2430                       RemoteLaneOffset, ReductionArrayTy);
2431     } else {
2432       switch (RI.EvaluationKind) {
2433       case EvalKind::Scalar: {
2434         Value *Elem = Builder.CreateLoad(RI.ElementType, SrcElementAddr);
2435         // Store the source element value to the dest element address.
2436         Builder.CreateStore(Elem, DestElementAddr);
2437         break;
2438       }
2439       case EvalKind::Complex: {
2440         Value *SrcRealPtr = Builder.CreateConstInBoundsGEP2_32(
2441             RI.ElementType, SrcElementAddr, 0, 0, ".realp");
2442         Value *SrcReal = Builder.CreateLoad(
2443             RI.ElementType->getStructElementType(0), SrcRealPtr, ".real");
2444         Value *SrcImgPtr = Builder.CreateConstInBoundsGEP2_32(
2445             RI.ElementType, SrcElementAddr, 0, 1, ".imagp");
2446         Value *SrcImg = Builder.CreateLoad(
2447             RI.ElementType->getStructElementType(1), SrcImgPtr, ".imag");
2448 
2449         Value *DestRealPtr = Builder.CreateConstInBoundsGEP2_32(
2450             RI.ElementType, DestElementAddr, 0, 0, ".realp");
2451         Value *DestImgPtr = Builder.CreateConstInBoundsGEP2_32(
2452             RI.ElementType, DestElementAddr, 0, 1, ".imagp");
2453         Builder.CreateStore(SrcReal, DestRealPtr);
2454         Builder.CreateStore(SrcImg, DestImgPtr);
2455         break;
2456       }
2457       case EvalKind::Aggregate: {
2458         Value *SizeVal = Builder.getInt64(
2459             M.getDataLayout().getTypeStoreSize(RI.ElementType));
2460         Builder.CreateMemCpy(
2461             DestElementAddr, M.getDataLayout().getPrefTypeAlign(RI.ElementType),
2462             SrcElementAddr, M.getDataLayout().getPrefTypeAlign(RI.ElementType),
2463             SizeVal, false);
2464         break;
2465       }
2466       };
2467     }
2468 
2469     // Step 3.1: Modify reference in dest Reduce list as needed.
2470     // Modifying the reference in Reduce list to point to the newly
2471     // created element.  The element is live in the current function
2472     // scope and that of functions it invokes (i.e., reduce_function).
2473     // RemoteReduceData[i] = (void*)&RemoteElem
2474     if (UpdateDestListPtr) {
2475       Value *CastDestAddr = Builder.CreatePointerBitCastOrAddrSpaceCast(
2476           DestElementAddr, Builder.getPtrTy(),
2477           DestElementAddr->getName() + ".ascast");
2478       Builder.CreateStore(CastDestAddr, DestElementPtrAddr);
2479     }
2480   }
2481 }
2482 
emitInterWarpCopyFunction(const LocationDescription & Loc,ArrayRef<ReductionInfo> ReductionInfos,AttributeList FuncAttrs)2483 Function *OpenMPIRBuilder::emitInterWarpCopyFunction(
2484     const LocationDescription &Loc, ArrayRef<ReductionInfo> ReductionInfos,
2485     AttributeList FuncAttrs) {
2486   InsertPointTy SavedIP = Builder.saveIP();
2487   LLVMContext &Ctx = M.getContext();
2488   FunctionType *FuncTy = FunctionType::get(
2489       Builder.getVoidTy(), {Builder.getPtrTy(), Builder.getInt32Ty()},
2490       /* IsVarArg */ false);
2491   Function *WcFunc =
2492       Function::Create(FuncTy, GlobalVariable::InternalLinkage,
2493                        "_omp_reduction_inter_warp_copy_func", &M);
2494   WcFunc->setAttributes(FuncAttrs);
2495   WcFunc->addParamAttr(0, Attribute::NoUndef);
2496   WcFunc->addParamAttr(1, Attribute::NoUndef);
2497   BasicBlock *EntryBB = BasicBlock::Create(M.getContext(), "entry", WcFunc);
2498   Builder.SetInsertPoint(EntryBB);
2499 
2500   // ReduceList: thread local Reduce list.
2501   // At the stage of the computation when this function is called, partially
2502   // aggregated values reside in the first lane of every active warp.
2503   Argument *ReduceListArg = WcFunc->getArg(0);
2504   // NumWarps: number of warps active in the parallel region.  This could
2505   // be smaller than 32 (max warps in a CTA) for partial block reduction.
2506   Argument *NumWarpsArg = WcFunc->getArg(1);
2507 
2508   // This array is used as a medium to transfer, one reduce element at a time,
2509   // the data from the first lane of every warp to lanes in the first warp
2510   // in order to perform the final step of a reduction in a parallel region
2511   // (reduction across warps).  The array is placed in NVPTX __shared__ memory
2512   // for reduced latency, as well as to have a distinct copy for concurrently
2513   // executing target regions.  The array is declared with common linkage so
2514   // as to be shared across compilation units.
2515   StringRef TransferMediumName =
2516       "__openmp_nvptx_data_transfer_temporary_storage";
2517   GlobalVariable *TransferMedium = M.getGlobalVariable(TransferMediumName);
2518   unsigned WarpSize = Config.getGridValue().GV_Warp_Size;
2519   ArrayType *ArrayTy = ArrayType::get(Builder.getInt32Ty(), WarpSize);
2520   if (!TransferMedium) {
2521     TransferMedium = new GlobalVariable(
2522         M, ArrayTy, /*isConstant=*/false, GlobalVariable::WeakAnyLinkage,
2523         UndefValue::get(ArrayTy), TransferMediumName,
2524         /*InsertBefore=*/nullptr, GlobalVariable::NotThreadLocal,
2525         /*AddressSpace=*/3);
2526   }
2527 
2528   // Get the CUDA thread id of the current OpenMP thread on the GPU.
2529   Value *GPUThreadID = getGPUThreadID();
2530   // nvptx_lane_id = nvptx_id % warpsize
2531   Value *LaneID = getNVPTXLaneID();
2532   // nvptx_warp_id = nvptx_id / warpsize
2533   Value *WarpID = getNVPTXWarpID();
2534 
2535   InsertPointTy AllocaIP =
2536       InsertPointTy(Builder.GetInsertBlock(),
2537                     Builder.GetInsertBlock()->getFirstInsertionPt());
2538   Type *Arg0Type = ReduceListArg->getType();
2539   Type *Arg1Type = NumWarpsArg->getType();
2540   Builder.restoreIP(AllocaIP);
2541   AllocaInst *ReduceListAlloca = Builder.CreateAlloca(
2542       Arg0Type, nullptr, ReduceListArg->getName() + ".addr");
2543   AllocaInst *NumWarpsAlloca =
2544       Builder.CreateAlloca(Arg1Type, nullptr, NumWarpsArg->getName() + ".addr");
2545   Value *ReduceListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2546       ReduceListAlloca, Arg0Type, ReduceListAlloca->getName() + ".ascast");
2547   Value *NumWarpsAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2548       NumWarpsAlloca, Arg1Type->getPointerTo(),
2549       NumWarpsAlloca->getName() + ".ascast");
2550   Builder.CreateStore(ReduceListArg, ReduceListAddrCast);
2551   Builder.CreateStore(NumWarpsArg, NumWarpsAddrCast);
2552   AllocaIP = getInsertPointAfterInstr(NumWarpsAlloca);
2553   InsertPointTy CodeGenIP =
2554       getInsertPointAfterInstr(&Builder.GetInsertBlock()->back());
2555   Builder.restoreIP(CodeGenIP);
2556 
2557   Value *ReduceList =
2558       Builder.CreateLoad(Builder.getPtrTy(), ReduceListAddrCast);
2559 
2560   for (auto En : enumerate(ReductionInfos)) {
2561     //
2562     // Warp master copies reduce element to transfer medium in __shared__
2563     // memory.
2564     //
2565     const ReductionInfo &RI = En.value();
2566     unsigned RealTySize = M.getDataLayout().getTypeAllocSize(RI.ElementType);
2567     for (unsigned TySize = 4; TySize > 0 && RealTySize > 0; TySize /= 2) {
2568       Type *CType = Builder.getIntNTy(TySize * 8);
2569 
2570       unsigned NumIters = RealTySize / TySize;
2571       if (NumIters == 0)
2572         continue;
2573       Value *Cnt = nullptr;
2574       Value *CntAddr = nullptr;
2575       BasicBlock *PrecondBB = nullptr;
2576       BasicBlock *ExitBB = nullptr;
2577       if (NumIters > 1) {
2578         CodeGenIP = Builder.saveIP();
2579         Builder.restoreIP(AllocaIP);
2580         CntAddr =
2581             Builder.CreateAlloca(Builder.getInt32Ty(), nullptr, ".cnt.addr");
2582 
2583         CntAddr = Builder.CreateAddrSpaceCast(CntAddr, Builder.getPtrTy(),
2584                                               CntAddr->getName() + ".ascast");
2585         Builder.restoreIP(CodeGenIP);
2586         Builder.CreateStore(Constant::getNullValue(Builder.getInt32Ty()),
2587                             CntAddr,
2588                             /*Volatile=*/false);
2589         PrecondBB = BasicBlock::Create(Ctx, "precond");
2590         ExitBB = BasicBlock::Create(Ctx, "exit");
2591         BasicBlock *BodyBB = BasicBlock::Create(Ctx, "body");
2592         emitBlock(PrecondBB, Builder.GetInsertBlock()->getParent());
2593         Cnt = Builder.CreateLoad(Builder.getInt32Ty(), CntAddr,
2594                                  /*Volatile=*/false);
2595         Value *Cmp = Builder.CreateICmpULT(
2596             Cnt, ConstantInt::get(Builder.getInt32Ty(), NumIters));
2597         Builder.CreateCondBr(Cmp, BodyBB, ExitBB);
2598         emitBlock(BodyBB, Builder.GetInsertBlock()->getParent());
2599       }
2600 
2601       // kmpc_barrier.
2602       createBarrier(LocationDescription(Builder.saveIP(), Loc.DL),
2603                     omp::Directive::OMPD_unknown,
2604                     /* ForceSimpleCall */ false,
2605                     /* CheckCancelFlag */ true);
2606       BasicBlock *ThenBB = BasicBlock::Create(Ctx, "then");
2607       BasicBlock *ElseBB = BasicBlock::Create(Ctx, "else");
2608       BasicBlock *MergeBB = BasicBlock::Create(Ctx, "ifcont");
2609 
2610       // if (lane_id  == 0)
2611       Value *IsWarpMaster = Builder.CreateIsNull(LaneID, "warp_master");
2612       Builder.CreateCondBr(IsWarpMaster, ThenBB, ElseBB);
2613       emitBlock(ThenBB, Builder.GetInsertBlock()->getParent());
2614 
2615       // Reduce element = LocalReduceList[i]
2616       auto *RedListArrayTy =
2617           ArrayType::get(Builder.getPtrTy(), ReductionInfos.size());
2618       Type *IndexTy = Builder.getIndexTy(
2619           M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
2620       Value *ElemPtrPtr =
2621           Builder.CreateInBoundsGEP(RedListArrayTy, ReduceList,
2622                                     {ConstantInt::get(IndexTy, 0),
2623                                      ConstantInt::get(IndexTy, En.index())});
2624       // elemptr = ((CopyType*)(elemptrptr)) + I
2625       Value *ElemPtr = Builder.CreateLoad(Builder.getPtrTy(), ElemPtrPtr);
2626       if (NumIters > 1)
2627         ElemPtr = Builder.CreateGEP(Builder.getInt32Ty(), ElemPtr, Cnt);
2628 
2629       // Get pointer to location in transfer medium.
2630       // MediumPtr = &medium[warp_id]
2631       Value *MediumPtr = Builder.CreateInBoundsGEP(
2632           ArrayTy, TransferMedium, {Builder.getInt64(0), WarpID});
2633       // elem = *elemptr
2634       //*MediumPtr = elem
2635       Value *Elem = Builder.CreateLoad(CType, ElemPtr);
2636       // Store the source element value to the dest element address.
2637       Builder.CreateStore(Elem, MediumPtr,
2638                           /*IsVolatile*/ true);
2639       Builder.CreateBr(MergeBB);
2640 
2641       // else
2642       emitBlock(ElseBB, Builder.GetInsertBlock()->getParent());
2643       Builder.CreateBr(MergeBB);
2644 
2645       // endif
2646       emitBlock(MergeBB, Builder.GetInsertBlock()->getParent());
2647       createBarrier(LocationDescription(Builder.saveIP(), Loc.DL),
2648                     omp::Directive::OMPD_unknown,
2649                     /* ForceSimpleCall */ false,
2650                     /* CheckCancelFlag */ true);
2651 
2652       // Warp 0 copies reduce element from transfer medium
2653       BasicBlock *W0ThenBB = BasicBlock::Create(Ctx, "then");
2654       BasicBlock *W0ElseBB = BasicBlock::Create(Ctx, "else");
2655       BasicBlock *W0MergeBB = BasicBlock::Create(Ctx, "ifcont");
2656 
2657       Value *NumWarpsVal =
2658           Builder.CreateLoad(Builder.getInt32Ty(), NumWarpsAddrCast);
2659       // Up to 32 threads in warp 0 are active.
2660       Value *IsActiveThread =
2661           Builder.CreateICmpULT(GPUThreadID, NumWarpsVal, "is_active_thread");
2662       Builder.CreateCondBr(IsActiveThread, W0ThenBB, W0ElseBB);
2663 
2664       emitBlock(W0ThenBB, Builder.GetInsertBlock()->getParent());
2665 
2666       // SecMediumPtr = &medium[tid]
2667       // SrcMediumVal = *SrcMediumPtr
2668       Value *SrcMediumPtrVal = Builder.CreateInBoundsGEP(
2669           ArrayTy, TransferMedium, {Builder.getInt64(0), GPUThreadID});
2670       // TargetElemPtr = (CopyType*)(SrcDataAddr[i]) + I
2671       Value *TargetElemPtrPtr =
2672           Builder.CreateInBoundsGEP(RedListArrayTy, ReduceList,
2673                                     {ConstantInt::get(IndexTy, 0),
2674                                      ConstantInt::get(IndexTy, En.index())});
2675       Value *TargetElemPtrVal =
2676           Builder.CreateLoad(Builder.getPtrTy(), TargetElemPtrPtr);
2677       Value *TargetElemPtr = TargetElemPtrVal;
2678       if (NumIters > 1)
2679         TargetElemPtr =
2680             Builder.CreateGEP(Builder.getInt32Ty(), TargetElemPtr, Cnt);
2681 
2682       // *TargetElemPtr = SrcMediumVal;
2683       Value *SrcMediumValue =
2684           Builder.CreateLoad(CType, SrcMediumPtrVal, /*IsVolatile*/ true);
2685       Builder.CreateStore(SrcMediumValue, TargetElemPtr);
2686       Builder.CreateBr(W0MergeBB);
2687 
2688       emitBlock(W0ElseBB, Builder.GetInsertBlock()->getParent());
2689       Builder.CreateBr(W0MergeBB);
2690 
2691       emitBlock(W0MergeBB, Builder.GetInsertBlock()->getParent());
2692 
2693       if (NumIters > 1) {
2694         Cnt = Builder.CreateNSWAdd(
2695             Cnt, ConstantInt::get(Builder.getInt32Ty(), /*V=*/1));
2696         Builder.CreateStore(Cnt, CntAddr, /*Volatile=*/false);
2697 
2698         auto *CurFn = Builder.GetInsertBlock()->getParent();
2699         emitBranch(PrecondBB);
2700         emitBlock(ExitBB, CurFn);
2701       }
2702       RealTySize %= TySize;
2703     }
2704   }
2705 
2706   Builder.CreateRetVoid();
2707   Builder.restoreIP(SavedIP);
2708 
2709   return WcFunc;
2710 }
2711 
emitShuffleAndReduceFunction(ArrayRef<ReductionInfo> ReductionInfos,Function * ReduceFn,AttributeList FuncAttrs)2712 Function *OpenMPIRBuilder::emitShuffleAndReduceFunction(
2713     ArrayRef<ReductionInfo> ReductionInfos, Function *ReduceFn,
2714     AttributeList FuncAttrs) {
2715   LLVMContext &Ctx = M.getContext();
2716   FunctionType *FuncTy =
2717       FunctionType::get(Builder.getVoidTy(),
2718                         {Builder.getPtrTy(), Builder.getInt16Ty(),
2719                          Builder.getInt16Ty(), Builder.getInt16Ty()},
2720                         /* IsVarArg */ false);
2721   Function *SarFunc =
2722       Function::Create(FuncTy, GlobalVariable::InternalLinkage,
2723                        "_omp_reduction_shuffle_and_reduce_func", &M);
2724   SarFunc->setAttributes(FuncAttrs);
2725   SarFunc->addParamAttr(0, Attribute::NoUndef);
2726   SarFunc->addParamAttr(1, Attribute::NoUndef);
2727   SarFunc->addParamAttr(2, Attribute::NoUndef);
2728   SarFunc->addParamAttr(3, Attribute::NoUndef);
2729   SarFunc->addParamAttr(1, Attribute::SExt);
2730   SarFunc->addParamAttr(2, Attribute::SExt);
2731   SarFunc->addParamAttr(3, Attribute::SExt);
2732   BasicBlock *EntryBB = BasicBlock::Create(M.getContext(), "entry", SarFunc);
2733   Builder.SetInsertPoint(EntryBB);
2734 
2735   // Thread local Reduce list used to host the values of data to be reduced.
2736   Argument *ReduceListArg = SarFunc->getArg(0);
2737   // Current lane id; could be logical.
2738   Argument *LaneIDArg = SarFunc->getArg(1);
2739   // Offset of the remote source lane relative to the current lane.
2740   Argument *RemoteLaneOffsetArg = SarFunc->getArg(2);
2741   // Algorithm version.  This is expected to be known at compile time.
2742   Argument *AlgoVerArg = SarFunc->getArg(3);
2743 
2744   Type *ReduceListArgType = ReduceListArg->getType();
2745   Type *LaneIDArgType = LaneIDArg->getType();
2746   Type *LaneIDArgPtrType = LaneIDArg->getType()->getPointerTo();
2747   Value *ReduceListAlloca = Builder.CreateAlloca(
2748       ReduceListArgType, nullptr, ReduceListArg->getName() + ".addr");
2749   Value *LaneIdAlloca = Builder.CreateAlloca(LaneIDArgType, nullptr,
2750                                              LaneIDArg->getName() + ".addr");
2751   Value *RemoteLaneOffsetAlloca = Builder.CreateAlloca(
2752       LaneIDArgType, nullptr, RemoteLaneOffsetArg->getName() + ".addr");
2753   Value *AlgoVerAlloca = Builder.CreateAlloca(LaneIDArgType, nullptr,
2754                                               AlgoVerArg->getName() + ".addr");
2755   ArrayType *RedListArrayTy =
2756       ArrayType::get(Builder.getPtrTy(), ReductionInfos.size());
2757 
2758   // Create a local thread-private variable to host the Reduce list
2759   // from a remote lane.
2760   Instruction *RemoteReductionListAlloca = Builder.CreateAlloca(
2761       RedListArrayTy, nullptr, ".omp.reduction.remote_reduce_list");
2762 
2763   Value *ReduceListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2764       ReduceListAlloca, ReduceListArgType,
2765       ReduceListAlloca->getName() + ".ascast");
2766   Value *LaneIdAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2767       LaneIdAlloca, LaneIDArgPtrType, LaneIdAlloca->getName() + ".ascast");
2768   Value *RemoteLaneOffsetAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2769       RemoteLaneOffsetAlloca, LaneIDArgPtrType,
2770       RemoteLaneOffsetAlloca->getName() + ".ascast");
2771   Value *AlgoVerAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2772       AlgoVerAlloca, LaneIDArgPtrType, AlgoVerAlloca->getName() + ".ascast");
2773   Value *RemoteListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2774       RemoteReductionListAlloca, Builder.getPtrTy(),
2775       RemoteReductionListAlloca->getName() + ".ascast");
2776 
2777   Builder.CreateStore(ReduceListArg, ReduceListAddrCast);
2778   Builder.CreateStore(LaneIDArg, LaneIdAddrCast);
2779   Builder.CreateStore(RemoteLaneOffsetArg, RemoteLaneOffsetAddrCast);
2780   Builder.CreateStore(AlgoVerArg, AlgoVerAddrCast);
2781 
2782   Value *ReduceList = Builder.CreateLoad(ReduceListArgType, ReduceListAddrCast);
2783   Value *LaneId = Builder.CreateLoad(LaneIDArgType, LaneIdAddrCast);
2784   Value *RemoteLaneOffset =
2785       Builder.CreateLoad(LaneIDArgType, RemoteLaneOffsetAddrCast);
2786   Value *AlgoVer = Builder.CreateLoad(LaneIDArgType, AlgoVerAddrCast);
2787 
2788   InsertPointTy AllocaIP = getInsertPointAfterInstr(RemoteReductionListAlloca);
2789 
2790   // This loop iterates through the list of reduce elements and copies,
2791   // element by element, from a remote lane in the warp to RemoteReduceList,
2792   // hosted on the thread's stack.
2793   emitReductionListCopy(
2794       AllocaIP, CopyAction::RemoteLaneToThread, RedListArrayTy, ReductionInfos,
2795       ReduceList, RemoteListAddrCast, {RemoteLaneOffset, nullptr, nullptr});
2796 
2797   // The actions to be performed on the Remote Reduce list is dependent
2798   // on the algorithm version.
2799   //
2800   //  if (AlgoVer==0) || (AlgoVer==1 && (LaneId < Offset)) || (AlgoVer==2 &&
2801   //  LaneId % 2 == 0 && Offset > 0):
2802   //    do the reduction value aggregation
2803   //
2804   //  The thread local variable Reduce list is mutated in place to host the
2805   //  reduced data, which is the aggregated value produced from local and
2806   //  remote lanes.
2807   //
2808   //  Note that AlgoVer is expected to be a constant integer known at compile
2809   //  time.
2810   //  When AlgoVer==0, the first conjunction evaluates to true, making
2811   //    the entire predicate true during compile time.
2812   //  When AlgoVer==1, the second conjunction has only the second part to be
2813   //    evaluated during runtime.  Other conjunctions evaluates to false
2814   //    during compile time.
2815   //  When AlgoVer==2, the third conjunction has only the second part to be
2816   //    evaluated during runtime.  Other conjunctions evaluates to false
2817   //    during compile time.
2818   Value *CondAlgo0 = Builder.CreateIsNull(AlgoVer);
2819   Value *Algo1 = Builder.CreateICmpEQ(AlgoVer, Builder.getInt16(1));
2820   Value *LaneComp = Builder.CreateICmpULT(LaneId, RemoteLaneOffset);
2821   Value *CondAlgo1 = Builder.CreateAnd(Algo1, LaneComp);
2822   Value *Algo2 = Builder.CreateICmpEQ(AlgoVer, Builder.getInt16(2));
2823   Value *LaneIdAnd1 = Builder.CreateAnd(LaneId, Builder.getInt16(1));
2824   Value *LaneIdComp = Builder.CreateIsNull(LaneIdAnd1);
2825   Value *Algo2AndLaneIdComp = Builder.CreateAnd(Algo2, LaneIdComp);
2826   Value *RemoteOffsetComp =
2827       Builder.CreateICmpSGT(RemoteLaneOffset, Builder.getInt16(0));
2828   Value *CondAlgo2 = Builder.CreateAnd(Algo2AndLaneIdComp, RemoteOffsetComp);
2829   Value *CA0OrCA1 = Builder.CreateOr(CondAlgo0, CondAlgo1);
2830   Value *CondReduce = Builder.CreateOr(CA0OrCA1, CondAlgo2);
2831 
2832   BasicBlock *ThenBB = BasicBlock::Create(Ctx, "then");
2833   BasicBlock *ElseBB = BasicBlock::Create(Ctx, "else");
2834   BasicBlock *MergeBB = BasicBlock::Create(Ctx, "ifcont");
2835 
2836   Builder.CreateCondBr(CondReduce, ThenBB, ElseBB);
2837   emitBlock(ThenBB, Builder.GetInsertBlock()->getParent());
2838   Value *LocalReduceListPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
2839       ReduceList, Builder.getPtrTy());
2840   Value *RemoteReduceListPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
2841       RemoteListAddrCast, Builder.getPtrTy());
2842   Builder.CreateCall(ReduceFn, {LocalReduceListPtr, RemoteReduceListPtr})
2843       ->addFnAttr(Attribute::NoUnwind);
2844   Builder.CreateBr(MergeBB);
2845 
2846   emitBlock(ElseBB, Builder.GetInsertBlock()->getParent());
2847   Builder.CreateBr(MergeBB);
2848 
2849   emitBlock(MergeBB, Builder.GetInsertBlock()->getParent());
2850 
2851   // if (AlgoVer==1 && (LaneId >= Offset)) copy Remote Reduce list to local
2852   // Reduce list.
2853   Algo1 = Builder.CreateICmpEQ(AlgoVer, Builder.getInt16(1));
2854   Value *LaneIdGtOffset = Builder.CreateICmpUGE(LaneId, RemoteLaneOffset);
2855   Value *CondCopy = Builder.CreateAnd(Algo1, LaneIdGtOffset);
2856 
2857   BasicBlock *CpyThenBB = BasicBlock::Create(Ctx, "then");
2858   BasicBlock *CpyElseBB = BasicBlock::Create(Ctx, "else");
2859   BasicBlock *CpyMergeBB = BasicBlock::Create(Ctx, "ifcont");
2860   Builder.CreateCondBr(CondCopy, CpyThenBB, CpyElseBB);
2861 
2862   emitBlock(CpyThenBB, Builder.GetInsertBlock()->getParent());
2863   emitReductionListCopy(AllocaIP, CopyAction::ThreadCopy, RedListArrayTy,
2864                         ReductionInfos, RemoteListAddrCast, ReduceList);
2865   Builder.CreateBr(CpyMergeBB);
2866 
2867   emitBlock(CpyElseBB, Builder.GetInsertBlock()->getParent());
2868   Builder.CreateBr(CpyMergeBB);
2869 
2870   emitBlock(CpyMergeBB, Builder.GetInsertBlock()->getParent());
2871 
2872   Builder.CreateRetVoid();
2873 
2874   return SarFunc;
2875 }
2876 
emitListToGlobalCopyFunction(ArrayRef<ReductionInfo> ReductionInfos,Type * ReductionsBufferTy,AttributeList FuncAttrs)2877 Function *OpenMPIRBuilder::emitListToGlobalCopyFunction(
2878     ArrayRef<ReductionInfo> ReductionInfos, Type *ReductionsBufferTy,
2879     AttributeList FuncAttrs) {
2880   OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
2881   LLVMContext &Ctx = M.getContext();
2882   FunctionType *FuncTy = FunctionType::get(
2883       Builder.getVoidTy(),
2884       {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
2885       /* IsVarArg */ false);
2886   Function *LtGCFunc =
2887       Function::Create(FuncTy, GlobalVariable::InternalLinkage,
2888                        "_omp_reduction_list_to_global_copy_func", &M);
2889   LtGCFunc->setAttributes(FuncAttrs);
2890   LtGCFunc->addParamAttr(0, Attribute::NoUndef);
2891   LtGCFunc->addParamAttr(1, Attribute::NoUndef);
2892   LtGCFunc->addParamAttr(2, Attribute::NoUndef);
2893 
2894   BasicBlock *EntryBlock = BasicBlock::Create(Ctx, "entry", LtGCFunc);
2895   Builder.SetInsertPoint(EntryBlock);
2896 
2897   // Buffer: global reduction buffer.
2898   Argument *BufferArg = LtGCFunc->getArg(0);
2899   // Idx: index of the buffer.
2900   Argument *IdxArg = LtGCFunc->getArg(1);
2901   // ReduceList: thread local Reduce list.
2902   Argument *ReduceListArg = LtGCFunc->getArg(2);
2903 
2904   Value *BufferArgAlloca = Builder.CreateAlloca(Builder.getPtrTy(), nullptr,
2905                                                 BufferArg->getName() + ".addr");
2906   Value *IdxArgAlloca = Builder.CreateAlloca(Builder.getInt32Ty(), nullptr,
2907                                              IdxArg->getName() + ".addr");
2908   Value *ReduceListArgAlloca = Builder.CreateAlloca(
2909       Builder.getPtrTy(), nullptr, ReduceListArg->getName() + ".addr");
2910   Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2911       BufferArgAlloca, Builder.getPtrTy(),
2912       BufferArgAlloca->getName() + ".ascast");
2913   Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2914       IdxArgAlloca, Builder.getPtrTy(), IdxArgAlloca->getName() + ".ascast");
2915   Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2916       ReduceListArgAlloca, Builder.getPtrTy(),
2917       ReduceListArgAlloca->getName() + ".ascast");
2918 
2919   Builder.CreateStore(BufferArg, BufferArgAddrCast);
2920   Builder.CreateStore(IdxArg, IdxArgAddrCast);
2921   Builder.CreateStore(ReduceListArg, ReduceListArgAddrCast);
2922 
2923   Value *LocalReduceList =
2924       Builder.CreateLoad(Builder.getPtrTy(), ReduceListArgAddrCast);
2925   Value *BufferArgVal =
2926       Builder.CreateLoad(Builder.getPtrTy(), BufferArgAddrCast);
2927   Value *Idxs[] = {Builder.CreateLoad(Builder.getInt32Ty(), IdxArgAddrCast)};
2928   Type *IndexTy = Builder.getIndexTy(
2929       M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
2930   for (auto En : enumerate(ReductionInfos)) {
2931     const ReductionInfo &RI = En.value();
2932     auto *RedListArrayTy =
2933         ArrayType::get(Builder.getPtrTy(), ReductionInfos.size());
2934     // Reduce element = LocalReduceList[i]
2935     Value *ElemPtrPtr = Builder.CreateInBoundsGEP(
2936         RedListArrayTy, LocalReduceList,
2937         {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
2938     // elemptr = ((CopyType*)(elemptrptr)) + I
2939     Value *ElemPtr = Builder.CreateLoad(Builder.getPtrTy(), ElemPtrPtr);
2940 
2941     // Global = Buffer.VD[Idx];
2942     Value *BufferVD =
2943         Builder.CreateInBoundsGEP(ReductionsBufferTy, BufferArgVal, Idxs);
2944     Value *GlobVal = Builder.CreateConstInBoundsGEP2_32(
2945         ReductionsBufferTy, BufferVD, 0, En.index());
2946 
2947     switch (RI.EvaluationKind) {
2948     case EvalKind::Scalar: {
2949       Value *TargetElement = Builder.CreateLoad(RI.ElementType, ElemPtr);
2950       Builder.CreateStore(TargetElement, GlobVal);
2951       break;
2952     }
2953     case EvalKind::Complex: {
2954       Value *SrcRealPtr = Builder.CreateConstInBoundsGEP2_32(
2955           RI.ElementType, ElemPtr, 0, 0, ".realp");
2956       Value *SrcReal = Builder.CreateLoad(
2957           RI.ElementType->getStructElementType(0), SrcRealPtr, ".real");
2958       Value *SrcImgPtr = Builder.CreateConstInBoundsGEP2_32(
2959           RI.ElementType, ElemPtr, 0, 1, ".imagp");
2960       Value *SrcImg = Builder.CreateLoad(
2961           RI.ElementType->getStructElementType(1), SrcImgPtr, ".imag");
2962 
2963       Value *DestRealPtr = Builder.CreateConstInBoundsGEP2_32(
2964           RI.ElementType, GlobVal, 0, 0, ".realp");
2965       Value *DestImgPtr = Builder.CreateConstInBoundsGEP2_32(
2966           RI.ElementType, GlobVal, 0, 1, ".imagp");
2967       Builder.CreateStore(SrcReal, DestRealPtr);
2968       Builder.CreateStore(SrcImg, DestImgPtr);
2969       break;
2970     }
2971     case EvalKind::Aggregate: {
2972       Value *SizeVal =
2973           Builder.getInt64(M.getDataLayout().getTypeStoreSize(RI.ElementType));
2974       Builder.CreateMemCpy(
2975           GlobVal, M.getDataLayout().getPrefTypeAlign(RI.ElementType), ElemPtr,
2976           M.getDataLayout().getPrefTypeAlign(RI.ElementType), SizeVal, false);
2977       break;
2978     }
2979     }
2980   }
2981 
2982   Builder.CreateRetVoid();
2983   Builder.restoreIP(OldIP);
2984   return LtGCFunc;
2985 }
2986 
emitListToGlobalReduceFunction(ArrayRef<ReductionInfo> ReductionInfos,Function * ReduceFn,Type * ReductionsBufferTy,AttributeList FuncAttrs)2987 Function *OpenMPIRBuilder::emitListToGlobalReduceFunction(
2988     ArrayRef<ReductionInfo> ReductionInfos, Function *ReduceFn,
2989     Type *ReductionsBufferTy, AttributeList FuncAttrs) {
2990   OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
2991   LLVMContext &Ctx = M.getContext();
2992   FunctionType *FuncTy = FunctionType::get(
2993       Builder.getVoidTy(),
2994       {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
2995       /* IsVarArg */ false);
2996   Function *LtGRFunc =
2997       Function::Create(FuncTy, GlobalVariable::InternalLinkage,
2998                        "_omp_reduction_list_to_global_reduce_func", &M);
2999   LtGRFunc->setAttributes(FuncAttrs);
3000   LtGRFunc->addParamAttr(0, Attribute::NoUndef);
3001   LtGRFunc->addParamAttr(1, Attribute::NoUndef);
3002   LtGRFunc->addParamAttr(2, Attribute::NoUndef);
3003 
3004   BasicBlock *EntryBlock = BasicBlock::Create(Ctx, "entry", LtGRFunc);
3005   Builder.SetInsertPoint(EntryBlock);
3006 
3007   // Buffer: global reduction buffer.
3008   Argument *BufferArg = LtGRFunc->getArg(0);
3009   // Idx: index of the buffer.
3010   Argument *IdxArg = LtGRFunc->getArg(1);
3011   // ReduceList: thread local Reduce list.
3012   Argument *ReduceListArg = LtGRFunc->getArg(2);
3013 
3014   Value *BufferArgAlloca = Builder.CreateAlloca(Builder.getPtrTy(), nullptr,
3015                                                 BufferArg->getName() + ".addr");
3016   Value *IdxArgAlloca = Builder.CreateAlloca(Builder.getInt32Ty(), nullptr,
3017                                              IdxArg->getName() + ".addr");
3018   Value *ReduceListArgAlloca = Builder.CreateAlloca(
3019       Builder.getPtrTy(), nullptr, ReduceListArg->getName() + ".addr");
3020   auto *RedListArrayTy =
3021       ArrayType::get(Builder.getPtrTy(), ReductionInfos.size());
3022 
3023   // 1. Build a list of reduction variables.
3024   // void *RedList[<n>] = {<ReductionVars>[0], ..., <ReductionVars>[<n>-1]};
3025   Value *LocalReduceList =
3026       Builder.CreateAlloca(RedListArrayTy, nullptr, ".omp.reduction.red_list");
3027 
3028   Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3029       BufferArgAlloca, Builder.getPtrTy(),
3030       BufferArgAlloca->getName() + ".ascast");
3031   Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3032       IdxArgAlloca, Builder.getPtrTy(), IdxArgAlloca->getName() + ".ascast");
3033   Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3034       ReduceListArgAlloca, Builder.getPtrTy(),
3035       ReduceListArgAlloca->getName() + ".ascast");
3036   Value *LocalReduceListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3037       LocalReduceList, Builder.getPtrTy(),
3038       LocalReduceList->getName() + ".ascast");
3039 
3040   Builder.CreateStore(BufferArg, BufferArgAddrCast);
3041   Builder.CreateStore(IdxArg, IdxArgAddrCast);
3042   Builder.CreateStore(ReduceListArg, ReduceListArgAddrCast);
3043 
3044   Value *BufferVal = Builder.CreateLoad(Builder.getPtrTy(), BufferArgAddrCast);
3045   Value *Idxs[] = {Builder.CreateLoad(Builder.getInt32Ty(), IdxArgAddrCast)};
3046   Type *IndexTy = Builder.getIndexTy(
3047       M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
3048   for (auto En : enumerate(ReductionInfos)) {
3049     Value *TargetElementPtrPtr = Builder.CreateInBoundsGEP(
3050         RedListArrayTy, LocalReduceListAddrCast,
3051         {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
3052     Value *BufferVD =
3053         Builder.CreateInBoundsGEP(ReductionsBufferTy, BufferVal, Idxs);
3054     // Global = Buffer.VD[Idx];
3055     Value *GlobValPtr = Builder.CreateConstInBoundsGEP2_32(
3056         ReductionsBufferTy, BufferVD, 0, En.index());
3057     Builder.CreateStore(GlobValPtr, TargetElementPtrPtr);
3058   }
3059 
3060   // Call reduce_function(GlobalReduceList, ReduceList)
3061   Value *ReduceList =
3062       Builder.CreateLoad(Builder.getPtrTy(), ReduceListArgAddrCast);
3063   Builder.CreateCall(ReduceFn, {LocalReduceListAddrCast, ReduceList})
3064       ->addFnAttr(Attribute::NoUnwind);
3065   Builder.CreateRetVoid();
3066   Builder.restoreIP(OldIP);
3067   return LtGRFunc;
3068 }
3069 
emitGlobalToListCopyFunction(ArrayRef<ReductionInfo> ReductionInfos,Type * ReductionsBufferTy,AttributeList FuncAttrs)3070 Function *OpenMPIRBuilder::emitGlobalToListCopyFunction(
3071     ArrayRef<ReductionInfo> ReductionInfos, Type *ReductionsBufferTy,
3072     AttributeList FuncAttrs) {
3073   OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
3074   LLVMContext &Ctx = M.getContext();
3075   FunctionType *FuncTy = FunctionType::get(
3076       Builder.getVoidTy(),
3077       {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
3078       /* IsVarArg */ false);
3079   Function *LtGCFunc =
3080       Function::Create(FuncTy, GlobalVariable::InternalLinkage,
3081                        "_omp_reduction_global_to_list_copy_func", &M);
3082   LtGCFunc->setAttributes(FuncAttrs);
3083   LtGCFunc->addParamAttr(0, Attribute::NoUndef);
3084   LtGCFunc->addParamAttr(1, Attribute::NoUndef);
3085   LtGCFunc->addParamAttr(2, Attribute::NoUndef);
3086 
3087   BasicBlock *EntryBlock = BasicBlock::Create(Ctx, "entry", LtGCFunc);
3088   Builder.SetInsertPoint(EntryBlock);
3089 
3090   // Buffer: global reduction buffer.
3091   Argument *BufferArg = LtGCFunc->getArg(0);
3092   // Idx: index of the buffer.
3093   Argument *IdxArg = LtGCFunc->getArg(1);
3094   // ReduceList: thread local Reduce list.
3095   Argument *ReduceListArg = LtGCFunc->getArg(2);
3096 
3097   Value *BufferArgAlloca = Builder.CreateAlloca(Builder.getPtrTy(), nullptr,
3098                                                 BufferArg->getName() + ".addr");
3099   Value *IdxArgAlloca = Builder.CreateAlloca(Builder.getInt32Ty(), nullptr,
3100                                              IdxArg->getName() + ".addr");
3101   Value *ReduceListArgAlloca = Builder.CreateAlloca(
3102       Builder.getPtrTy(), nullptr, ReduceListArg->getName() + ".addr");
3103   Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3104       BufferArgAlloca, Builder.getPtrTy(),
3105       BufferArgAlloca->getName() + ".ascast");
3106   Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3107       IdxArgAlloca, Builder.getPtrTy(), IdxArgAlloca->getName() + ".ascast");
3108   Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3109       ReduceListArgAlloca, Builder.getPtrTy(),
3110       ReduceListArgAlloca->getName() + ".ascast");
3111   Builder.CreateStore(BufferArg, BufferArgAddrCast);
3112   Builder.CreateStore(IdxArg, IdxArgAddrCast);
3113   Builder.CreateStore(ReduceListArg, ReduceListArgAddrCast);
3114 
3115   Value *LocalReduceList =
3116       Builder.CreateLoad(Builder.getPtrTy(), ReduceListArgAddrCast);
3117   Value *BufferVal = Builder.CreateLoad(Builder.getPtrTy(), BufferArgAddrCast);
3118   Value *Idxs[] = {Builder.CreateLoad(Builder.getInt32Ty(), IdxArgAddrCast)};
3119   Type *IndexTy = Builder.getIndexTy(
3120       M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
3121   for (auto En : enumerate(ReductionInfos)) {
3122     const OpenMPIRBuilder::ReductionInfo &RI = En.value();
3123     auto *RedListArrayTy =
3124         ArrayType::get(Builder.getPtrTy(), ReductionInfos.size());
3125     // Reduce element = LocalReduceList[i]
3126     Value *ElemPtrPtr = Builder.CreateInBoundsGEP(
3127         RedListArrayTy, LocalReduceList,
3128         {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
3129     // elemptr = ((CopyType*)(elemptrptr)) + I
3130     Value *ElemPtr = Builder.CreateLoad(Builder.getPtrTy(), ElemPtrPtr);
3131     // Global = Buffer.VD[Idx];
3132     Value *BufferVD =
3133         Builder.CreateInBoundsGEP(ReductionsBufferTy, BufferVal, Idxs);
3134     Value *GlobValPtr = Builder.CreateConstInBoundsGEP2_32(
3135         ReductionsBufferTy, BufferVD, 0, En.index());
3136 
3137     switch (RI.EvaluationKind) {
3138     case EvalKind::Scalar: {
3139       Value *TargetElement = Builder.CreateLoad(RI.ElementType, GlobValPtr);
3140       Builder.CreateStore(TargetElement, ElemPtr);
3141       break;
3142     }
3143     case EvalKind::Complex: {
3144       Value *SrcRealPtr = Builder.CreateConstInBoundsGEP2_32(
3145           RI.ElementType, GlobValPtr, 0, 0, ".realp");
3146       Value *SrcReal = Builder.CreateLoad(
3147           RI.ElementType->getStructElementType(0), SrcRealPtr, ".real");
3148       Value *SrcImgPtr = Builder.CreateConstInBoundsGEP2_32(
3149           RI.ElementType, GlobValPtr, 0, 1, ".imagp");
3150       Value *SrcImg = Builder.CreateLoad(
3151           RI.ElementType->getStructElementType(1), SrcImgPtr, ".imag");
3152 
3153       Value *DestRealPtr = Builder.CreateConstInBoundsGEP2_32(
3154           RI.ElementType, ElemPtr, 0, 0, ".realp");
3155       Value *DestImgPtr = Builder.CreateConstInBoundsGEP2_32(
3156           RI.ElementType, ElemPtr, 0, 1, ".imagp");
3157       Builder.CreateStore(SrcReal, DestRealPtr);
3158       Builder.CreateStore(SrcImg, DestImgPtr);
3159       break;
3160     }
3161     case EvalKind::Aggregate: {
3162       Value *SizeVal =
3163           Builder.getInt64(M.getDataLayout().getTypeStoreSize(RI.ElementType));
3164       Builder.CreateMemCpy(
3165           ElemPtr, M.getDataLayout().getPrefTypeAlign(RI.ElementType),
3166           GlobValPtr, M.getDataLayout().getPrefTypeAlign(RI.ElementType),
3167           SizeVal, false);
3168       break;
3169     }
3170     }
3171   }
3172 
3173   Builder.CreateRetVoid();
3174   Builder.restoreIP(OldIP);
3175   return LtGCFunc;
3176 }
3177 
emitGlobalToListReduceFunction(ArrayRef<ReductionInfo> ReductionInfos,Function * ReduceFn,Type * ReductionsBufferTy,AttributeList FuncAttrs)3178 Function *OpenMPIRBuilder::emitGlobalToListReduceFunction(
3179     ArrayRef<ReductionInfo> ReductionInfos, Function *ReduceFn,
3180     Type *ReductionsBufferTy, AttributeList FuncAttrs) {
3181   OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
3182   LLVMContext &Ctx = M.getContext();
3183   auto *FuncTy = FunctionType::get(
3184       Builder.getVoidTy(),
3185       {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
3186       /* IsVarArg */ false);
3187   Function *LtGRFunc =
3188       Function::Create(FuncTy, GlobalVariable::InternalLinkage,
3189                        "_omp_reduction_global_to_list_reduce_func", &M);
3190   LtGRFunc->setAttributes(FuncAttrs);
3191   LtGRFunc->addParamAttr(0, Attribute::NoUndef);
3192   LtGRFunc->addParamAttr(1, Attribute::NoUndef);
3193   LtGRFunc->addParamAttr(2, Attribute::NoUndef);
3194 
3195   BasicBlock *EntryBlock = BasicBlock::Create(Ctx, "entry", LtGRFunc);
3196   Builder.SetInsertPoint(EntryBlock);
3197 
3198   // Buffer: global reduction buffer.
3199   Argument *BufferArg = LtGRFunc->getArg(0);
3200   // Idx: index of the buffer.
3201   Argument *IdxArg = LtGRFunc->getArg(1);
3202   // ReduceList: thread local Reduce list.
3203   Argument *ReduceListArg = LtGRFunc->getArg(2);
3204 
3205   Value *BufferArgAlloca = Builder.CreateAlloca(Builder.getPtrTy(), nullptr,
3206                                                 BufferArg->getName() + ".addr");
3207   Value *IdxArgAlloca = Builder.CreateAlloca(Builder.getInt32Ty(), nullptr,
3208                                              IdxArg->getName() + ".addr");
3209   Value *ReduceListArgAlloca = Builder.CreateAlloca(
3210       Builder.getPtrTy(), nullptr, ReduceListArg->getName() + ".addr");
3211   ArrayType *RedListArrayTy =
3212       ArrayType::get(Builder.getPtrTy(), ReductionInfos.size());
3213 
3214   // 1. Build a list of reduction variables.
3215   // void *RedList[<n>] = {<ReductionVars>[0], ..., <ReductionVars>[<n>-1]};
3216   Value *LocalReduceList =
3217       Builder.CreateAlloca(RedListArrayTy, nullptr, ".omp.reduction.red_list");
3218 
3219   Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3220       BufferArgAlloca, Builder.getPtrTy(),
3221       BufferArgAlloca->getName() + ".ascast");
3222   Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3223       IdxArgAlloca, Builder.getPtrTy(), IdxArgAlloca->getName() + ".ascast");
3224   Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3225       ReduceListArgAlloca, Builder.getPtrTy(),
3226       ReduceListArgAlloca->getName() + ".ascast");
3227   Value *ReductionList = Builder.CreatePointerBitCastOrAddrSpaceCast(
3228       LocalReduceList, Builder.getPtrTy(),
3229       LocalReduceList->getName() + ".ascast");
3230 
3231   Builder.CreateStore(BufferArg, BufferArgAddrCast);
3232   Builder.CreateStore(IdxArg, IdxArgAddrCast);
3233   Builder.CreateStore(ReduceListArg, ReduceListArgAddrCast);
3234 
3235   Value *BufferVal = Builder.CreateLoad(Builder.getPtrTy(), BufferArgAddrCast);
3236   Value *Idxs[] = {Builder.CreateLoad(Builder.getInt32Ty(), IdxArgAddrCast)};
3237   Type *IndexTy = Builder.getIndexTy(
3238       M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
3239   for (auto En : enumerate(ReductionInfos)) {
3240     Value *TargetElementPtrPtr = Builder.CreateInBoundsGEP(
3241         RedListArrayTy, ReductionList,
3242         {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
3243     // Global = Buffer.VD[Idx];
3244     Value *BufferVD =
3245         Builder.CreateInBoundsGEP(ReductionsBufferTy, BufferVal, Idxs);
3246     Value *GlobValPtr = Builder.CreateConstInBoundsGEP2_32(
3247         ReductionsBufferTy, BufferVD, 0, En.index());
3248     Builder.CreateStore(GlobValPtr, TargetElementPtrPtr);
3249   }
3250 
3251   // Call reduce_function(ReduceList, GlobalReduceList)
3252   Value *ReduceList =
3253       Builder.CreateLoad(Builder.getPtrTy(), ReduceListArgAddrCast);
3254   Builder.CreateCall(ReduceFn, {ReduceList, ReductionList})
3255       ->addFnAttr(Attribute::NoUnwind);
3256   Builder.CreateRetVoid();
3257   Builder.restoreIP(OldIP);
3258   return LtGRFunc;
3259 }
3260 
getReductionFuncName(StringRef Name) const3261 std::string OpenMPIRBuilder::getReductionFuncName(StringRef Name) const {
3262   std::string Suffix =
3263       createPlatformSpecificName({"omp", "reduction", "reduction_func"});
3264   return (Name + Suffix).str();
3265 }
3266 
createReductionFunction(StringRef ReducerName,ArrayRef<ReductionInfo> ReductionInfos,ReductionGenCBKind ReductionGenCBKind,AttributeList FuncAttrs)3267 Function *OpenMPIRBuilder::createReductionFunction(
3268     StringRef ReducerName, ArrayRef<ReductionInfo> ReductionInfos,
3269     ReductionGenCBKind ReductionGenCBKind, AttributeList FuncAttrs) {
3270   auto *FuncTy = FunctionType::get(Builder.getVoidTy(),
3271                                    {Builder.getPtrTy(), Builder.getPtrTy()},
3272                                    /* IsVarArg */ false);
3273   std::string Name = getReductionFuncName(ReducerName);
3274   Function *ReductionFunc =
3275       Function::Create(FuncTy, GlobalVariable::InternalLinkage, Name, &M);
3276   ReductionFunc->setAttributes(FuncAttrs);
3277   ReductionFunc->addParamAttr(0, Attribute::NoUndef);
3278   ReductionFunc->addParamAttr(1, Attribute::NoUndef);
3279   BasicBlock *EntryBB =
3280       BasicBlock::Create(M.getContext(), "entry", ReductionFunc);
3281   Builder.SetInsertPoint(EntryBB);
3282 
3283   // Need to alloca memory here and deal with the pointers before getting
3284   // LHS/RHS pointers out
3285   Value *LHSArrayPtr = nullptr;
3286   Value *RHSArrayPtr = nullptr;
3287   Argument *Arg0 = ReductionFunc->getArg(0);
3288   Argument *Arg1 = ReductionFunc->getArg(1);
3289   Type *Arg0Type = Arg0->getType();
3290   Type *Arg1Type = Arg1->getType();
3291 
3292   Value *LHSAlloca =
3293       Builder.CreateAlloca(Arg0Type, nullptr, Arg0->getName() + ".addr");
3294   Value *RHSAlloca =
3295       Builder.CreateAlloca(Arg1Type, nullptr, Arg1->getName() + ".addr");
3296   Value *LHSAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3297       LHSAlloca, Arg0Type, LHSAlloca->getName() + ".ascast");
3298   Value *RHSAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3299       RHSAlloca, Arg1Type, RHSAlloca->getName() + ".ascast");
3300   Builder.CreateStore(Arg0, LHSAddrCast);
3301   Builder.CreateStore(Arg1, RHSAddrCast);
3302   LHSArrayPtr = Builder.CreateLoad(Arg0Type, LHSAddrCast);
3303   RHSArrayPtr = Builder.CreateLoad(Arg1Type, RHSAddrCast);
3304 
3305   Type *RedArrayTy = ArrayType::get(Builder.getPtrTy(), ReductionInfos.size());
3306   Type *IndexTy = Builder.getIndexTy(
3307       M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
3308   SmallVector<Value *> LHSPtrs, RHSPtrs;
3309   for (auto En : enumerate(ReductionInfos)) {
3310     const ReductionInfo &RI = En.value();
3311     Value *RHSI8PtrPtr = Builder.CreateInBoundsGEP(
3312         RedArrayTy, RHSArrayPtr,
3313         {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
3314     Value *RHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), RHSI8PtrPtr);
3315     Value *RHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
3316         RHSI8Ptr, RI.PrivateVariable->getType(),
3317         RHSI8Ptr->getName() + ".ascast");
3318 
3319     Value *LHSI8PtrPtr = Builder.CreateInBoundsGEP(
3320         RedArrayTy, LHSArrayPtr,
3321         {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
3322     Value *LHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), LHSI8PtrPtr);
3323     Value *LHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
3324         LHSI8Ptr, RI.Variable->getType(), LHSI8Ptr->getName() + ".ascast");
3325 
3326     if (ReductionGenCBKind == ReductionGenCBKind::Clang) {
3327       LHSPtrs.emplace_back(LHSPtr);
3328       RHSPtrs.emplace_back(RHSPtr);
3329     } else {
3330       Value *LHS = Builder.CreateLoad(RI.ElementType, LHSPtr);
3331       Value *RHS = Builder.CreateLoad(RI.ElementType, RHSPtr);
3332       Value *Reduced;
3333       RI.ReductionGen(Builder.saveIP(), LHS, RHS, Reduced);
3334       if (!Builder.GetInsertBlock())
3335         return ReductionFunc;
3336       Builder.CreateStore(Reduced, LHSPtr);
3337     }
3338   }
3339 
3340   if (ReductionGenCBKind == ReductionGenCBKind::Clang)
3341     for (auto En : enumerate(ReductionInfos)) {
3342       unsigned Index = En.index();
3343       const ReductionInfo &RI = En.value();
3344       Value *LHSFixupPtr, *RHSFixupPtr;
3345       Builder.restoreIP(RI.ReductionGenClang(
3346           Builder.saveIP(), Index, &LHSFixupPtr, &RHSFixupPtr, ReductionFunc));
3347 
3348       // Fix the CallBack code genereated to use the correct Values for the LHS
3349       // and RHS
3350       LHSFixupPtr->replaceUsesWithIf(
3351           LHSPtrs[Index], [ReductionFunc](const Use &U) {
3352             return cast<Instruction>(U.getUser())->getParent()->getParent() ==
3353                    ReductionFunc;
3354           });
3355       RHSFixupPtr->replaceUsesWithIf(
3356           RHSPtrs[Index], [ReductionFunc](const Use &U) {
3357             return cast<Instruction>(U.getUser())->getParent()->getParent() ==
3358                    ReductionFunc;
3359           });
3360     }
3361 
3362   Builder.CreateRetVoid();
3363   return ReductionFunc;
3364 }
3365 
3366 static void
checkReductionInfos(ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos,bool IsGPU)3367 checkReductionInfos(ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos,
3368                     bool IsGPU) {
3369   for (const OpenMPIRBuilder::ReductionInfo &RI : ReductionInfos) {
3370     (void)RI;
3371     assert(RI.Variable && "expected non-null variable");
3372     assert(RI.PrivateVariable && "expected non-null private variable");
3373     assert((RI.ReductionGen || RI.ReductionGenClang) &&
3374            "expected non-null reduction generator callback");
3375     if (!IsGPU) {
3376       assert(
3377           RI.Variable->getType() == RI.PrivateVariable->getType() &&
3378           "expected variables and their private equivalents to have the same "
3379           "type");
3380     }
3381     assert(RI.Variable->getType()->isPointerTy() &&
3382            "expected variables to be pointers");
3383   }
3384 }
3385 
createReductionsGPU(const LocationDescription & Loc,InsertPointTy AllocaIP,InsertPointTy CodeGenIP,ArrayRef<ReductionInfo> ReductionInfos,bool IsNoWait,bool IsTeamsReduction,bool HasDistribute,ReductionGenCBKind ReductionGenCBKind,std::optional<omp::GV> GridValue,unsigned ReductionBufNum,Value * SrcLocInfo)3386 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createReductionsGPU(
3387     const LocationDescription &Loc, InsertPointTy AllocaIP,
3388     InsertPointTy CodeGenIP, ArrayRef<ReductionInfo> ReductionInfos,
3389     bool IsNoWait, bool IsTeamsReduction, bool HasDistribute,
3390     ReductionGenCBKind ReductionGenCBKind, std::optional<omp::GV> GridValue,
3391     unsigned ReductionBufNum, Value *SrcLocInfo) {
3392   if (!updateToLocation(Loc))
3393     return InsertPointTy();
3394   Builder.restoreIP(CodeGenIP);
3395   checkReductionInfos(ReductionInfos, /*IsGPU*/ true);
3396   LLVMContext &Ctx = M.getContext();
3397 
3398   // Source location for the ident struct
3399   if (!SrcLocInfo) {
3400     uint32_t SrcLocStrSize;
3401     Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3402     SrcLocInfo = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3403   }
3404 
3405   if (ReductionInfos.size() == 0)
3406     return Builder.saveIP();
3407 
3408   Function *CurFunc = Builder.GetInsertBlock()->getParent();
3409   AttributeList FuncAttrs;
3410   AttrBuilder AttrBldr(Ctx);
3411   for (auto Attr : CurFunc->getAttributes().getFnAttrs())
3412     AttrBldr.addAttribute(Attr);
3413   AttrBldr.removeAttribute(Attribute::OptimizeNone);
3414   FuncAttrs = FuncAttrs.addFnAttributes(Ctx, AttrBldr);
3415 
3416   Function *ReductionFunc = nullptr;
3417   CodeGenIP = Builder.saveIP();
3418   ReductionFunc =
3419       createReductionFunction(Builder.GetInsertBlock()->getParent()->getName(),
3420                               ReductionInfos, ReductionGenCBKind, FuncAttrs);
3421   Builder.restoreIP(CodeGenIP);
3422 
3423   // Set the grid value in the config needed for lowering later on
3424   if (GridValue.has_value())
3425     Config.setGridValue(GridValue.value());
3426   else
3427     Config.setGridValue(getGridValue(T, ReductionFunc));
3428 
3429   uint32_t SrcLocStrSize;
3430   Constant *SrcLocStr = getOrCreateDefaultSrcLocStr(SrcLocStrSize);
3431   Value *RTLoc =
3432       getOrCreateIdent(SrcLocStr, SrcLocStrSize, omp::IdentFlag(0), 0);
3433 
3434   // Build res = __kmpc_reduce{_nowait}(<gtid>, <n>, sizeof(RedList),
3435   // RedList, shuffle_reduce_func, interwarp_copy_func);
3436   // or
3437   // Build res = __kmpc_reduce_teams_nowait_simple(<loc>, <gtid>, <lck>);
3438   Value *Res;
3439 
3440   // 1. Build a list of reduction variables.
3441   // void *RedList[<n>] = {<ReductionVars>[0], ..., <ReductionVars>[<n>-1]};
3442   auto Size = ReductionInfos.size();
3443   Type *PtrTy = PointerType::getUnqual(Ctx);
3444   Type *RedArrayTy = ArrayType::get(PtrTy, Size);
3445   CodeGenIP = Builder.saveIP();
3446   Builder.restoreIP(AllocaIP);
3447   Value *ReductionListAlloca =
3448       Builder.CreateAlloca(RedArrayTy, nullptr, ".omp.reduction.red_list");
3449   Value *ReductionList = Builder.CreatePointerBitCastOrAddrSpaceCast(
3450       ReductionListAlloca, PtrTy, ReductionListAlloca->getName() + ".ascast");
3451   Builder.restoreIP(CodeGenIP);
3452   Type *IndexTy = Builder.getIndexTy(
3453       M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
3454   for (auto En : enumerate(ReductionInfos)) {
3455     const ReductionInfo &RI = En.value();
3456     Value *ElemPtr = Builder.CreateInBoundsGEP(
3457         RedArrayTy, ReductionList,
3458         {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
3459     Value *CastElem =
3460         Builder.CreatePointerBitCastOrAddrSpaceCast(RI.PrivateVariable, PtrTy);
3461     Builder.CreateStore(CastElem, ElemPtr);
3462   }
3463   CodeGenIP = Builder.saveIP();
3464   Function *SarFunc =
3465       emitShuffleAndReduceFunction(ReductionInfos, ReductionFunc, FuncAttrs);
3466   Function *WcFunc = emitInterWarpCopyFunction(Loc, ReductionInfos, FuncAttrs);
3467   Builder.restoreIP(CodeGenIP);
3468 
3469   Value *RL = Builder.CreatePointerBitCastOrAddrSpaceCast(ReductionList, PtrTy);
3470 
3471   unsigned MaxDataSize = 0;
3472   SmallVector<Type *> ReductionTypeArgs;
3473   for (auto En : enumerate(ReductionInfos)) {
3474     auto Size = M.getDataLayout().getTypeStoreSize(En.value().ElementType);
3475     if (Size > MaxDataSize)
3476       MaxDataSize = Size;
3477     ReductionTypeArgs.emplace_back(En.value().ElementType);
3478   }
3479   Value *ReductionDataSize =
3480       Builder.getInt64(MaxDataSize * ReductionInfos.size());
3481   if (!IsTeamsReduction) {
3482     Value *SarFuncCast =
3483         Builder.CreatePointerBitCastOrAddrSpaceCast(SarFunc, PtrTy);
3484     Value *WcFuncCast =
3485         Builder.CreatePointerBitCastOrAddrSpaceCast(WcFunc, PtrTy);
3486     Value *Args[] = {RTLoc, ReductionDataSize, RL, SarFuncCast, WcFuncCast};
3487     Function *Pv2Ptr = getOrCreateRuntimeFunctionPtr(
3488         RuntimeFunction::OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2);
3489     Res = Builder.CreateCall(Pv2Ptr, Args);
3490   } else {
3491     CodeGenIP = Builder.saveIP();
3492     StructType *ReductionsBufferTy = StructType::create(
3493         Ctx, ReductionTypeArgs, "struct._globalized_locals_ty");
3494     Function *RedFixedBuferFn = getOrCreateRuntimeFunctionPtr(
3495         RuntimeFunction::OMPRTL___kmpc_reduction_get_fixed_buffer);
3496     Function *LtGCFunc = emitListToGlobalCopyFunction(
3497         ReductionInfos, ReductionsBufferTy, FuncAttrs);
3498     Function *LtGRFunc = emitListToGlobalReduceFunction(
3499         ReductionInfos, ReductionFunc, ReductionsBufferTy, FuncAttrs);
3500     Function *GtLCFunc = emitGlobalToListCopyFunction(
3501         ReductionInfos, ReductionsBufferTy, FuncAttrs);
3502     Function *GtLRFunc = emitGlobalToListReduceFunction(
3503         ReductionInfos, ReductionFunc, ReductionsBufferTy, FuncAttrs);
3504     Builder.restoreIP(CodeGenIP);
3505 
3506     Value *KernelTeamsReductionPtr = Builder.CreateCall(
3507         RedFixedBuferFn, {}, "_openmp_teams_reductions_buffer_$_$ptr");
3508 
3509     Value *Args3[] = {RTLoc,
3510                       KernelTeamsReductionPtr,
3511                       Builder.getInt32(ReductionBufNum),
3512                       ReductionDataSize,
3513                       RL,
3514                       SarFunc,
3515                       WcFunc,
3516                       LtGCFunc,
3517                       LtGRFunc,
3518                       GtLCFunc,
3519                       GtLRFunc};
3520 
3521     Function *TeamsReduceFn = getOrCreateRuntimeFunctionPtr(
3522         RuntimeFunction::OMPRTL___kmpc_nvptx_teams_reduce_nowait_v2);
3523     Res = Builder.CreateCall(TeamsReduceFn, Args3);
3524   }
3525 
3526   // 5. Build if (res == 1)
3527   BasicBlock *ExitBB = BasicBlock::Create(Ctx, ".omp.reduction.done");
3528   BasicBlock *ThenBB = BasicBlock::Create(Ctx, ".omp.reduction.then");
3529   Value *Cond = Builder.CreateICmpEQ(Res, Builder.getInt32(1));
3530   Builder.CreateCondBr(Cond, ThenBB, ExitBB);
3531 
3532   // 6. Build then branch: where we have reduced values in the master
3533   //    thread in each team.
3534   //    __kmpc_end_reduce{_nowait}(<gtid>);
3535   //    break;
3536   emitBlock(ThenBB, CurFunc);
3537 
3538   // Add emission of __kmpc_end_reduce{_nowait}(<gtid>);
3539   for (auto En : enumerate(ReductionInfos)) {
3540     const ReductionInfo &RI = En.value();
3541     Value *LHS = RI.Variable;
3542     Value *RHS =
3543         Builder.CreatePointerBitCastOrAddrSpaceCast(RI.PrivateVariable, PtrTy);
3544 
3545     if (ReductionGenCBKind == ReductionGenCBKind::Clang) {
3546       Value *LHSPtr, *RHSPtr;
3547       Builder.restoreIP(RI.ReductionGenClang(Builder.saveIP(), En.index(),
3548                                              &LHSPtr, &RHSPtr, CurFunc));
3549 
3550       // Fix the CallBack code genereated to use the correct Values for the LHS
3551       // and RHS
3552       LHSPtr->replaceUsesWithIf(LHS, [ReductionFunc](const Use &U) {
3553         return cast<Instruction>(U.getUser())->getParent()->getParent() ==
3554                ReductionFunc;
3555       });
3556       RHSPtr->replaceUsesWithIf(RHS, [ReductionFunc](const Use &U) {
3557         return cast<Instruction>(U.getUser())->getParent()->getParent() ==
3558                ReductionFunc;
3559       });
3560     } else {
3561       assert(false && "Unhandled ReductionGenCBKind");
3562     }
3563   }
3564   emitBlock(ExitBB, CurFunc);
3565 
3566   Config.setEmitLLVMUsed();
3567 
3568   return Builder.saveIP();
3569 }
3570 
getFreshReductionFunc(Module & M)3571 static Function *getFreshReductionFunc(Module &M) {
3572   Type *VoidTy = Type::getVoidTy(M.getContext());
3573   Type *Int8PtrTy = PointerType::getUnqual(M.getContext());
3574   auto *FuncTy =
3575       FunctionType::get(VoidTy, {Int8PtrTy, Int8PtrTy}, /* IsVarArg */ false);
3576   return Function::Create(FuncTy, GlobalVariable::InternalLinkage,
3577                           ".omp.reduction.func", &M);
3578 }
3579 
3580 OpenMPIRBuilder::InsertPointTy
createReductions(const LocationDescription & Loc,InsertPointTy AllocaIP,ArrayRef<ReductionInfo> ReductionInfos,ArrayRef<bool> IsByRef,bool IsNoWait)3581 OpenMPIRBuilder::createReductions(const LocationDescription &Loc,
3582                                   InsertPointTy AllocaIP,
3583                                   ArrayRef<ReductionInfo> ReductionInfos,
3584                                   ArrayRef<bool> IsByRef, bool IsNoWait) {
3585   assert(ReductionInfos.size() == IsByRef.size());
3586   for (const ReductionInfo &RI : ReductionInfos) {
3587     (void)RI;
3588     assert(RI.Variable && "expected non-null variable");
3589     assert(RI.PrivateVariable && "expected non-null private variable");
3590     assert(RI.ReductionGen && "expected non-null reduction generator callback");
3591     assert(RI.Variable->getType() == RI.PrivateVariable->getType() &&
3592            "expected variables and their private equivalents to have the same "
3593            "type");
3594     assert(RI.Variable->getType()->isPointerTy() &&
3595            "expected variables to be pointers");
3596   }
3597 
3598   if (!updateToLocation(Loc))
3599     return InsertPointTy();
3600 
3601   BasicBlock *InsertBlock = Loc.IP.getBlock();
3602   BasicBlock *ContinuationBlock =
3603       InsertBlock->splitBasicBlock(Loc.IP.getPoint(), "reduce.finalize");
3604   InsertBlock->getTerminator()->eraseFromParent();
3605 
3606   // Create and populate array of type-erased pointers to private reduction
3607   // values.
3608   unsigned NumReductions = ReductionInfos.size();
3609   Type *RedArrayTy = ArrayType::get(Builder.getPtrTy(), NumReductions);
3610   Builder.SetInsertPoint(AllocaIP.getBlock()->getTerminator());
3611   Value *RedArray = Builder.CreateAlloca(RedArrayTy, nullptr, "red.array");
3612 
3613   Builder.SetInsertPoint(InsertBlock, InsertBlock->end());
3614 
3615   for (auto En : enumerate(ReductionInfos)) {
3616     unsigned Index = En.index();
3617     const ReductionInfo &RI = En.value();
3618     Value *RedArrayElemPtr = Builder.CreateConstInBoundsGEP2_64(
3619         RedArrayTy, RedArray, 0, Index, "red.array.elem." + Twine(Index));
3620     Builder.CreateStore(RI.PrivateVariable, RedArrayElemPtr);
3621   }
3622 
3623   // Emit a call to the runtime function that orchestrates the reduction.
3624   // Declare the reduction function in the process.
3625   Function *Func = Builder.GetInsertBlock()->getParent();
3626   Module *Module = Func->getParent();
3627   uint32_t SrcLocStrSize;
3628   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3629   bool CanGenerateAtomic = all_of(ReductionInfos, [](const ReductionInfo &RI) {
3630     return RI.AtomicReductionGen;
3631   });
3632   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize,
3633                                   CanGenerateAtomic
3634                                       ? IdentFlag::OMP_IDENT_FLAG_ATOMIC_REDUCE
3635                                       : IdentFlag(0));
3636   Value *ThreadId = getOrCreateThreadID(Ident);
3637   Constant *NumVariables = Builder.getInt32(NumReductions);
3638   const DataLayout &DL = Module->getDataLayout();
3639   unsigned RedArrayByteSize = DL.getTypeStoreSize(RedArrayTy);
3640   Constant *RedArraySize = Builder.getInt64(RedArrayByteSize);
3641   Function *ReductionFunc = getFreshReductionFunc(*Module);
3642   Value *Lock = getOMPCriticalRegionLock(".reduction");
3643   Function *ReduceFunc = getOrCreateRuntimeFunctionPtr(
3644       IsNoWait ? RuntimeFunction::OMPRTL___kmpc_reduce_nowait
3645                : RuntimeFunction::OMPRTL___kmpc_reduce);
3646   CallInst *ReduceCall =
3647       Builder.CreateCall(ReduceFunc,
3648                          {Ident, ThreadId, NumVariables, RedArraySize, RedArray,
3649                           ReductionFunc, Lock},
3650                          "reduce");
3651 
3652   // Create final reduction entry blocks for the atomic and non-atomic case.
3653   // Emit IR that dispatches control flow to one of the blocks based on the
3654   // reduction supporting the atomic mode.
3655   BasicBlock *NonAtomicRedBlock =
3656       BasicBlock::Create(Module->getContext(), "reduce.switch.nonatomic", Func);
3657   BasicBlock *AtomicRedBlock =
3658       BasicBlock::Create(Module->getContext(), "reduce.switch.atomic", Func);
3659   SwitchInst *Switch =
3660       Builder.CreateSwitch(ReduceCall, ContinuationBlock, /* NumCases */ 2);
3661   Switch->addCase(Builder.getInt32(1), NonAtomicRedBlock);
3662   Switch->addCase(Builder.getInt32(2), AtomicRedBlock);
3663 
3664   // Populate the non-atomic reduction using the elementwise reduction function.
3665   // This loads the elements from the global and private variables and reduces
3666   // them before storing back the result to the global variable.
3667   Builder.SetInsertPoint(NonAtomicRedBlock);
3668   for (auto En : enumerate(ReductionInfos)) {
3669     const ReductionInfo &RI = En.value();
3670     Type *ValueType = RI.ElementType;
3671     // We have one less load for by-ref case because that load is now inside of
3672     // the reduction region
3673     Value *RedValue = nullptr;
3674     if (!IsByRef[En.index()]) {
3675       RedValue = Builder.CreateLoad(ValueType, RI.Variable,
3676                                     "red.value." + Twine(En.index()));
3677     }
3678     Value *PrivateRedValue =
3679         Builder.CreateLoad(ValueType, RI.PrivateVariable,
3680                            "red.private.value." + Twine(En.index()));
3681     Value *Reduced;
3682     if (IsByRef[En.index()]) {
3683       Builder.restoreIP(RI.ReductionGen(Builder.saveIP(), RI.Variable,
3684                                         PrivateRedValue, Reduced));
3685     } else {
3686       Builder.restoreIP(RI.ReductionGen(Builder.saveIP(), RedValue,
3687                                         PrivateRedValue, Reduced));
3688     }
3689     if (!Builder.GetInsertBlock())
3690       return InsertPointTy();
3691     // for by-ref case, the load is inside of the reduction region
3692     if (!IsByRef[En.index()])
3693       Builder.CreateStore(Reduced, RI.Variable);
3694   }
3695   Function *EndReduceFunc = getOrCreateRuntimeFunctionPtr(
3696       IsNoWait ? RuntimeFunction::OMPRTL___kmpc_end_reduce_nowait
3697                : RuntimeFunction::OMPRTL___kmpc_end_reduce);
3698   Builder.CreateCall(EndReduceFunc, {Ident, ThreadId, Lock});
3699   Builder.CreateBr(ContinuationBlock);
3700 
3701   // Populate the atomic reduction using the atomic elementwise reduction
3702   // function. There are no loads/stores here because they will be happening
3703   // inside the atomic elementwise reduction.
3704   Builder.SetInsertPoint(AtomicRedBlock);
3705   if (CanGenerateAtomic && llvm::none_of(IsByRef, [](bool P) { return P; })) {
3706     for (const ReductionInfo &RI : ReductionInfos) {
3707       Builder.restoreIP(RI.AtomicReductionGen(Builder.saveIP(), RI.ElementType,
3708                                               RI.Variable, RI.PrivateVariable));
3709       if (!Builder.GetInsertBlock())
3710         return InsertPointTy();
3711     }
3712     Builder.CreateBr(ContinuationBlock);
3713   } else {
3714     Builder.CreateUnreachable();
3715   }
3716 
3717   // Populate the outlined reduction function using the elementwise reduction
3718   // function. Partial values are extracted from the type-erased array of
3719   // pointers to private variables.
3720   BasicBlock *ReductionFuncBlock =
3721       BasicBlock::Create(Module->getContext(), "", ReductionFunc);
3722   Builder.SetInsertPoint(ReductionFuncBlock);
3723   Value *LHSArrayPtr = ReductionFunc->getArg(0);
3724   Value *RHSArrayPtr = ReductionFunc->getArg(1);
3725 
3726   for (auto En : enumerate(ReductionInfos)) {
3727     const ReductionInfo &RI = En.value();
3728     Value *LHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
3729         RedArrayTy, LHSArrayPtr, 0, En.index());
3730     Value *LHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), LHSI8PtrPtr);
3731     Value *LHSPtr = Builder.CreateBitCast(LHSI8Ptr, RI.Variable->getType());
3732     Value *LHS = Builder.CreateLoad(RI.ElementType, LHSPtr);
3733     Value *RHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
3734         RedArrayTy, RHSArrayPtr, 0, En.index());
3735     Value *RHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), RHSI8PtrPtr);
3736     Value *RHSPtr =
3737         Builder.CreateBitCast(RHSI8Ptr, RI.PrivateVariable->getType());
3738     Value *RHS = Builder.CreateLoad(RI.ElementType, RHSPtr);
3739     Value *Reduced;
3740     Builder.restoreIP(RI.ReductionGen(Builder.saveIP(), LHS, RHS, Reduced));
3741     if (!Builder.GetInsertBlock())
3742       return InsertPointTy();
3743     // store is inside of the reduction region when using by-ref
3744     if (!IsByRef[En.index()])
3745       Builder.CreateStore(Reduced, LHSPtr);
3746   }
3747   Builder.CreateRetVoid();
3748 
3749   Builder.SetInsertPoint(ContinuationBlock);
3750   return Builder.saveIP();
3751 }
3752 
3753 OpenMPIRBuilder::InsertPointTy
createMaster(const LocationDescription & Loc,BodyGenCallbackTy BodyGenCB,FinalizeCallbackTy FiniCB)3754 OpenMPIRBuilder::createMaster(const LocationDescription &Loc,
3755                               BodyGenCallbackTy BodyGenCB,
3756                               FinalizeCallbackTy FiniCB) {
3757 
3758   if (!updateToLocation(Loc))
3759     return Loc.IP;
3760 
3761   Directive OMPD = Directive::OMPD_master;
3762   uint32_t SrcLocStrSize;
3763   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3764   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3765   Value *ThreadId = getOrCreateThreadID(Ident);
3766   Value *Args[] = {Ident, ThreadId};
3767 
3768   Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_master);
3769   Instruction *EntryCall = Builder.CreateCall(EntryRTLFn, Args);
3770 
3771   Function *ExitRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_master);
3772   Instruction *ExitCall = Builder.CreateCall(ExitRTLFn, Args);
3773 
3774   return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
3775                               /*Conditional*/ true, /*hasFinalize*/ true);
3776 }
3777 
3778 OpenMPIRBuilder::InsertPointTy
createMasked(const LocationDescription & Loc,BodyGenCallbackTy BodyGenCB,FinalizeCallbackTy FiniCB,Value * Filter)3779 OpenMPIRBuilder::createMasked(const LocationDescription &Loc,
3780                               BodyGenCallbackTy BodyGenCB,
3781                               FinalizeCallbackTy FiniCB, Value *Filter) {
3782   if (!updateToLocation(Loc))
3783     return Loc.IP;
3784 
3785   Directive OMPD = Directive::OMPD_masked;
3786   uint32_t SrcLocStrSize;
3787   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3788   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3789   Value *ThreadId = getOrCreateThreadID(Ident);
3790   Value *Args[] = {Ident, ThreadId, Filter};
3791   Value *ArgsEnd[] = {Ident, ThreadId};
3792 
3793   Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_masked);
3794   Instruction *EntryCall = Builder.CreateCall(EntryRTLFn, Args);
3795 
3796   Function *ExitRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_masked);
3797   Instruction *ExitCall = Builder.CreateCall(ExitRTLFn, ArgsEnd);
3798 
3799   return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
3800                               /*Conditional*/ true, /*hasFinalize*/ true);
3801 }
3802 
createLoopSkeleton(DebugLoc DL,Value * TripCount,Function * F,BasicBlock * PreInsertBefore,BasicBlock * PostInsertBefore,const Twine & Name)3803 CanonicalLoopInfo *OpenMPIRBuilder::createLoopSkeleton(
3804     DebugLoc DL, Value *TripCount, Function *F, BasicBlock *PreInsertBefore,
3805     BasicBlock *PostInsertBefore, const Twine &Name) {
3806   Module *M = F->getParent();
3807   LLVMContext &Ctx = M->getContext();
3808   Type *IndVarTy = TripCount->getType();
3809 
3810   // Create the basic block structure.
3811   BasicBlock *Preheader =
3812       BasicBlock::Create(Ctx, "omp_" + Name + ".preheader", F, PreInsertBefore);
3813   BasicBlock *Header =
3814       BasicBlock::Create(Ctx, "omp_" + Name + ".header", F, PreInsertBefore);
3815   BasicBlock *Cond =
3816       BasicBlock::Create(Ctx, "omp_" + Name + ".cond", F, PreInsertBefore);
3817   BasicBlock *Body =
3818       BasicBlock::Create(Ctx, "omp_" + Name + ".body", F, PreInsertBefore);
3819   BasicBlock *Latch =
3820       BasicBlock::Create(Ctx, "omp_" + Name + ".inc", F, PostInsertBefore);
3821   BasicBlock *Exit =
3822       BasicBlock::Create(Ctx, "omp_" + Name + ".exit", F, PostInsertBefore);
3823   BasicBlock *After =
3824       BasicBlock::Create(Ctx, "omp_" + Name + ".after", F, PostInsertBefore);
3825 
3826   // Use specified DebugLoc for new instructions.
3827   Builder.SetCurrentDebugLocation(DL);
3828 
3829   Builder.SetInsertPoint(Preheader);
3830   Builder.CreateBr(Header);
3831 
3832   Builder.SetInsertPoint(Header);
3833   PHINode *IndVarPHI = Builder.CreatePHI(IndVarTy, 2, "omp_" + Name + ".iv");
3834   IndVarPHI->addIncoming(ConstantInt::get(IndVarTy, 0), Preheader);
3835   Builder.CreateBr(Cond);
3836 
3837   Builder.SetInsertPoint(Cond);
3838   Value *Cmp =
3839       Builder.CreateICmpULT(IndVarPHI, TripCount, "omp_" + Name + ".cmp");
3840   Builder.CreateCondBr(Cmp, Body, Exit);
3841 
3842   Builder.SetInsertPoint(Body);
3843   Builder.CreateBr(Latch);
3844 
3845   Builder.SetInsertPoint(Latch);
3846   Value *Next = Builder.CreateAdd(IndVarPHI, ConstantInt::get(IndVarTy, 1),
3847                                   "omp_" + Name + ".next", /*HasNUW=*/true);
3848   Builder.CreateBr(Header);
3849   IndVarPHI->addIncoming(Next, Latch);
3850 
3851   Builder.SetInsertPoint(Exit);
3852   Builder.CreateBr(After);
3853 
3854   // Remember and return the canonical control flow.
3855   LoopInfos.emplace_front();
3856   CanonicalLoopInfo *CL = &LoopInfos.front();
3857 
3858   CL->Header = Header;
3859   CL->Cond = Cond;
3860   CL->Latch = Latch;
3861   CL->Exit = Exit;
3862 
3863 #ifndef NDEBUG
3864   CL->assertOK();
3865 #endif
3866   return CL;
3867 }
3868 
3869 CanonicalLoopInfo *
createCanonicalLoop(const LocationDescription & Loc,LoopBodyGenCallbackTy BodyGenCB,Value * TripCount,const Twine & Name)3870 OpenMPIRBuilder::createCanonicalLoop(const LocationDescription &Loc,
3871                                      LoopBodyGenCallbackTy BodyGenCB,
3872                                      Value *TripCount, const Twine &Name) {
3873   BasicBlock *BB = Loc.IP.getBlock();
3874   BasicBlock *NextBB = BB->getNextNode();
3875 
3876   CanonicalLoopInfo *CL = createLoopSkeleton(Loc.DL, TripCount, BB->getParent(),
3877                                              NextBB, NextBB, Name);
3878   BasicBlock *After = CL->getAfter();
3879 
3880   // If location is not set, don't connect the loop.
3881   if (updateToLocation(Loc)) {
3882     // Split the loop at the insertion point: Branch to the preheader and move
3883     // every following instruction to after the loop (the After BB). Also, the
3884     // new successor is the loop's after block.
3885     spliceBB(Builder, After, /*CreateBranch=*/false);
3886     Builder.CreateBr(CL->getPreheader());
3887   }
3888 
3889   // Emit the body content. We do it after connecting the loop to the CFG to
3890   // avoid that the callback encounters degenerate BBs.
3891   BodyGenCB(CL->getBodyIP(), CL->getIndVar());
3892 
3893 #ifndef NDEBUG
3894   CL->assertOK();
3895 #endif
3896   return CL;
3897 }
3898 
createCanonicalLoop(const LocationDescription & Loc,LoopBodyGenCallbackTy BodyGenCB,Value * Start,Value * Stop,Value * Step,bool IsSigned,bool InclusiveStop,InsertPointTy ComputeIP,const Twine & Name)3899 CanonicalLoopInfo *OpenMPIRBuilder::createCanonicalLoop(
3900     const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
3901     Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
3902     InsertPointTy ComputeIP, const Twine &Name) {
3903 
3904   // Consider the following difficulties (assuming 8-bit signed integers):
3905   //  * Adding \p Step to the loop counter which passes \p Stop may overflow:
3906   //      DO I = 1, 100, 50
3907   ///  * A \p Step of INT_MIN cannot not be normalized to a positive direction:
3908   //      DO I = 100, 0, -128
3909 
3910   // Start, Stop and Step must be of the same integer type.
3911   auto *IndVarTy = cast<IntegerType>(Start->getType());
3912   assert(IndVarTy == Stop->getType() && "Stop type mismatch");
3913   assert(IndVarTy == Step->getType() && "Step type mismatch");
3914 
3915   LocationDescription ComputeLoc =
3916       ComputeIP.isSet() ? LocationDescription(ComputeIP, Loc.DL) : Loc;
3917   updateToLocation(ComputeLoc);
3918 
3919   ConstantInt *Zero = ConstantInt::get(IndVarTy, 0);
3920   ConstantInt *One = ConstantInt::get(IndVarTy, 1);
3921 
3922   // Like Step, but always positive.
3923   Value *Incr = Step;
3924 
3925   // Distance between Start and Stop; always positive.
3926   Value *Span;
3927 
3928   // Condition whether there are no iterations are executed at all, e.g. because
3929   // UB < LB.
3930   Value *ZeroCmp;
3931 
3932   if (IsSigned) {
3933     // Ensure that increment is positive. If not, negate and invert LB and UB.
3934     Value *IsNeg = Builder.CreateICmpSLT(Step, Zero);
3935     Incr = Builder.CreateSelect(IsNeg, Builder.CreateNeg(Step), Step);
3936     Value *LB = Builder.CreateSelect(IsNeg, Stop, Start);
3937     Value *UB = Builder.CreateSelect(IsNeg, Start, Stop);
3938     Span = Builder.CreateSub(UB, LB, "", false, true);
3939     ZeroCmp = Builder.CreateICmp(
3940         InclusiveStop ? CmpInst::ICMP_SLT : CmpInst::ICMP_SLE, UB, LB);
3941   } else {
3942     Span = Builder.CreateSub(Stop, Start, "", true);
3943     ZeroCmp = Builder.CreateICmp(
3944         InclusiveStop ? CmpInst::ICMP_ULT : CmpInst::ICMP_ULE, Stop, Start);
3945   }
3946 
3947   Value *CountIfLooping;
3948   if (InclusiveStop) {
3949     CountIfLooping = Builder.CreateAdd(Builder.CreateUDiv(Span, Incr), One);
3950   } else {
3951     // Avoid incrementing past stop since it could overflow.
3952     Value *CountIfTwo = Builder.CreateAdd(
3953         Builder.CreateUDiv(Builder.CreateSub(Span, One), Incr), One);
3954     Value *OneCmp = Builder.CreateICmp(CmpInst::ICMP_ULE, Span, Incr);
3955     CountIfLooping = Builder.CreateSelect(OneCmp, One, CountIfTwo);
3956   }
3957   Value *TripCount = Builder.CreateSelect(ZeroCmp, Zero, CountIfLooping,
3958                                           "omp_" + Name + ".tripcount");
3959 
3960   auto BodyGen = [=](InsertPointTy CodeGenIP, Value *IV) {
3961     Builder.restoreIP(CodeGenIP);
3962     Value *Span = Builder.CreateMul(IV, Step);
3963     Value *IndVar = Builder.CreateAdd(Span, Start);
3964     BodyGenCB(Builder.saveIP(), IndVar);
3965   };
3966   LocationDescription LoopLoc = ComputeIP.isSet() ? Loc.IP : Builder.saveIP();
3967   return createCanonicalLoop(LoopLoc, BodyGen, TripCount, Name);
3968 }
3969 
3970 // Returns an LLVM function to call for initializing loop bounds using OpenMP
3971 // static scheduling depending on `type`. Only i32 and i64 are supported by the
3972 // runtime. Always interpret integers as unsigned similarly to
3973 // CanonicalLoopInfo.
getKmpcForStaticInitForType(Type * Ty,Module & M,OpenMPIRBuilder & OMPBuilder)3974 static FunctionCallee getKmpcForStaticInitForType(Type *Ty, Module &M,
3975                                                   OpenMPIRBuilder &OMPBuilder) {
3976   unsigned Bitwidth = Ty->getIntegerBitWidth();
3977   if (Bitwidth == 32)
3978     return OMPBuilder.getOrCreateRuntimeFunction(
3979         M, omp::RuntimeFunction::OMPRTL___kmpc_for_static_init_4u);
3980   if (Bitwidth == 64)
3981     return OMPBuilder.getOrCreateRuntimeFunction(
3982         M, omp::RuntimeFunction::OMPRTL___kmpc_for_static_init_8u);
3983   llvm_unreachable("unknown OpenMP loop iterator bitwidth");
3984 }
3985 
3986 OpenMPIRBuilder::InsertPointTy
applyStaticWorkshareLoop(DebugLoc DL,CanonicalLoopInfo * CLI,InsertPointTy AllocaIP,bool NeedsBarrier)3987 OpenMPIRBuilder::applyStaticWorkshareLoop(DebugLoc DL, CanonicalLoopInfo *CLI,
3988                                           InsertPointTy AllocaIP,
3989                                           bool NeedsBarrier) {
3990   assert(CLI->isValid() && "Requires a valid canonical loop");
3991   assert(!isConflictIP(AllocaIP, CLI->getPreheaderIP()) &&
3992          "Require dedicated allocate IP");
3993 
3994   // Set up the source location value for OpenMP runtime.
3995   Builder.restoreIP(CLI->getPreheaderIP());
3996   Builder.SetCurrentDebugLocation(DL);
3997 
3998   uint32_t SrcLocStrSize;
3999   Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
4000   Value *SrcLoc = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4001 
4002   // Declare useful OpenMP runtime functions.
4003   Value *IV = CLI->getIndVar();
4004   Type *IVTy = IV->getType();
4005   FunctionCallee StaticInit = getKmpcForStaticInitForType(IVTy, M, *this);
4006   FunctionCallee StaticFini =
4007       getOrCreateRuntimeFunction(M, omp::OMPRTL___kmpc_for_static_fini);
4008 
4009   // Allocate space for computed loop bounds as expected by the "init" function.
4010   Builder.SetInsertPoint(AllocaIP.getBlock()->getFirstNonPHIOrDbgOrAlloca());
4011 
4012   Type *I32Type = Type::getInt32Ty(M.getContext());
4013   Value *PLastIter = Builder.CreateAlloca(I32Type, nullptr, "p.lastiter");
4014   Value *PLowerBound = Builder.CreateAlloca(IVTy, nullptr, "p.lowerbound");
4015   Value *PUpperBound = Builder.CreateAlloca(IVTy, nullptr, "p.upperbound");
4016   Value *PStride = Builder.CreateAlloca(IVTy, nullptr, "p.stride");
4017 
4018   // At the end of the preheader, prepare for calling the "init" function by
4019   // storing the current loop bounds into the allocated space. A canonical loop
4020   // always iterates from 0 to trip-count with step 1. Note that "init" expects
4021   // and produces an inclusive upper bound.
4022   Builder.SetInsertPoint(CLI->getPreheader()->getTerminator());
4023   Constant *Zero = ConstantInt::get(IVTy, 0);
4024   Constant *One = ConstantInt::get(IVTy, 1);
4025   Builder.CreateStore(Zero, PLowerBound);
4026   Value *UpperBound = Builder.CreateSub(CLI->getTripCount(), One);
4027   Builder.CreateStore(UpperBound, PUpperBound);
4028   Builder.CreateStore(One, PStride);
4029 
4030   Value *ThreadNum = getOrCreateThreadID(SrcLoc);
4031 
4032   Constant *SchedulingType = ConstantInt::get(
4033       I32Type, static_cast<int>(OMPScheduleType::UnorderedStatic));
4034 
4035   // Call the "init" function and update the trip count of the loop with the
4036   // value it produced.
4037   Builder.CreateCall(StaticInit,
4038                      {SrcLoc, ThreadNum, SchedulingType, PLastIter, PLowerBound,
4039                       PUpperBound, PStride, One, Zero});
4040   Value *LowerBound = Builder.CreateLoad(IVTy, PLowerBound);
4041   Value *InclusiveUpperBound = Builder.CreateLoad(IVTy, PUpperBound);
4042   Value *TripCountMinusOne = Builder.CreateSub(InclusiveUpperBound, LowerBound);
4043   Value *TripCount = Builder.CreateAdd(TripCountMinusOne, One);
4044   CLI->setTripCount(TripCount);
4045 
4046   // Update all uses of the induction variable except the one in the condition
4047   // block that compares it with the actual upper bound, and the increment in
4048   // the latch block.
4049 
4050   CLI->mapIndVar([&](Instruction *OldIV) -> Value * {
4051     Builder.SetInsertPoint(CLI->getBody(),
4052                            CLI->getBody()->getFirstInsertionPt());
4053     Builder.SetCurrentDebugLocation(DL);
4054     return Builder.CreateAdd(OldIV, LowerBound);
4055   });
4056 
4057   // In the "exit" block, call the "fini" function.
4058   Builder.SetInsertPoint(CLI->getExit(),
4059                          CLI->getExit()->getTerminator()->getIterator());
4060   Builder.CreateCall(StaticFini, {SrcLoc, ThreadNum});
4061 
4062   // Add the barrier if requested.
4063   if (NeedsBarrier)
4064     createBarrier(LocationDescription(Builder.saveIP(), DL),
4065                   omp::Directive::OMPD_for, /* ForceSimpleCall */ false,
4066                   /* CheckCancelFlag */ false);
4067 
4068   InsertPointTy AfterIP = CLI->getAfterIP();
4069   CLI->invalidate();
4070 
4071   return AfterIP;
4072 }
4073 
applyStaticChunkedWorkshareLoop(DebugLoc DL,CanonicalLoopInfo * CLI,InsertPointTy AllocaIP,bool NeedsBarrier,Value * ChunkSize)4074 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(
4075     DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
4076     bool NeedsBarrier, Value *ChunkSize) {
4077   assert(CLI->isValid() && "Requires a valid canonical loop");
4078   assert(ChunkSize && "Chunk size is required");
4079 
4080   LLVMContext &Ctx = CLI->getFunction()->getContext();
4081   Value *IV = CLI->getIndVar();
4082   Value *OrigTripCount = CLI->getTripCount();
4083   Type *IVTy = IV->getType();
4084   assert(IVTy->getIntegerBitWidth() <= 64 &&
4085          "Max supported tripcount bitwidth is 64 bits");
4086   Type *InternalIVTy = IVTy->getIntegerBitWidth() <= 32 ? Type::getInt32Ty(Ctx)
4087                                                         : Type::getInt64Ty(Ctx);
4088   Type *I32Type = Type::getInt32Ty(M.getContext());
4089   Constant *Zero = ConstantInt::get(InternalIVTy, 0);
4090   Constant *One = ConstantInt::get(InternalIVTy, 1);
4091 
4092   // Declare useful OpenMP runtime functions.
4093   FunctionCallee StaticInit =
4094       getKmpcForStaticInitForType(InternalIVTy, M, *this);
4095   FunctionCallee StaticFini =
4096       getOrCreateRuntimeFunction(M, omp::OMPRTL___kmpc_for_static_fini);
4097 
4098   // Allocate space for computed loop bounds as expected by the "init" function.
4099   Builder.restoreIP(AllocaIP);
4100   Builder.SetCurrentDebugLocation(DL);
4101   Value *PLastIter = Builder.CreateAlloca(I32Type, nullptr, "p.lastiter");
4102   Value *PLowerBound =
4103       Builder.CreateAlloca(InternalIVTy, nullptr, "p.lowerbound");
4104   Value *PUpperBound =
4105       Builder.CreateAlloca(InternalIVTy, nullptr, "p.upperbound");
4106   Value *PStride = Builder.CreateAlloca(InternalIVTy, nullptr, "p.stride");
4107 
4108   // Set up the source location value for the OpenMP runtime.
4109   Builder.restoreIP(CLI->getPreheaderIP());
4110   Builder.SetCurrentDebugLocation(DL);
4111 
4112   // TODO: Detect overflow in ubsan or max-out with current tripcount.
4113   Value *CastedChunkSize =
4114       Builder.CreateZExtOrTrunc(ChunkSize, InternalIVTy, "chunksize");
4115   Value *CastedTripCount =
4116       Builder.CreateZExt(OrigTripCount, InternalIVTy, "tripcount");
4117 
4118   Constant *SchedulingType = ConstantInt::get(
4119       I32Type, static_cast<int>(OMPScheduleType::UnorderedStaticChunked));
4120   Builder.CreateStore(Zero, PLowerBound);
4121   Value *OrigUpperBound = Builder.CreateSub(CastedTripCount, One);
4122   Builder.CreateStore(OrigUpperBound, PUpperBound);
4123   Builder.CreateStore(One, PStride);
4124 
4125   // Call the "init" function and update the trip count of the loop with the
4126   // value it produced.
4127   uint32_t SrcLocStrSize;
4128   Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
4129   Value *SrcLoc = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4130   Value *ThreadNum = getOrCreateThreadID(SrcLoc);
4131   Builder.CreateCall(StaticInit,
4132                      {/*loc=*/SrcLoc, /*global_tid=*/ThreadNum,
4133                       /*schedtype=*/SchedulingType, /*plastiter=*/PLastIter,
4134                       /*plower=*/PLowerBound, /*pupper=*/PUpperBound,
4135                       /*pstride=*/PStride, /*incr=*/One,
4136                       /*chunk=*/CastedChunkSize});
4137 
4138   // Load values written by the "init" function.
4139   Value *FirstChunkStart =
4140       Builder.CreateLoad(InternalIVTy, PLowerBound, "omp_firstchunk.lb");
4141   Value *FirstChunkStop =
4142       Builder.CreateLoad(InternalIVTy, PUpperBound, "omp_firstchunk.ub");
4143   Value *FirstChunkEnd = Builder.CreateAdd(FirstChunkStop, One);
4144   Value *ChunkRange =
4145       Builder.CreateSub(FirstChunkEnd, FirstChunkStart, "omp_chunk.range");
4146   Value *NextChunkStride =
4147       Builder.CreateLoad(InternalIVTy, PStride, "omp_dispatch.stride");
4148 
4149   // Create outer "dispatch" loop for enumerating the chunks.
4150   BasicBlock *DispatchEnter = splitBB(Builder, true);
4151   Value *DispatchCounter;
4152   CanonicalLoopInfo *DispatchCLI = createCanonicalLoop(
4153       {Builder.saveIP(), DL},
4154       [&](InsertPointTy BodyIP, Value *Counter) { DispatchCounter = Counter; },
4155       FirstChunkStart, CastedTripCount, NextChunkStride,
4156       /*IsSigned=*/false, /*InclusiveStop=*/false, /*ComputeIP=*/{},
4157       "dispatch");
4158 
4159   // Remember the BasicBlocks of the dispatch loop we need, then invalidate to
4160   // not have to preserve the canonical invariant.
4161   BasicBlock *DispatchBody = DispatchCLI->getBody();
4162   BasicBlock *DispatchLatch = DispatchCLI->getLatch();
4163   BasicBlock *DispatchExit = DispatchCLI->getExit();
4164   BasicBlock *DispatchAfter = DispatchCLI->getAfter();
4165   DispatchCLI->invalidate();
4166 
4167   // Rewire the original loop to become the chunk loop inside the dispatch loop.
4168   redirectTo(DispatchAfter, CLI->getAfter(), DL);
4169   redirectTo(CLI->getExit(), DispatchLatch, DL);
4170   redirectTo(DispatchBody, DispatchEnter, DL);
4171 
4172   // Prepare the prolog of the chunk loop.
4173   Builder.restoreIP(CLI->getPreheaderIP());
4174   Builder.SetCurrentDebugLocation(DL);
4175 
4176   // Compute the number of iterations of the chunk loop.
4177   Builder.SetInsertPoint(CLI->getPreheader()->getTerminator());
4178   Value *ChunkEnd = Builder.CreateAdd(DispatchCounter, ChunkRange);
4179   Value *IsLastChunk =
4180       Builder.CreateICmpUGE(ChunkEnd, CastedTripCount, "omp_chunk.is_last");
4181   Value *CountUntilOrigTripCount =
4182       Builder.CreateSub(CastedTripCount, DispatchCounter);
4183   Value *ChunkTripCount = Builder.CreateSelect(
4184       IsLastChunk, CountUntilOrigTripCount, ChunkRange, "omp_chunk.tripcount");
4185   Value *BackcastedChunkTC =
4186       Builder.CreateTrunc(ChunkTripCount, IVTy, "omp_chunk.tripcount.trunc");
4187   CLI->setTripCount(BackcastedChunkTC);
4188 
4189   // Update all uses of the induction variable except the one in the condition
4190   // block that compares it with the actual upper bound, and the increment in
4191   // the latch block.
4192   Value *BackcastedDispatchCounter =
4193       Builder.CreateTrunc(DispatchCounter, IVTy, "omp_dispatch.iv.trunc");
4194   CLI->mapIndVar([&](Instruction *) -> Value * {
4195     Builder.restoreIP(CLI->getBodyIP());
4196     return Builder.CreateAdd(IV, BackcastedDispatchCounter);
4197   });
4198 
4199   // In the "exit" block, call the "fini" function.
4200   Builder.SetInsertPoint(DispatchExit, DispatchExit->getFirstInsertionPt());
4201   Builder.CreateCall(StaticFini, {SrcLoc, ThreadNum});
4202 
4203   // Add the barrier if requested.
4204   if (NeedsBarrier)
4205     createBarrier(LocationDescription(Builder.saveIP(), DL), OMPD_for,
4206                   /*ForceSimpleCall=*/false, /*CheckCancelFlag=*/false);
4207 
4208 #ifndef NDEBUG
4209   // Even though we currently do not support applying additional methods to it,
4210   // the chunk loop should remain a canonical loop.
4211   CLI->assertOK();
4212 #endif
4213 
4214   return {DispatchAfter, DispatchAfter->getFirstInsertionPt()};
4215 }
4216 
4217 // Returns an LLVM function to call for executing an OpenMP static worksharing
4218 // for loop depending on `type`. Only i32 and i64 are supported by the runtime.
4219 // Always interpret integers as unsigned similarly to CanonicalLoopInfo.
4220 static FunctionCallee
getKmpcForStaticLoopForType(Type * Ty,OpenMPIRBuilder * OMPBuilder,WorksharingLoopType LoopType)4221 getKmpcForStaticLoopForType(Type *Ty, OpenMPIRBuilder *OMPBuilder,
4222                             WorksharingLoopType LoopType) {
4223   unsigned Bitwidth = Ty->getIntegerBitWidth();
4224   Module &M = OMPBuilder->M;
4225   switch (LoopType) {
4226   case WorksharingLoopType::ForStaticLoop:
4227     if (Bitwidth == 32)
4228       return OMPBuilder->getOrCreateRuntimeFunction(
4229           M, omp::RuntimeFunction::OMPRTL___kmpc_for_static_loop_4u);
4230     if (Bitwidth == 64)
4231       return OMPBuilder->getOrCreateRuntimeFunction(
4232           M, omp::RuntimeFunction::OMPRTL___kmpc_for_static_loop_8u);
4233     break;
4234   case WorksharingLoopType::DistributeStaticLoop:
4235     if (Bitwidth == 32)
4236       return OMPBuilder->getOrCreateRuntimeFunction(
4237           M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_static_loop_4u);
4238     if (Bitwidth == 64)
4239       return OMPBuilder->getOrCreateRuntimeFunction(
4240           M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_static_loop_8u);
4241     break;
4242   case WorksharingLoopType::DistributeForStaticLoop:
4243     if (Bitwidth == 32)
4244       return OMPBuilder->getOrCreateRuntimeFunction(
4245           M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_for_static_loop_4u);
4246     if (Bitwidth == 64)
4247       return OMPBuilder->getOrCreateRuntimeFunction(
4248           M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_for_static_loop_8u);
4249     break;
4250   }
4251   if (Bitwidth != 32 && Bitwidth != 64) {
4252     llvm_unreachable("Unknown OpenMP loop iterator bitwidth");
4253   }
4254   llvm_unreachable("Unknown type of OpenMP worksharing loop");
4255 }
4256 
4257 // Inserts a call to proper OpenMP Device RTL function which handles
4258 // loop worksharing.
createTargetLoopWorkshareCall(OpenMPIRBuilder * OMPBuilder,WorksharingLoopType LoopType,BasicBlock * InsertBlock,Value * Ident,Value * LoopBodyArg,Type * ParallelTaskPtr,Value * TripCount,Function & LoopBodyFn)4259 static void createTargetLoopWorkshareCall(
4260     OpenMPIRBuilder *OMPBuilder, WorksharingLoopType LoopType,
4261     BasicBlock *InsertBlock, Value *Ident, Value *LoopBodyArg,
4262     Type *ParallelTaskPtr, Value *TripCount, Function &LoopBodyFn) {
4263   Type *TripCountTy = TripCount->getType();
4264   Module &M = OMPBuilder->M;
4265   IRBuilder<> &Builder = OMPBuilder->Builder;
4266   FunctionCallee RTLFn =
4267       getKmpcForStaticLoopForType(TripCountTy, OMPBuilder, LoopType);
4268   SmallVector<Value *, 8> RealArgs;
4269   RealArgs.push_back(Ident);
4270   RealArgs.push_back(Builder.CreateBitCast(&LoopBodyFn, ParallelTaskPtr));
4271   RealArgs.push_back(LoopBodyArg);
4272   RealArgs.push_back(TripCount);
4273   if (LoopType == WorksharingLoopType::DistributeStaticLoop) {
4274     RealArgs.push_back(ConstantInt::get(TripCountTy, 0));
4275     Builder.CreateCall(RTLFn, RealArgs);
4276     return;
4277   }
4278   FunctionCallee RTLNumThreads = OMPBuilder->getOrCreateRuntimeFunction(
4279       M, omp::RuntimeFunction::OMPRTL_omp_get_num_threads);
4280   Builder.restoreIP({InsertBlock, std::prev(InsertBlock->end())});
4281   Value *NumThreads = Builder.CreateCall(RTLNumThreads, {});
4282 
4283   RealArgs.push_back(
4284       Builder.CreateZExtOrTrunc(NumThreads, TripCountTy, "num.threads.cast"));
4285   RealArgs.push_back(ConstantInt::get(TripCountTy, 0));
4286   if (LoopType == WorksharingLoopType::DistributeForStaticLoop) {
4287     RealArgs.push_back(ConstantInt::get(TripCountTy, 0));
4288   }
4289 
4290   Builder.CreateCall(RTLFn, RealArgs);
4291 }
4292 
4293 static void
workshareLoopTargetCallback(OpenMPIRBuilder * OMPIRBuilder,CanonicalLoopInfo * CLI,Value * Ident,Function & OutlinedFn,Type * ParallelTaskPtr,const SmallVector<Instruction *,4> & ToBeDeleted,WorksharingLoopType LoopType)4294 workshareLoopTargetCallback(OpenMPIRBuilder *OMPIRBuilder,
4295                             CanonicalLoopInfo *CLI, Value *Ident,
4296                             Function &OutlinedFn, Type *ParallelTaskPtr,
4297                             const SmallVector<Instruction *, 4> &ToBeDeleted,
4298                             WorksharingLoopType LoopType) {
4299   IRBuilder<> &Builder = OMPIRBuilder->Builder;
4300   BasicBlock *Preheader = CLI->getPreheader();
4301   Value *TripCount = CLI->getTripCount();
4302 
4303   // After loop body outling, the loop body contains only set up
4304   // of loop body argument structure and the call to the outlined
4305   // loop body function. Firstly, we need to move setup of loop body args
4306   // into loop preheader.
4307   Preheader->splice(std::prev(Preheader->end()), CLI->getBody(),
4308                     CLI->getBody()->begin(), std::prev(CLI->getBody()->end()));
4309 
4310   // The next step is to remove the whole loop. We do not it need anymore.
4311   // That's why make an unconditional branch from loop preheader to loop
4312   // exit block
4313   Builder.restoreIP({Preheader, Preheader->end()});
4314   Preheader->getTerminator()->eraseFromParent();
4315   Builder.CreateBr(CLI->getExit());
4316 
4317   // Delete dead loop blocks
4318   OpenMPIRBuilder::OutlineInfo CleanUpInfo;
4319   SmallPtrSet<BasicBlock *, 32> RegionBlockSet;
4320   SmallVector<BasicBlock *, 32> BlocksToBeRemoved;
4321   CleanUpInfo.EntryBB = CLI->getHeader();
4322   CleanUpInfo.ExitBB = CLI->getExit();
4323   CleanUpInfo.collectBlocks(RegionBlockSet, BlocksToBeRemoved);
4324   DeleteDeadBlocks(BlocksToBeRemoved);
4325 
4326   // Find the instruction which corresponds to loop body argument structure
4327   // and remove the call to loop body function instruction.
4328   Value *LoopBodyArg;
4329   User *OutlinedFnUser = OutlinedFn.getUniqueUndroppableUser();
4330   assert(OutlinedFnUser &&
4331          "Expected unique undroppable user of outlined function");
4332   CallInst *OutlinedFnCallInstruction = dyn_cast<CallInst>(OutlinedFnUser);
4333   assert(OutlinedFnCallInstruction && "Expected outlined function call");
4334   assert((OutlinedFnCallInstruction->getParent() == Preheader) &&
4335          "Expected outlined function call to be located in loop preheader");
4336   // Check in case no argument structure has been passed.
4337   if (OutlinedFnCallInstruction->arg_size() > 1)
4338     LoopBodyArg = OutlinedFnCallInstruction->getArgOperand(1);
4339   else
4340     LoopBodyArg = Constant::getNullValue(Builder.getPtrTy());
4341   OutlinedFnCallInstruction->eraseFromParent();
4342 
4343   createTargetLoopWorkshareCall(OMPIRBuilder, LoopType, Preheader, Ident,
4344                                 LoopBodyArg, ParallelTaskPtr, TripCount,
4345                                 OutlinedFn);
4346 
4347   for (auto &ToBeDeletedItem : ToBeDeleted)
4348     ToBeDeletedItem->eraseFromParent();
4349   CLI->invalidate();
4350 }
4351 
4352 OpenMPIRBuilder::InsertPointTy
applyWorkshareLoopTarget(DebugLoc DL,CanonicalLoopInfo * CLI,InsertPointTy AllocaIP,WorksharingLoopType LoopType)4353 OpenMPIRBuilder::applyWorkshareLoopTarget(DebugLoc DL, CanonicalLoopInfo *CLI,
4354                                           InsertPointTy AllocaIP,
4355                                           WorksharingLoopType LoopType) {
4356   uint32_t SrcLocStrSize;
4357   Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
4358   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4359 
4360   OutlineInfo OI;
4361   OI.OuterAllocaBB = CLI->getPreheader();
4362   Function *OuterFn = CLI->getPreheader()->getParent();
4363 
4364   // Instructions which need to be deleted at the end of code generation
4365   SmallVector<Instruction *, 4> ToBeDeleted;
4366 
4367   OI.OuterAllocaBB = AllocaIP.getBlock();
4368 
4369   // Mark the body loop as region which needs to be extracted
4370   OI.EntryBB = CLI->getBody();
4371   OI.ExitBB = CLI->getLatch()->splitBasicBlock(CLI->getLatch()->begin(),
4372                                                "omp.prelatch", true);
4373 
4374   // Prepare loop body for extraction
4375   Builder.restoreIP({CLI->getPreheader(), CLI->getPreheader()->begin()});
4376 
4377   // Insert new loop counter variable which will be used only in loop
4378   // body.
4379   AllocaInst *NewLoopCnt = Builder.CreateAlloca(CLI->getIndVarType(), 0, "");
4380   Instruction *NewLoopCntLoad =
4381       Builder.CreateLoad(CLI->getIndVarType(), NewLoopCnt);
4382   // New loop counter instructions are redundant in the loop preheader when
4383   // code generation for workshare loop is finshed. That's why mark them as
4384   // ready for deletion.
4385   ToBeDeleted.push_back(NewLoopCntLoad);
4386   ToBeDeleted.push_back(NewLoopCnt);
4387 
4388   // Analyse loop body region. Find all input variables which are used inside
4389   // loop body region.
4390   SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
4391   SmallVector<BasicBlock *, 32> Blocks;
4392   OI.collectBlocks(ParallelRegionBlockSet, Blocks);
4393   SmallVector<BasicBlock *, 32> BlocksT(ParallelRegionBlockSet.begin(),
4394                                         ParallelRegionBlockSet.end());
4395 
4396   CodeExtractorAnalysisCache CEAC(*OuterFn);
4397   CodeExtractor Extractor(Blocks,
4398                           /* DominatorTree */ nullptr,
4399                           /* AggregateArgs */ true,
4400                           /* BlockFrequencyInfo */ nullptr,
4401                           /* BranchProbabilityInfo */ nullptr,
4402                           /* AssumptionCache */ nullptr,
4403                           /* AllowVarArgs */ true,
4404                           /* AllowAlloca */ true,
4405                           /* AllocationBlock */ CLI->getPreheader(),
4406                           /* Suffix */ ".omp_wsloop",
4407                           /* AggrArgsIn0AddrSpace */ true);
4408 
4409   BasicBlock *CommonExit = nullptr;
4410   SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
4411 
4412   // Find allocas outside the loop body region which are used inside loop
4413   // body
4414   Extractor.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
4415 
4416   // We need to model loop body region as the function f(cnt, loop_arg).
4417   // That's why we replace loop induction variable by the new counter
4418   // which will be one of loop body function argument
4419   SmallVector<User *> Users(CLI->getIndVar()->user_begin(),
4420                             CLI->getIndVar()->user_end());
4421   for (auto Use : Users) {
4422     if (Instruction *Inst = dyn_cast<Instruction>(Use)) {
4423       if (ParallelRegionBlockSet.count(Inst->getParent())) {
4424         Inst->replaceUsesOfWith(CLI->getIndVar(), NewLoopCntLoad);
4425       }
4426     }
4427   }
4428   // Make sure that loop counter variable is not merged into loop body
4429   // function argument structure and it is passed as separate variable
4430   OI.ExcludeArgsFromAggregate.push_back(NewLoopCntLoad);
4431 
4432   // PostOutline CB is invoked when loop body function is outlined and
4433   // loop body is replaced by call to outlined function. We need to add
4434   // call to OpenMP device rtl inside loop preheader. OpenMP device rtl
4435   // function will handle loop control logic.
4436   //
4437   OI.PostOutlineCB = [=, ToBeDeletedVec =
4438                              std::move(ToBeDeleted)](Function &OutlinedFn) {
4439     workshareLoopTargetCallback(this, CLI, Ident, OutlinedFn, ParallelTaskPtr,
4440                                 ToBeDeletedVec, LoopType);
4441   };
4442   addOutlineInfo(std::move(OI));
4443   return CLI->getAfterIP();
4444 }
4445 
applyWorkshareLoop(DebugLoc DL,CanonicalLoopInfo * CLI,InsertPointTy AllocaIP,bool NeedsBarrier,omp::ScheduleKind SchedKind,Value * ChunkSize,bool HasSimdModifier,bool HasMonotonicModifier,bool HasNonmonotonicModifier,bool HasOrderedClause,WorksharingLoopType LoopType)4446 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyWorkshareLoop(
4447     DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
4448     bool NeedsBarrier, omp::ScheduleKind SchedKind, Value *ChunkSize,
4449     bool HasSimdModifier, bool HasMonotonicModifier,
4450     bool HasNonmonotonicModifier, bool HasOrderedClause,
4451     WorksharingLoopType LoopType) {
4452   if (Config.isTargetDevice())
4453     return applyWorkshareLoopTarget(DL, CLI, AllocaIP, LoopType);
4454   OMPScheduleType EffectiveScheduleType = computeOpenMPScheduleType(
4455       SchedKind, ChunkSize, HasSimdModifier, HasMonotonicModifier,
4456       HasNonmonotonicModifier, HasOrderedClause);
4457 
4458   bool IsOrdered = (EffectiveScheduleType & OMPScheduleType::ModifierOrdered) ==
4459                    OMPScheduleType::ModifierOrdered;
4460   switch (EffectiveScheduleType & ~OMPScheduleType::ModifierMask) {
4461   case OMPScheduleType::BaseStatic:
4462     assert(!ChunkSize && "No chunk size with static-chunked schedule");
4463     if (IsOrdered)
4464       return applyDynamicWorkshareLoop(DL, CLI, AllocaIP, EffectiveScheduleType,
4465                                        NeedsBarrier, ChunkSize);
4466     // FIXME: Monotonicity ignored?
4467     return applyStaticWorkshareLoop(DL, CLI, AllocaIP, NeedsBarrier);
4468 
4469   case OMPScheduleType::BaseStaticChunked:
4470     if (IsOrdered)
4471       return applyDynamicWorkshareLoop(DL, CLI, AllocaIP, EffectiveScheduleType,
4472                                        NeedsBarrier, ChunkSize);
4473     // FIXME: Monotonicity ignored?
4474     return applyStaticChunkedWorkshareLoop(DL, CLI, AllocaIP, NeedsBarrier,
4475                                            ChunkSize);
4476 
4477   case OMPScheduleType::BaseRuntime:
4478   case OMPScheduleType::BaseAuto:
4479   case OMPScheduleType::BaseGreedy:
4480   case OMPScheduleType::BaseBalanced:
4481   case OMPScheduleType::BaseSteal:
4482   case OMPScheduleType::BaseGuidedSimd:
4483   case OMPScheduleType::BaseRuntimeSimd:
4484     assert(!ChunkSize &&
4485            "schedule type does not support user-defined chunk sizes");
4486     [[fallthrough]];
4487   case OMPScheduleType::BaseDynamicChunked:
4488   case OMPScheduleType::BaseGuidedChunked:
4489   case OMPScheduleType::BaseGuidedIterativeChunked:
4490   case OMPScheduleType::BaseGuidedAnalyticalChunked:
4491   case OMPScheduleType::BaseStaticBalancedChunked:
4492     return applyDynamicWorkshareLoop(DL, CLI, AllocaIP, EffectiveScheduleType,
4493                                      NeedsBarrier, ChunkSize);
4494 
4495   default:
4496     llvm_unreachable("Unknown/unimplemented schedule kind");
4497   }
4498 }
4499 
4500 /// Returns an LLVM function to call for initializing loop bounds using OpenMP
4501 /// dynamic scheduling depending on `type`. Only i32 and i64 are supported by
4502 /// the runtime. Always interpret integers as unsigned similarly to
4503 /// CanonicalLoopInfo.
4504 static FunctionCallee
getKmpcForDynamicInitForType(Type * Ty,Module & M,OpenMPIRBuilder & OMPBuilder)4505 getKmpcForDynamicInitForType(Type *Ty, Module &M, OpenMPIRBuilder &OMPBuilder) {
4506   unsigned Bitwidth = Ty->getIntegerBitWidth();
4507   if (Bitwidth == 32)
4508     return OMPBuilder.getOrCreateRuntimeFunction(
4509         M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_init_4u);
4510   if (Bitwidth == 64)
4511     return OMPBuilder.getOrCreateRuntimeFunction(
4512         M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_init_8u);
4513   llvm_unreachable("unknown OpenMP loop iterator bitwidth");
4514 }
4515 
4516 /// Returns an LLVM function to call for updating the next loop using OpenMP
4517 /// dynamic scheduling depending on `type`. Only i32 and i64 are supported by
4518 /// the runtime. Always interpret integers as unsigned similarly to
4519 /// CanonicalLoopInfo.
4520 static FunctionCallee
getKmpcForDynamicNextForType(Type * Ty,Module & M,OpenMPIRBuilder & OMPBuilder)4521 getKmpcForDynamicNextForType(Type *Ty, Module &M, OpenMPIRBuilder &OMPBuilder) {
4522   unsigned Bitwidth = Ty->getIntegerBitWidth();
4523   if (Bitwidth == 32)
4524     return OMPBuilder.getOrCreateRuntimeFunction(
4525         M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_next_4u);
4526   if (Bitwidth == 64)
4527     return OMPBuilder.getOrCreateRuntimeFunction(
4528         M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_next_8u);
4529   llvm_unreachable("unknown OpenMP loop iterator bitwidth");
4530 }
4531 
4532 /// Returns an LLVM function to call for finalizing the dynamic loop using
4533 /// depending on `type`. Only i32 and i64 are supported by the runtime. Always
4534 /// interpret integers as unsigned similarly to CanonicalLoopInfo.
4535 static FunctionCallee
getKmpcForDynamicFiniForType(Type * Ty,Module & M,OpenMPIRBuilder & OMPBuilder)4536 getKmpcForDynamicFiniForType(Type *Ty, Module &M, OpenMPIRBuilder &OMPBuilder) {
4537   unsigned Bitwidth = Ty->getIntegerBitWidth();
4538   if (Bitwidth == 32)
4539     return OMPBuilder.getOrCreateRuntimeFunction(
4540         M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_fini_4u);
4541   if (Bitwidth == 64)
4542     return OMPBuilder.getOrCreateRuntimeFunction(
4543         M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_fini_8u);
4544   llvm_unreachable("unknown OpenMP loop iterator bitwidth");
4545 }
4546 
applyDynamicWorkshareLoop(DebugLoc DL,CanonicalLoopInfo * CLI,InsertPointTy AllocaIP,OMPScheduleType SchedType,bool NeedsBarrier,Value * Chunk)4547 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyDynamicWorkshareLoop(
4548     DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
4549     OMPScheduleType SchedType, bool NeedsBarrier, Value *Chunk) {
4550   assert(CLI->isValid() && "Requires a valid canonical loop");
4551   assert(!isConflictIP(AllocaIP, CLI->getPreheaderIP()) &&
4552          "Require dedicated allocate IP");
4553   assert(isValidWorkshareLoopScheduleType(SchedType) &&
4554          "Require valid schedule type");
4555 
4556   bool Ordered = (SchedType & OMPScheduleType::ModifierOrdered) ==
4557                  OMPScheduleType::ModifierOrdered;
4558 
4559   // Set up the source location value for OpenMP runtime.
4560   Builder.SetCurrentDebugLocation(DL);
4561 
4562   uint32_t SrcLocStrSize;
4563   Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
4564   Value *SrcLoc = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4565 
4566   // Declare useful OpenMP runtime functions.
4567   Value *IV = CLI->getIndVar();
4568   Type *IVTy = IV->getType();
4569   FunctionCallee DynamicInit = getKmpcForDynamicInitForType(IVTy, M, *this);
4570   FunctionCallee DynamicNext = getKmpcForDynamicNextForType(IVTy, M, *this);
4571 
4572   // Allocate space for computed loop bounds as expected by the "init" function.
4573   Builder.SetInsertPoint(AllocaIP.getBlock()->getFirstNonPHIOrDbgOrAlloca());
4574   Type *I32Type = Type::getInt32Ty(M.getContext());
4575   Value *PLastIter = Builder.CreateAlloca(I32Type, nullptr, "p.lastiter");
4576   Value *PLowerBound = Builder.CreateAlloca(IVTy, nullptr, "p.lowerbound");
4577   Value *PUpperBound = Builder.CreateAlloca(IVTy, nullptr, "p.upperbound");
4578   Value *PStride = Builder.CreateAlloca(IVTy, nullptr, "p.stride");
4579 
4580   // At the end of the preheader, prepare for calling the "init" function by
4581   // storing the current loop bounds into the allocated space. A canonical loop
4582   // always iterates from 0 to trip-count with step 1. Note that "init" expects
4583   // and produces an inclusive upper bound.
4584   BasicBlock *PreHeader = CLI->getPreheader();
4585   Builder.SetInsertPoint(PreHeader->getTerminator());
4586   Constant *One = ConstantInt::get(IVTy, 1);
4587   Builder.CreateStore(One, PLowerBound);
4588   Value *UpperBound = CLI->getTripCount();
4589   Builder.CreateStore(UpperBound, PUpperBound);
4590   Builder.CreateStore(One, PStride);
4591 
4592   BasicBlock *Header = CLI->getHeader();
4593   BasicBlock *Exit = CLI->getExit();
4594   BasicBlock *Cond = CLI->getCond();
4595   BasicBlock *Latch = CLI->getLatch();
4596   InsertPointTy AfterIP = CLI->getAfterIP();
4597 
4598   // The CLI will be "broken" in the code below, as the loop is no longer
4599   // a valid canonical loop.
4600 
4601   if (!Chunk)
4602     Chunk = One;
4603 
4604   Value *ThreadNum = getOrCreateThreadID(SrcLoc);
4605 
4606   Constant *SchedulingType =
4607       ConstantInt::get(I32Type, static_cast<int>(SchedType));
4608 
4609   // Call the "init" function.
4610   Builder.CreateCall(DynamicInit,
4611                      {SrcLoc, ThreadNum, SchedulingType, /* LowerBound */ One,
4612                       UpperBound, /* step */ One, Chunk});
4613 
4614   // An outer loop around the existing one.
4615   BasicBlock *OuterCond = BasicBlock::Create(
4616       PreHeader->getContext(), Twine(PreHeader->getName()) + ".outer.cond",
4617       PreHeader->getParent());
4618   // This needs to be 32-bit always, so can't use the IVTy Zero above.
4619   Builder.SetInsertPoint(OuterCond, OuterCond->getFirstInsertionPt());
4620   Value *Res =
4621       Builder.CreateCall(DynamicNext, {SrcLoc, ThreadNum, PLastIter,
4622                                        PLowerBound, PUpperBound, PStride});
4623   Constant *Zero32 = ConstantInt::get(I32Type, 0);
4624   Value *MoreWork = Builder.CreateCmp(CmpInst::ICMP_NE, Res, Zero32);
4625   Value *LowerBound =
4626       Builder.CreateSub(Builder.CreateLoad(IVTy, PLowerBound), One, "lb");
4627   Builder.CreateCondBr(MoreWork, Header, Exit);
4628 
4629   // Change PHI-node in loop header to use outer cond rather than preheader,
4630   // and set IV to the LowerBound.
4631   Instruction *Phi = &Header->front();
4632   auto *PI = cast<PHINode>(Phi);
4633   PI->setIncomingBlock(0, OuterCond);
4634   PI->setIncomingValue(0, LowerBound);
4635 
4636   // Then set the pre-header to jump to the OuterCond
4637   Instruction *Term = PreHeader->getTerminator();
4638   auto *Br = cast<BranchInst>(Term);
4639   Br->setSuccessor(0, OuterCond);
4640 
4641   // Modify the inner condition:
4642   // * Use the UpperBound returned from the DynamicNext call.
4643   // * jump to the loop outer loop when done with one of the inner loops.
4644   Builder.SetInsertPoint(Cond, Cond->getFirstInsertionPt());
4645   UpperBound = Builder.CreateLoad(IVTy, PUpperBound, "ub");
4646   Instruction *Comp = &*Builder.GetInsertPoint();
4647   auto *CI = cast<CmpInst>(Comp);
4648   CI->setOperand(1, UpperBound);
4649   // Redirect the inner exit to branch to outer condition.
4650   Instruction *Branch = &Cond->back();
4651   auto *BI = cast<BranchInst>(Branch);
4652   assert(BI->getSuccessor(1) == Exit);
4653   BI->setSuccessor(1, OuterCond);
4654 
4655   // Call the "fini" function if "ordered" is present in wsloop directive.
4656   if (Ordered) {
4657     Builder.SetInsertPoint(&Latch->back());
4658     FunctionCallee DynamicFini = getKmpcForDynamicFiniForType(IVTy, M, *this);
4659     Builder.CreateCall(DynamicFini, {SrcLoc, ThreadNum});
4660   }
4661 
4662   // Add the barrier if requested.
4663   if (NeedsBarrier) {
4664     Builder.SetInsertPoint(&Exit->back());
4665     createBarrier(LocationDescription(Builder.saveIP(), DL),
4666                   omp::Directive::OMPD_for, /* ForceSimpleCall */ false,
4667                   /* CheckCancelFlag */ false);
4668   }
4669 
4670   CLI->invalidate();
4671   return AfterIP;
4672 }
4673 
4674 /// Redirect all edges that branch to \p OldTarget to \p NewTarget. That is,
4675 /// after this \p OldTarget will be orphaned.
redirectAllPredecessorsTo(BasicBlock * OldTarget,BasicBlock * NewTarget,DebugLoc DL)4676 static void redirectAllPredecessorsTo(BasicBlock *OldTarget,
4677                                       BasicBlock *NewTarget, DebugLoc DL) {
4678   for (BasicBlock *Pred : make_early_inc_range(predecessors(OldTarget)))
4679     redirectTo(Pred, NewTarget, DL);
4680 }
4681 
4682 /// Determine which blocks in \p BBs are reachable from outside and remove the
4683 /// ones that are not reachable from the function.
removeUnusedBlocksFromParent(ArrayRef<BasicBlock * > BBs)4684 static void removeUnusedBlocksFromParent(ArrayRef<BasicBlock *> BBs) {
4685   SmallPtrSet<BasicBlock *, 6> BBsToErase{BBs.begin(), BBs.end()};
4686   auto HasRemainingUses = [&BBsToErase](BasicBlock *BB) {
4687     for (Use &U : BB->uses()) {
4688       auto *UseInst = dyn_cast<Instruction>(U.getUser());
4689       if (!UseInst)
4690         continue;
4691       if (BBsToErase.count(UseInst->getParent()))
4692         continue;
4693       return true;
4694     }
4695     return false;
4696   };
4697 
4698   while (BBsToErase.remove_if(HasRemainingUses)) {
4699     // Try again if anything was removed.
4700   }
4701 
4702   SmallVector<BasicBlock *, 7> BBVec(BBsToErase.begin(), BBsToErase.end());
4703   DeleteDeadBlocks(BBVec);
4704 }
4705 
4706 CanonicalLoopInfo *
collapseLoops(DebugLoc DL,ArrayRef<CanonicalLoopInfo * > Loops,InsertPointTy ComputeIP)4707 OpenMPIRBuilder::collapseLoops(DebugLoc DL, ArrayRef<CanonicalLoopInfo *> Loops,
4708                                InsertPointTy ComputeIP) {
4709   assert(Loops.size() >= 1 && "At least one loop required");
4710   size_t NumLoops = Loops.size();
4711 
4712   // Nothing to do if there is already just one loop.
4713   if (NumLoops == 1)
4714     return Loops.front();
4715 
4716   CanonicalLoopInfo *Outermost = Loops.front();
4717   CanonicalLoopInfo *Innermost = Loops.back();
4718   BasicBlock *OrigPreheader = Outermost->getPreheader();
4719   BasicBlock *OrigAfter = Outermost->getAfter();
4720   Function *F = OrigPreheader->getParent();
4721 
4722   // Loop control blocks that may become orphaned later.
4723   SmallVector<BasicBlock *, 12> OldControlBBs;
4724   OldControlBBs.reserve(6 * Loops.size());
4725   for (CanonicalLoopInfo *Loop : Loops)
4726     Loop->collectControlBlocks(OldControlBBs);
4727 
4728   // Setup the IRBuilder for inserting the trip count computation.
4729   Builder.SetCurrentDebugLocation(DL);
4730   if (ComputeIP.isSet())
4731     Builder.restoreIP(ComputeIP);
4732   else
4733     Builder.restoreIP(Outermost->getPreheaderIP());
4734 
4735   // Derive the collapsed' loop trip count.
4736   // TODO: Find common/largest indvar type.
4737   Value *CollapsedTripCount = nullptr;
4738   for (CanonicalLoopInfo *L : Loops) {
4739     assert(L->isValid() &&
4740            "All loops to collapse must be valid canonical loops");
4741     Value *OrigTripCount = L->getTripCount();
4742     if (!CollapsedTripCount) {
4743       CollapsedTripCount = OrigTripCount;
4744       continue;
4745     }
4746 
4747     // TODO: Enable UndefinedSanitizer to diagnose an overflow here.
4748     CollapsedTripCount = Builder.CreateMul(CollapsedTripCount, OrigTripCount,
4749                                            {}, /*HasNUW=*/true);
4750   }
4751 
4752   // Create the collapsed loop control flow.
4753   CanonicalLoopInfo *Result =
4754       createLoopSkeleton(DL, CollapsedTripCount, F,
4755                          OrigPreheader->getNextNode(), OrigAfter, "collapsed");
4756 
4757   // Build the collapsed loop body code.
4758   // Start with deriving the input loop induction variables from the collapsed
4759   // one, using a divmod scheme. To preserve the original loops' order, the
4760   // innermost loop use the least significant bits.
4761   Builder.restoreIP(Result->getBodyIP());
4762 
4763   Value *Leftover = Result->getIndVar();
4764   SmallVector<Value *> NewIndVars;
4765   NewIndVars.resize(NumLoops);
4766   for (int i = NumLoops - 1; i >= 1; --i) {
4767     Value *OrigTripCount = Loops[i]->getTripCount();
4768 
4769     Value *NewIndVar = Builder.CreateURem(Leftover, OrigTripCount);
4770     NewIndVars[i] = NewIndVar;
4771 
4772     Leftover = Builder.CreateUDiv(Leftover, OrigTripCount);
4773   }
4774   // Outermost loop gets all the remaining bits.
4775   NewIndVars[0] = Leftover;
4776 
4777   // Construct the loop body control flow.
4778   // We progressively construct the branch structure following in direction of
4779   // the control flow, from the leading in-between code, the loop nest body, the
4780   // trailing in-between code, and rejoining the collapsed loop's latch.
4781   // ContinueBlock and ContinuePred keep track of the source(s) of next edge. If
4782   // the ContinueBlock is set, continue with that block. If ContinuePred, use
4783   // its predecessors as sources.
4784   BasicBlock *ContinueBlock = Result->getBody();
4785   BasicBlock *ContinuePred = nullptr;
4786   auto ContinueWith = [&ContinueBlock, &ContinuePred, DL](BasicBlock *Dest,
4787                                                           BasicBlock *NextSrc) {
4788     if (ContinueBlock)
4789       redirectTo(ContinueBlock, Dest, DL);
4790     else
4791       redirectAllPredecessorsTo(ContinuePred, Dest, DL);
4792 
4793     ContinueBlock = nullptr;
4794     ContinuePred = NextSrc;
4795   };
4796 
4797   // The code before the nested loop of each level.
4798   // Because we are sinking it into the nest, it will be executed more often
4799   // that the original loop. More sophisticated schemes could keep track of what
4800   // the in-between code is and instantiate it only once per thread.
4801   for (size_t i = 0; i < NumLoops - 1; ++i)
4802     ContinueWith(Loops[i]->getBody(), Loops[i + 1]->getHeader());
4803 
4804   // Connect the loop nest body.
4805   ContinueWith(Innermost->getBody(), Innermost->getLatch());
4806 
4807   // The code after the nested loop at each level.
4808   for (size_t i = NumLoops - 1; i > 0; --i)
4809     ContinueWith(Loops[i]->getAfter(), Loops[i - 1]->getLatch());
4810 
4811   // Connect the finished loop to the collapsed loop latch.
4812   ContinueWith(Result->getLatch(), nullptr);
4813 
4814   // Replace the input loops with the new collapsed loop.
4815   redirectTo(Outermost->getPreheader(), Result->getPreheader(), DL);
4816   redirectTo(Result->getAfter(), Outermost->getAfter(), DL);
4817 
4818   // Replace the input loop indvars with the derived ones.
4819   for (size_t i = 0; i < NumLoops; ++i)
4820     Loops[i]->getIndVar()->replaceAllUsesWith(NewIndVars[i]);
4821 
4822   // Remove unused parts of the input loops.
4823   removeUnusedBlocksFromParent(OldControlBBs);
4824 
4825   for (CanonicalLoopInfo *L : Loops)
4826     L->invalidate();
4827 
4828 #ifndef NDEBUG
4829   Result->assertOK();
4830 #endif
4831   return Result;
4832 }
4833 
4834 std::vector<CanonicalLoopInfo *>
tileLoops(DebugLoc DL,ArrayRef<CanonicalLoopInfo * > Loops,ArrayRef<Value * > TileSizes)4835 OpenMPIRBuilder::tileLoops(DebugLoc DL, ArrayRef<CanonicalLoopInfo *> Loops,
4836                            ArrayRef<Value *> TileSizes) {
4837   assert(TileSizes.size() == Loops.size() &&
4838          "Must pass as many tile sizes as there are loops");
4839   int NumLoops = Loops.size();
4840   assert(NumLoops >= 1 && "At least one loop to tile required");
4841 
4842   CanonicalLoopInfo *OutermostLoop = Loops.front();
4843   CanonicalLoopInfo *InnermostLoop = Loops.back();
4844   Function *F = OutermostLoop->getBody()->getParent();
4845   BasicBlock *InnerEnter = InnermostLoop->getBody();
4846   BasicBlock *InnerLatch = InnermostLoop->getLatch();
4847 
4848   // Loop control blocks that may become orphaned later.
4849   SmallVector<BasicBlock *, 12> OldControlBBs;
4850   OldControlBBs.reserve(6 * Loops.size());
4851   for (CanonicalLoopInfo *Loop : Loops)
4852     Loop->collectControlBlocks(OldControlBBs);
4853 
4854   // Collect original trip counts and induction variable to be accessible by
4855   // index. Also, the structure of the original loops is not preserved during
4856   // the construction of the tiled loops, so do it before we scavenge the BBs of
4857   // any original CanonicalLoopInfo.
4858   SmallVector<Value *, 4> OrigTripCounts, OrigIndVars;
4859   for (CanonicalLoopInfo *L : Loops) {
4860     assert(L->isValid() && "All input loops must be valid canonical loops");
4861     OrigTripCounts.push_back(L->getTripCount());
4862     OrigIndVars.push_back(L->getIndVar());
4863   }
4864 
4865   // Collect the code between loop headers. These may contain SSA definitions
4866   // that are used in the loop nest body. To be usable with in the innermost
4867   // body, these BasicBlocks will be sunk into the loop nest body. That is,
4868   // these instructions may be executed more often than before the tiling.
4869   // TODO: It would be sufficient to only sink them into body of the
4870   // corresponding tile loop.
4871   SmallVector<std::pair<BasicBlock *, BasicBlock *>, 4> InbetweenCode;
4872   for (int i = 0; i < NumLoops - 1; ++i) {
4873     CanonicalLoopInfo *Surrounding = Loops[i];
4874     CanonicalLoopInfo *Nested = Loops[i + 1];
4875 
4876     BasicBlock *EnterBB = Surrounding->getBody();
4877     BasicBlock *ExitBB = Nested->getHeader();
4878     InbetweenCode.emplace_back(EnterBB, ExitBB);
4879   }
4880 
4881   // Compute the trip counts of the floor loops.
4882   Builder.SetCurrentDebugLocation(DL);
4883   Builder.restoreIP(OutermostLoop->getPreheaderIP());
4884   SmallVector<Value *, 4> FloorCount, FloorRems;
4885   for (int i = 0; i < NumLoops; ++i) {
4886     Value *TileSize = TileSizes[i];
4887     Value *OrigTripCount = OrigTripCounts[i];
4888     Type *IVType = OrigTripCount->getType();
4889 
4890     Value *FloorTripCount = Builder.CreateUDiv(OrigTripCount, TileSize);
4891     Value *FloorTripRem = Builder.CreateURem(OrigTripCount, TileSize);
4892 
4893     // 0 if tripcount divides the tilesize, 1 otherwise.
4894     // 1 means we need an additional iteration for a partial tile.
4895     //
4896     // Unfortunately we cannot just use the roundup-formula
4897     //   (tripcount + tilesize - 1)/tilesize
4898     // because the summation might overflow. We do not want introduce undefined
4899     // behavior when the untiled loop nest did not.
4900     Value *FloorTripOverflow =
4901         Builder.CreateICmpNE(FloorTripRem, ConstantInt::get(IVType, 0));
4902 
4903     FloorTripOverflow = Builder.CreateZExt(FloorTripOverflow, IVType);
4904     FloorTripCount =
4905         Builder.CreateAdd(FloorTripCount, FloorTripOverflow,
4906                           "omp_floor" + Twine(i) + ".tripcount", true);
4907 
4908     // Remember some values for later use.
4909     FloorCount.push_back(FloorTripCount);
4910     FloorRems.push_back(FloorTripRem);
4911   }
4912 
4913   // Generate the new loop nest, from the outermost to the innermost.
4914   std::vector<CanonicalLoopInfo *> Result;
4915   Result.reserve(NumLoops * 2);
4916 
4917   // The basic block of the surrounding loop that enters the nest generated
4918   // loop.
4919   BasicBlock *Enter = OutermostLoop->getPreheader();
4920 
4921   // The basic block of the surrounding loop where the inner code should
4922   // continue.
4923   BasicBlock *Continue = OutermostLoop->getAfter();
4924 
4925   // Where the next loop basic block should be inserted.
4926   BasicBlock *OutroInsertBefore = InnermostLoop->getExit();
4927 
4928   auto EmbeddNewLoop =
4929       [this, DL, F, InnerEnter, &Enter, &Continue, &OutroInsertBefore](
4930           Value *TripCount, const Twine &Name) -> CanonicalLoopInfo * {
4931     CanonicalLoopInfo *EmbeddedLoop = createLoopSkeleton(
4932         DL, TripCount, F, InnerEnter, OutroInsertBefore, Name);
4933     redirectTo(Enter, EmbeddedLoop->getPreheader(), DL);
4934     redirectTo(EmbeddedLoop->getAfter(), Continue, DL);
4935 
4936     // Setup the position where the next embedded loop connects to this loop.
4937     Enter = EmbeddedLoop->getBody();
4938     Continue = EmbeddedLoop->getLatch();
4939     OutroInsertBefore = EmbeddedLoop->getLatch();
4940     return EmbeddedLoop;
4941   };
4942 
4943   auto EmbeddNewLoops = [&Result, &EmbeddNewLoop](ArrayRef<Value *> TripCounts,
4944                                                   const Twine &NameBase) {
4945     for (auto P : enumerate(TripCounts)) {
4946       CanonicalLoopInfo *EmbeddedLoop =
4947           EmbeddNewLoop(P.value(), NameBase + Twine(P.index()));
4948       Result.push_back(EmbeddedLoop);
4949     }
4950   };
4951 
4952   EmbeddNewLoops(FloorCount, "floor");
4953 
4954   // Within the innermost floor loop, emit the code that computes the tile
4955   // sizes.
4956   Builder.SetInsertPoint(Enter->getTerminator());
4957   SmallVector<Value *, 4> TileCounts;
4958   for (int i = 0; i < NumLoops; ++i) {
4959     CanonicalLoopInfo *FloorLoop = Result[i];
4960     Value *TileSize = TileSizes[i];
4961 
4962     Value *FloorIsEpilogue =
4963         Builder.CreateICmpEQ(FloorLoop->getIndVar(), FloorCount[i]);
4964     Value *TileTripCount =
4965         Builder.CreateSelect(FloorIsEpilogue, FloorRems[i], TileSize);
4966 
4967     TileCounts.push_back(TileTripCount);
4968   }
4969 
4970   // Create the tile loops.
4971   EmbeddNewLoops(TileCounts, "tile");
4972 
4973   // Insert the inbetween code into the body.
4974   BasicBlock *BodyEnter = Enter;
4975   BasicBlock *BodyEntered = nullptr;
4976   for (std::pair<BasicBlock *, BasicBlock *> P : InbetweenCode) {
4977     BasicBlock *EnterBB = P.first;
4978     BasicBlock *ExitBB = P.second;
4979 
4980     if (BodyEnter)
4981       redirectTo(BodyEnter, EnterBB, DL);
4982     else
4983       redirectAllPredecessorsTo(BodyEntered, EnterBB, DL);
4984 
4985     BodyEnter = nullptr;
4986     BodyEntered = ExitBB;
4987   }
4988 
4989   // Append the original loop nest body into the generated loop nest body.
4990   if (BodyEnter)
4991     redirectTo(BodyEnter, InnerEnter, DL);
4992   else
4993     redirectAllPredecessorsTo(BodyEntered, InnerEnter, DL);
4994   redirectAllPredecessorsTo(InnerLatch, Continue, DL);
4995 
4996   // Replace the original induction variable with an induction variable computed
4997   // from the tile and floor induction variables.
4998   Builder.restoreIP(Result.back()->getBodyIP());
4999   for (int i = 0; i < NumLoops; ++i) {
5000     CanonicalLoopInfo *FloorLoop = Result[i];
5001     CanonicalLoopInfo *TileLoop = Result[NumLoops + i];
5002     Value *OrigIndVar = OrigIndVars[i];
5003     Value *Size = TileSizes[i];
5004 
5005     Value *Scale =
5006         Builder.CreateMul(Size, FloorLoop->getIndVar(), {}, /*HasNUW=*/true);
5007     Value *Shift =
5008         Builder.CreateAdd(Scale, TileLoop->getIndVar(), {}, /*HasNUW=*/true);
5009     OrigIndVar->replaceAllUsesWith(Shift);
5010   }
5011 
5012   // Remove unused parts of the original loops.
5013   removeUnusedBlocksFromParent(OldControlBBs);
5014 
5015   for (CanonicalLoopInfo *L : Loops)
5016     L->invalidate();
5017 
5018 #ifndef NDEBUG
5019   for (CanonicalLoopInfo *GenL : Result)
5020     GenL->assertOK();
5021 #endif
5022   return Result;
5023 }
5024 
5025 /// Attach metadata \p Properties to the basic block described by \p BB. If the
5026 /// basic block already has metadata, the basic block properties are appended.
addBasicBlockMetadata(BasicBlock * BB,ArrayRef<Metadata * > Properties)5027 static void addBasicBlockMetadata(BasicBlock *BB,
5028                                   ArrayRef<Metadata *> Properties) {
5029   // Nothing to do if no property to attach.
5030   if (Properties.empty())
5031     return;
5032 
5033   LLVMContext &Ctx = BB->getContext();
5034   SmallVector<Metadata *> NewProperties;
5035   NewProperties.push_back(nullptr);
5036 
5037   // If the basic block already has metadata, prepend it to the new metadata.
5038   MDNode *Existing = BB->getTerminator()->getMetadata(LLVMContext::MD_loop);
5039   if (Existing)
5040     append_range(NewProperties, drop_begin(Existing->operands(), 1));
5041 
5042   append_range(NewProperties, Properties);
5043   MDNode *BasicBlockID = MDNode::getDistinct(Ctx, NewProperties);
5044   BasicBlockID->replaceOperandWith(0, BasicBlockID);
5045 
5046   BB->getTerminator()->setMetadata(LLVMContext::MD_loop, BasicBlockID);
5047 }
5048 
5049 /// Attach loop metadata \p Properties to the loop described by \p Loop. If the
5050 /// loop already has metadata, the loop properties are appended.
addLoopMetadata(CanonicalLoopInfo * Loop,ArrayRef<Metadata * > Properties)5051 static void addLoopMetadata(CanonicalLoopInfo *Loop,
5052                             ArrayRef<Metadata *> Properties) {
5053   assert(Loop->isValid() && "Expecting a valid CanonicalLoopInfo");
5054 
5055   // Attach metadata to the loop's latch
5056   BasicBlock *Latch = Loop->getLatch();
5057   assert(Latch && "A valid CanonicalLoopInfo must have a unique latch");
5058   addBasicBlockMetadata(Latch, Properties);
5059 }
5060 
5061 /// Attach llvm.access.group metadata to the memref instructions of \p Block
addSimdMetadata(BasicBlock * Block,MDNode * AccessGroup,LoopInfo & LI)5062 static void addSimdMetadata(BasicBlock *Block, MDNode *AccessGroup,
5063                             LoopInfo &LI) {
5064   for (Instruction &I : *Block) {
5065     if (I.mayReadOrWriteMemory()) {
5066       // TODO: This instruction may already have access group from
5067       // other pragmas e.g. #pragma clang loop vectorize.  Append
5068       // so that the existing metadata is not overwritten.
5069       I.setMetadata(LLVMContext::MD_access_group, AccessGroup);
5070     }
5071   }
5072 }
5073 
unrollLoopFull(DebugLoc,CanonicalLoopInfo * Loop)5074 void OpenMPIRBuilder::unrollLoopFull(DebugLoc, CanonicalLoopInfo *Loop) {
5075   LLVMContext &Ctx = Builder.getContext();
5076   addLoopMetadata(
5077       Loop, {MDNode::get(Ctx, MDString::get(Ctx, "llvm.loop.unroll.enable")),
5078              MDNode::get(Ctx, MDString::get(Ctx, "llvm.loop.unroll.full"))});
5079 }
5080 
unrollLoopHeuristic(DebugLoc,CanonicalLoopInfo * Loop)5081 void OpenMPIRBuilder::unrollLoopHeuristic(DebugLoc, CanonicalLoopInfo *Loop) {
5082   LLVMContext &Ctx = Builder.getContext();
5083   addLoopMetadata(
5084       Loop, {
5085                 MDNode::get(Ctx, MDString::get(Ctx, "llvm.loop.unroll.enable")),
5086             });
5087 }
5088 
createIfVersion(CanonicalLoopInfo * CanonicalLoop,Value * IfCond,ValueToValueMapTy & VMap,const Twine & NamePrefix)5089 void OpenMPIRBuilder::createIfVersion(CanonicalLoopInfo *CanonicalLoop,
5090                                       Value *IfCond, ValueToValueMapTy &VMap,
5091                                       const Twine &NamePrefix) {
5092   Function *F = CanonicalLoop->getFunction();
5093 
5094   // Define where if branch should be inserted
5095   Instruction *SplitBefore;
5096   if (Instruction::classof(IfCond)) {
5097     SplitBefore = dyn_cast<Instruction>(IfCond);
5098   } else {
5099     SplitBefore = CanonicalLoop->getPreheader()->getTerminator();
5100   }
5101 
5102   // TODO: We should not rely on pass manager. Currently we use pass manager
5103   // only for getting llvm::Loop which corresponds to given CanonicalLoopInfo
5104   // object. We should have a method  which returns all blocks between
5105   // CanonicalLoopInfo::getHeader() and CanonicalLoopInfo::getAfter()
5106   FunctionAnalysisManager FAM;
5107   FAM.registerPass([]() { return DominatorTreeAnalysis(); });
5108   FAM.registerPass([]() { return LoopAnalysis(); });
5109   FAM.registerPass([]() { return PassInstrumentationAnalysis(); });
5110 
5111   // Get the loop which needs to be cloned
5112   LoopAnalysis LIA;
5113   LoopInfo &&LI = LIA.run(*F, FAM);
5114   Loop *L = LI.getLoopFor(CanonicalLoop->getHeader());
5115 
5116   // Create additional blocks for the if statement
5117   BasicBlock *Head = SplitBefore->getParent();
5118   Instruction *HeadOldTerm = Head->getTerminator();
5119   llvm::LLVMContext &C = Head->getContext();
5120   llvm::BasicBlock *ThenBlock = llvm::BasicBlock::Create(
5121       C, NamePrefix + ".if.then", Head->getParent(), Head->getNextNode());
5122   llvm::BasicBlock *ElseBlock = llvm::BasicBlock::Create(
5123       C, NamePrefix + ".if.else", Head->getParent(), CanonicalLoop->getExit());
5124 
5125   // Create if condition branch.
5126   Builder.SetInsertPoint(HeadOldTerm);
5127   Instruction *BrInstr =
5128       Builder.CreateCondBr(IfCond, ThenBlock, /*ifFalse*/ ElseBlock);
5129   InsertPointTy IP{BrInstr->getParent(), ++BrInstr->getIterator()};
5130   // Then block contains branch to omp loop which needs to be vectorized
5131   spliceBB(IP, ThenBlock, false);
5132   ThenBlock->replaceSuccessorsPhiUsesWith(Head, ThenBlock);
5133 
5134   Builder.SetInsertPoint(ElseBlock);
5135 
5136   // Clone loop for the else branch
5137   SmallVector<BasicBlock *, 8> NewBlocks;
5138 
5139   VMap[CanonicalLoop->getPreheader()] = ElseBlock;
5140   for (BasicBlock *Block : L->getBlocks()) {
5141     BasicBlock *NewBB = CloneBasicBlock(Block, VMap, "", F);
5142     NewBB->moveBefore(CanonicalLoop->getExit());
5143     VMap[Block] = NewBB;
5144     NewBlocks.push_back(NewBB);
5145   }
5146   remapInstructionsInBlocks(NewBlocks, VMap);
5147   Builder.CreateBr(NewBlocks.front());
5148 }
5149 
5150 unsigned
getOpenMPDefaultSimdAlign(const Triple & TargetTriple,const StringMap<bool> & Features)5151 OpenMPIRBuilder::getOpenMPDefaultSimdAlign(const Triple &TargetTriple,
5152                                            const StringMap<bool> &Features) {
5153   if (TargetTriple.isX86()) {
5154     if (Features.lookup("avx512f"))
5155       return 512;
5156     else if (Features.lookup("avx"))
5157       return 256;
5158     return 128;
5159   }
5160   if (TargetTriple.isPPC())
5161     return 128;
5162   if (TargetTriple.isWasm())
5163     return 128;
5164   return 0;
5165 }
5166 
applySimd(CanonicalLoopInfo * CanonicalLoop,MapVector<Value *,Value * > AlignedVars,Value * IfCond,OrderKind Order,ConstantInt * Simdlen,ConstantInt * Safelen)5167 void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop,
5168                                 MapVector<Value *, Value *> AlignedVars,
5169                                 Value *IfCond, OrderKind Order,
5170                                 ConstantInt *Simdlen, ConstantInt *Safelen) {
5171   LLVMContext &Ctx = Builder.getContext();
5172 
5173   Function *F = CanonicalLoop->getFunction();
5174 
5175   // TODO: We should not rely on pass manager. Currently we use pass manager
5176   // only for getting llvm::Loop which corresponds to given CanonicalLoopInfo
5177   // object. We should have a method  which returns all blocks between
5178   // CanonicalLoopInfo::getHeader() and CanonicalLoopInfo::getAfter()
5179   FunctionAnalysisManager FAM;
5180   FAM.registerPass([]() { return DominatorTreeAnalysis(); });
5181   FAM.registerPass([]() { return LoopAnalysis(); });
5182   FAM.registerPass([]() { return PassInstrumentationAnalysis(); });
5183 
5184   LoopAnalysis LIA;
5185   LoopInfo &&LI = LIA.run(*F, FAM);
5186 
5187   Loop *L = LI.getLoopFor(CanonicalLoop->getHeader());
5188   if (AlignedVars.size()) {
5189     InsertPointTy IP = Builder.saveIP();
5190     Builder.SetInsertPoint(CanonicalLoop->getPreheader()->getTerminator());
5191     for (auto &AlignedItem : AlignedVars) {
5192       Value *AlignedPtr = AlignedItem.first;
5193       Value *Alignment = AlignedItem.second;
5194       Builder.CreateAlignmentAssumption(F->getDataLayout(),
5195                                         AlignedPtr, Alignment);
5196     }
5197     Builder.restoreIP(IP);
5198   }
5199 
5200   if (IfCond) {
5201     ValueToValueMapTy VMap;
5202     createIfVersion(CanonicalLoop, IfCond, VMap, "simd");
5203     // Add metadata to the cloned loop which disables vectorization
5204     Value *MappedLatch = VMap.lookup(CanonicalLoop->getLatch());
5205     assert(MappedLatch &&
5206            "Cannot find value which corresponds to original loop latch");
5207     assert(isa<BasicBlock>(MappedLatch) &&
5208            "Cannot cast mapped latch block value to BasicBlock");
5209     BasicBlock *NewLatchBlock = dyn_cast<BasicBlock>(MappedLatch);
5210     ConstantAsMetadata *BoolConst =
5211         ConstantAsMetadata::get(ConstantInt::getFalse(Type::getInt1Ty(Ctx)));
5212     addBasicBlockMetadata(
5213         NewLatchBlock,
5214         {MDNode::get(Ctx, {MDString::get(Ctx, "llvm.loop.vectorize.enable"),
5215                            BoolConst})});
5216   }
5217 
5218   SmallSet<BasicBlock *, 8> Reachable;
5219 
5220   // Get the basic blocks from the loop in which memref instructions
5221   // can be found.
5222   // TODO: Generalize getting all blocks inside a CanonicalizeLoopInfo,
5223   // preferably without running any passes.
5224   for (BasicBlock *Block : L->getBlocks()) {
5225     if (Block == CanonicalLoop->getCond() ||
5226         Block == CanonicalLoop->getHeader())
5227       continue;
5228     Reachable.insert(Block);
5229   }
5230 
5231   SmallVector<Metadata *> LoopMDList;
5232 
5233   // In presence of finite 'safelen', it may be unsafe to mark all
5234   // the memory instructions parallel, because loop-carried
5235   // dependences of 'safelen' iterations are possible.
5236   // If clause order(concurrent) is specified then the memory instructions
5237   // are marked parallel even if 'safelen' is finite.
5238   if ((Safelen == nullptr) || (Order == OrderKind::OMP_ORDER_concurrent)) {
5239     // Add access group metadata to memory-access instructions.
5240     MDNode *AccessGroup = MDNode::getDistinct(Ctx, {});
5241     for (BasicBlock *BB : Reachable)
5242       addSimdMetadata(BB, AccessGroup, LI);
5243     // TODO:  If the loop has existing parallel access metadata, have
5244     // to combine two lists.
5245     LoopMDList.push_back(MDNode::get(
5246         Ctx, {MDString::get(Ctx, "llvm.loop.parallel_accesses"), AccessGroup}));
5247   }
5248 
5249   // Use the above access group metadata to create loop level
5250   // metadata, which should be distinct for each loop.
5251   ConstantAsMetadata *BoolConst =
5252       ConstantAsMetadata::get(ConstantInt::getTrue(Type::getInt1Ty(Ctx)));
5253   LoopMDList.push_back(MDNode::get(
5254       Ctx, {MDString::get(Ctx, "llvm.loop.vectorize.enable"), BoolConst}));
5255 
5256   if (Simdlen || Safelen) {
5257     // If both simdlen and safelen clauses are specified, the value of the
5258     // simdlen parameter must be less than or equal to the value of the safelen
5259     // parameter. Therefore, use safelen only in the absence of simdlen.
5260     ConstantInt *VectorizeWidth = Simdlen == nullptr ? Safelen : Simdlen;
5261     LoopMDList.push_back(
5262         MDNode::get(Ctx, {MDString::get(Ctx, "llvm.loop.vectorize.width"),
5263                           ConstantAsMetadata::get(VectorizeWidth)}));
5264   }
5265 
5266   addLoopMetadata(CanonicalLoop, LoopMDList);
5267 }
5268 
5269 /// Create the TargetMachine object to query the backend for optimization
5270 /// preferences.
5271 ///
5272 /// Ideally, this would be passed from the front-end to the OpenMPBuilder, but
5273 /// e.g. Clang does not pass it to its CodeGen layer and creates it only when
5274 /// needed for the LLVM pass pipline. We use some default options to avoid
5275 /// having to pass too many settings from the frontend that probably do not
5276 /// matter.
5277 ///
5278 /// Currently, TargetMachine is only used sometimes by the unrollLoopPartial
5279 /// method. If we are going to use TargetMachine for more purposes, especially
5280 /// those that are sensitive to TargetOptions, RelocModel and CodeModel, it
5281 /// might become be worth requiring front-ends to pass on their TargetMachine,
5282 /// or at least cache it between methods. Note that while fontends such as Clang
5283 /// have just a single main TargetMachine per translation unit, "target-cpu" and
5284 /// "target-features" that determine the TargetMachine are per-function and can
5285 /// be overrided using __attribute__((target("OPTIONS"))).
5286 static std::unique_ptr<TargetMachine>
createTargetMachine(Function * F,CodeGenOptLevel OptLevel)5287 createTargetMachine(Function *F, CodeGenOptLevel OptLevel) {
5288   Module *M = F->getParent();
5289 
5290   StringRef CPU = F->getFnAttribute("target-cpu").getValueAsString();
5291   StringRef Features = F->getFnAttribute("target-features").getValueAsString();
5292   const std::string &Triple = M->getTargetTriple();
5293 
5294   std::string Error;
5295   const llvm::Target *TheTarget = TargetRegistry::lookupTarget(Triple, Error);
5296   if (!TheTarget)
5297     return {};
5298 
5299   llvm::TargetOptions Options;
5300   return std::unique_ptr<TargetMachine>(TheTarget->createTargetMachine(
5301       Triple, CPU, Features, Options, /*RelocModel=*/std::nullopt,
5302       /*CodeModel=*/std::nullopt, OptLevel));
5303 }
5304 
5305 /// Heuristically determine the best-performant unroll factor for \p CLI. This
5306 /// depends on the target processor. We are re-using the same heuristics as the
5307 /// LoopUnrollPass.
computeHeuristicUnrollFactor(CanonicalLoopInfo * CLI)5308 static int32_t computeHeuristicUnrollFactor(CanonicalLoopInfo *CLI) {
5309   Function *F = CLI->getFunction();
5310 
5311   // Assume the user requests the most aggressive unrolling, even if the rest of
5312   // the code is optimized using a lower setting.
5313   CodeGenOptLevel OptLevel = CodeGenOptLevel::Aggressive;
5314   std::unique_ptr<TargetMachine> TM = createTargetMachine(F, OptLevel);
5315 
5316   FunctionAnalysisManager FAM;
5317   FAM.registerPass([]() { return TargetLibraryAnalysis(); });
5318   FAM.registerPass([]() { return AssumptionAnalysis(); });
5319   FAM.registerPass([]() { return DominatorTreeAnalysis(); });
5320   FAM.registerPass([]() { return LoopAnalysis(); });
5321   FAM.registerPass([]() { return ScalarEvolutionAnalysis(); });
5322   FAM.registerPass([]() { return PassInstrumentationAnalysis(); });
5323   TargetIRAnalysis TIRA;
5324   if (TM)
5325     TIRA = TargetIRAnalysis(
5326         [&](const Function &F) { return TM->getTargetTransformInfo(F); });
5327   FAM.registerPass([&]() { return TIRA; });
5328 
5329   TargetIRAnalysis::Result &&TTI = TIRA.run(*F, FAM);
5330   ScalarEvolutionAnalysis SEA;
5331   ScalarEvolution &&SE = SEA.run(*F, FAM);
5332   DominatorTreeAnalysis DTA;
5333   DominatorTree &&DT = DTA.run(*F, FAM);
5334   LoopAnalysis LIA;
5335   LoopInfo &&LI = LIA.run(*F, FAM);
5336   AssumptionAnalysis ACT;
5337   AssumptionCache &&AC = ACT.run(*F, FAM);
5338   OptimizationRemarkEmitter ORE{F};
5339 
5340   Loop *L = LI.getLoopFor(CLI->getHeader());
5341   assert(L && "Expecting CanonicalLoopInfo to be recognized as a loop");
5342 
5343   TargetTransformInfo::UnrollingPreferences UP =
5344       gatherUnrollingPreferences(L, SE, TTI,
5345                                  /*BlockFrequencyInfo=*/nullptr,
5346                                  /*ProfileSummaryInfo=*/nullptr, ORE, static_cast<int>(OptLevel),
5347                                  /*UserThreshold=*/std::nullopt,
5348                                  /*UserCount=*/std::nullopt,
5349                                  /*UserAllowPartial=*/true,
5350                                  /*UserAllowRuntime=*/true,
5351                                  /*UserUpperBound=*/std::nullopt,
5352                                  /*UserFullUnrollMaxCount=*/std::nullopt);
5353 
5354   UP.Force = true;
5355 
5356   // Account for additional optimizations taking place before the LoopUnrollPass
5357   // would unroll the loop.
5358   UP.Threshold *= UnrollThresholdFactor;
5359   UP.PartialThreshold *= UnrollThresholdFactor;
5360 
5361   // Use normal unroll factors even if the rest of the code is optimized for
5362   // size.
5363   UP.OptSizeThreshold = UP.Threshold;
5364   UP.PartialOptSizeThreshold = UP.PartialThreshold;
5365 
5366   LLVM_DEBUG(dbgs() << "Unroll heuristic thresholds:\n"
5367                     << "  Threshold=" << UP.Threshold << "\n"
5368                     << "  PartialThreshold=" << UP.PartialThreshold << "\n"
5369                     << "  OptSizeThreshold=" << UP.OptSizeThreshold << "\n"
5370                     << "  PartialOptSizeThreshold="
5371                     << UP.PartialOptSizeThreshold << "\n");
5372 
5373   // Disable peeling.
5374   TargetTransformInfo::PeelingPreferences PP =
5375       gatherPeelingPreferences(L, SE, TTI,
5376                                /*UserAllowPeeling=*/false,
5377                                /*UserAllowProfileBasedPeeling=*/false,
5378                                /*UnrollingSpecficValues=*/false);
5379 
5380   SmallPtrSet<const Value *, 32> EphValues;
5381   CodeMetrics::collectEphemeralValues(L, &AC, EphValues);
5382 
5383   // Assume that reads and writes to stack variables can be eliminated by
5384   // Mem2Reg, SROA or LICM. That is, don't count them towards the loop body's
5385   // size.
5386   for (BasicBlock *BB : L->blocks()) {
5387     for (Instruction &I : *BB) {
5388       Value *Ptr;
5389       if (auto *Load = dyn_cast<LoadInst>(&I)) {
5390         Ptr = Load->getPointerOperand();
5391       } else if (auto *Store = dyn_cast<StoreInst>(&I)) {
5392         Ptr = Store->getPointerOperand();
5393       } else
5394         continue;
5395 
5396       Ptr = Ptr->stripPointerCasts();
5397 
5398       if (auto *Alloca = dyn_cast<AllocaInst>(Ptr)) {
5399         if (Alloca->getParent() == &F->getEntryBlock())
5400           EphValues.insert(&I);
5401       }
5402     }
5403   }
5404 
5405   UnrollCostEstimator UCE(L, TTI, EphValues, UP.BEInsns);
5406 
5407   // Loop is not unrollable if the loop contains certain instructions.
5408   if (!UCE.canUnroll()) {
5409     LLVM_DEBUG(dbgs() << "Loop not considered unrollable\n");
5410     return 1;
5411   }
5412 
5413   LLVM_DEBUG(dbgs() << "Estimated loop size is " << UCE.getRolledLoopSize()
5414                     << "\n");
5415 
5416   // TODO: Determine trip count of \p CLI if constant, computeUnrollCount might
5417   // be able to use it.
5418   int TripCount = 0;
5419   int MaxTripCount = 0;
5420   bool MaxOrZero = false;
5421   unsigned TripMultiple = 0;
5422 
5423   bool UseUpperBound = false;
5424   computeUnrollCount(L, TTI, DT, &LI, &AC, SE, EphValues, &ORE, TripCount,
5425                      MaxTripCount, MaxOrZero, TripMultiple, UCE, UP, PP,
5426                      UseUpperBound);
5427   unsigned Factor = UP.Count;
5428   LLVM_DEBUG(dbgs() << "Suggesting unroll factor of " << Factor << "\n");
5429 
5430   // This function returns 1 to signal to not unroll a loop.
5431   if (Factor == 0)
5432     return 1;
5433   return Factor;
5434 }
5435 
unrollLoopPartial(DebugLoc DL,CanonicalLoopInfo * Loop,int32_t Factor,CanonicalLoopInfo ** UnrolledCLI)5436 void OpenMPIRBuilder::unrollLoopPartial(DebugLoc DL, CanonicalLoopInfo *Loop,
5437                                         int32_t Factor,
5438                                         CanonicalLoopInfo **UnrolledCLI) {
5439   assert(Factor >= 0 && "Unroll factor must not be negative");
5440 
5441   Function *F = Loop->getFunction();
5442   LLVMContext &Ctx = F->getContext();
5443 
5444   // If the unrolled loop is not used for another loop-associated directive, it
5445   // is sufficient to add metadata for the LoopUnrollPass.
5446   if (!UnrolledCLI) {
5447     SmallVector<Metadata *, 2> LoopMetadata;
5448     LoopMetadata.push_back(
5449         MDNode::get(Ctx, MDString::get(Ctx, "llvm.loop.unroll.enable")));
5450 
5451     if (Factor >= 1) {
5452       ConstantAsMetadata *FactorConst = ConstantAsMetadata::get(
5453           ConstantInt::get(Type::getInt32Ty(Ctx), APInt(32, Factor)));
5454       LoopMetadata.push_back(MDNode::get(
5455           Ctx, {MDString::get(Ctx, "llvm.loop.unroll.count"), FactorConst}));
5456     }
5457 
5458     addLoopMetadata(Loop, LoopMetadata);
5459     return;
5460   }
5461 
5462   // Heuristically determine the unroll factor.
5463   if (Factor == 0)
5464     Factor = computeHeuristicUnrollFactor(Loop);
5465 
5466   // No change required with unroll factor 1.
5467   if (Factor == 1) {
5468     *UnrolledCLI = Loop;
5469     return;
5470   }
5471 
5472   assert(Factor >= 2 &&
5473          "unrolling only makes sense with a factor of 2 or larger");
5474 
5475   Type *IndVarTy = Loop->getIndVarType();
5476 
5477   // Apply partial unrolling by tiling the loop by the unroll-factor, then fully
5478   // unroll the inner loop.
5479   Value *FactorVal =
5480       ConstantInt::get(IndVarTy, APInt(IndVarTy->getIntegerBitWidth(), Factor,
5481                                        /*isSigned=*/false));
5482   std::vector<CanonicalLoopInfo *> LoopNest =
5483       tileLoops(DL, {Loop}, {FactorVal});
5484   assert(LoopNest.size() == 2 && "Expect 2 loops after tiling");
5485   *UnrolledCLI = LoopNest[0];
5486   CanonicalLoopInfo *InnerLoop = LoopNest[1];
5487 
5488   // LoopUnrollPass can only fully unroll loops with constant trip count.
5489   // Unroll by the unroll factor with a fallback epilog for the remainder
5490   // iterations if necessary.
5491   ConstantAsMetadata *FactorConst = ConstantAsMetadata::get(
5492       ConstantInt::get(Type::getInt32Ty(Ctx), APInt(32, Factor)));
5493   addLoopMetadata(
5494       InnerLoop,
5495       {MDNode::get(Ctx, MDString::get(Ctx, "llvm.loop.unroll.enable")),
5496        MDNode::get(
5497            Ctx, {MDString::get(Ctx, "llvm.loop.unroll.count"), FactorConst})});
5498 
5499 #ifndef NDEBUG
5500   (*UnrolledCLI)->assertOK();
5501 #endif
5502 }
5503 
5504 OpenMPIRBuilder::InsertPointTy
createCopyPrivate(const LocationDescription & Loc,llvm::Value * BufSize,llvm::Value * CpyBuf,llvm::Value * CpyFn,llvm::Value * DidIt)5505 OpenMPIRBuilder::createCopyPrivate(const LocationDescription &Loc,
5506                                    llvm::Value *BufSize, llvm::Value *CpyBuf,
5507                                    llvm::Value *CpyFn, llvm::Value *DidIt) {
5508   if (!updateToLocation(Loc))
5509     return Loc.IP;
5510 
5511   uint32_t SrcLocStrSize;
5512   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5513   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5514   Value *ThreadId = getOrCreateThreadID(Ident);
5515 
5516   llvm::Value *DidItLD = Builder.CreateLoad(Builder.getInt32Ty(), DidIt);
5517 
5518   Value *Args[] = {Ident, ThreadId, BufSize, CpyBuf, CpyFn, DidItLD};
5519 
5520   Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_copyprivate);
5521   Builder.CreateCall(Fn, Args);
5522 
5523   return Builder.saveIP();
5524 }
5525 
createSingle(const LocationDescription & Loc,BodyGenCallbackTy BodyGenCB,FinalizeCallbackTy FiniCB,bool IsNowait,ArrayRef<llvm::Value * > CPVars,ArrayRef<llvm::Function * > CPFuncs)5526 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createSingle(
5527     const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
5528     FinalizeCallbackTy FiniCB, bool IsNowait, ArrayRef<llvm::Value *> CPVars,
5529     ArrayRef<llvm::Function *> CPFuncs) {
5530 
5531   if (!updateToLocation(Loc))
5532     return Loc.IP;
5533 
5534   // If needed allocate and initialize `DidIt` with 0.
5535   // DidIt: flag variable: 1=single thread; 0=not single thread.
5536   llvm::Value *DidIt = nullptr;
5537   if (!CPVars.empty()) {
5538     DidIt = Builder.CreateAlloca(llvm::Type::getInt32Ty(Builder.getContext()));
5539     Builder.CreateStore(Builder.getInt32(0), DidIt);
5540   }
5541 
5542   Directive OMPD = Directive::OMPD_single;
5543   uint32_t SrcLocStrSize;
5544   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5545   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5546   Value *ThreadId = getOrCreateThreadID(Ident);
5547   Value *Args[] = {Ident, ThreadId};
5548 
5549   Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_single);
5550   Instruction *EntryCall = Builder.CreateCall(EntryRTLFn, Args);
5551 
5552   Function *ExitRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_single);
5553   Instruction *ExitCall = Builder.CreateCall(ExitRTLFn, Args);
5554 
5555   auto FiniCBWrapper = [&](InsertPointTy IP) {
5556     FiniCB(IP);
5557 
5558     // The thread that executes the single region must set `DidIt` to 1.
5559     // This is used by __kmpc_copyprivate, to know if the caller is the
5560     // single thread or not.
5561     if (DidIt)
5562       Builder.CreateStore(Builder.getInt32(1), DidIt);
5563   };
5564 
5565   // generates the following:
5566   // if (__kmpc_single()) {
5567   //		.... single region ...
5568   // 		__kmpc_end_single
5569   // }
5570   // __kmpc_copyprivate
5571   // __kmpc_barrier
5572 
5573   EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCBWrapper,
5574                        /*Conditional*/ true,
5575                        /*hasFinalize*/ true);
5576 
5577   if (DidIt) {
5578     for (size_t I = 0, E = CPVars.size(); I < E; ++I)
5579       // NOTE BufSize is currently unused, so just pass 0.
5580       createCopyPrivate(LocationDescription(Builder.saveIP(), Loc.DL),
5581                         /*BufSize=*/ConstantInt::get(Int64, 0), CPVars[I],
5582                         CPFuncs[I], DidIt);
5583     // NOTE __kmpc_copyprivate already inserts a barrier
5584   } else if (!IsNowait)
5585     createBarrier(LocationDescription(Builder.saveIP(), Loc.DL),
5586                   omp::Directive::OMPD_unknown, /* ForceSimpleCall */ false,
5587                   /* CheckCancelFlag */ false);
5588   return Builder.saveIP();
5589 }
5590 
createCritical(const LocationDescription & Loc,BodyGenCallbackTy BodyGenCB,FinalizeCallbackTy FiniCB,StringRef CriticalName,Value * HintInst)5591 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createCritical(
5592     const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
5593     FinalizeCallbackTy FiniCB, StringRef CriticalName, Value *HintInst) {
5594 
5595   if (!updateToLocation(Loc))
5596     return Loc.IP;
5597 
5598   Directive OMPD = Directive::OMPD_critical;
5599   uint32_t SrcLocStrSize;
5600   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5601   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5602   Value *ThreadId = getOrCreateThreadID(Ident);
5603   Value *LockVar = getOMPCriticalRegionLock(CriticalName);
5604   Value *Args[] = {Ident, ThreadId, LockVar};
5605 
5606   SmallVector<llvm::Value *, 4> EnterArgs(std::begin(Args), std::end(Args));
5607   Function *RTFn = nullptr;
5608   if (HintInst) {
5609     // Add Hint to entry Args and create call
5610     EnterArgs.push_back(HintInst);
5611     RTFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_critical_with_hint);
5612   } else {
5613     RTFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_critical);
5614   }
5615   Instruction *EntryCall = Builder.CreateCall(RTFn, EnterArgs);
5616 
5617   Function *ExitRTLFn =
5618       getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_critical);
5619   Instruction *ExitCall = Builder.CreateCall(ExitRTLFn, Args);
5620 
5621   return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
5622                               /*Conditional*/ false, /*hasFinalize*/ true);
5623 }
5624 
5625 OpenMPIRBuilder::InsertPointTy
createOrderedDepend(const LocationDescription & Loc,InsertPointTy AllocaIP,unsigned NumLoops,ArrayRef<llvm::Value * > StoreValues,const Twine & Name,bool IsDependSource)5626 OpenMPIRBuilder::createOrderedDepend(const LocationDescription &Loc,
5627                                      InsertPointTy AllocaIP, unsigned NumLoops,
5628                                      ArrayRef<llvm::Value *> StoreValues,
5629                                      const Twine &Name, bool IsDependSource) {
5630   assert(
5631       llvm::all_of(StoreValues,
5632                    [](Value *SV) { return SV->getType()->isIntegerTy(64); }) &&
5633       "OpenMP runtime requires depend vec with i64 type");
5634 
5635   if (!updateToLocation(Loc))
5636     return Loc.IP;
5637 
5638   // Allocate space for vector and generate alloc instruction.
5639   auto *ArrI64Ty = ArrayType::get(Int64, NumLoops);
5640   Builder.restoreIP(AllocaIP);
5641   AllocaInst *ArgsBase = Builder.CreateAlloca(ArrI64Ty, nullptr, Name);
5642   ArgsBase->setAlignment(Align(8));
5643   Builder.restoreIP(Loc.IP);
5644 
5645   // Store the index value with offset in depend vector.
5646   for (unsigned I = 0; I < NumLoops; ++I) {
5647     Value *DependAddrGEPIter = Builder.CreateInBoundsGEP(
5648         ArrI64Ty, ArgsBase, {Builder.getInt64(0), Builder.getInt64(I)});
5649     StoreInst *STInst = Builder.CreateStore(StoreValues[I], DependAddrGEPIter);
5650     STInst->setAlignment(Align(8));
5651   }
5652 
5653   Value *DependBaseAddrGEP = Builder.CreateInBoundsGEP(
5654       ArrI64Ty, ArgsBase, {Builder.getInt64(0), Builder.getInt64(0)});
5655 
5656   uint32_t SrcLocStrSize;
5657   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5658   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5659   Value *ThreadId = getOrCreateThreadID(Ident);
5660   Value *Args[] = {Ident, ThreadId, DependBaseAddrGEP};
5661 
5662   Function *RTLFn = nullptr;
5663   if (IsDependSource)
5664     RTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_doacross_post);
5665   else
5666     RTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_doacross_wait);
5667   Builder.CreateCall(RTLFn, Args);
5668 
5669   return Builder.saveIP();
5670 }
5671 
createOrderedThreadsSimd(const LocationDescription & Loc,BodyGenCallbackTy BodyGenCB,FinalizeCallbackTy FiniCB,bool IsThreads)5672 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createOrderedThreadsSimd(
5673     const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
5674     FinalizeCallbackTy FiniCB, bool IsThreads) {
5675   if (!updateToLocation(Loc))
5676     return Loc.IP;
5677 
5678   Directive OMPD = Directive::OMPD_ordered;
5679   Instruction *EntryCall = nullptr;
5680   Instruction *ExitCall = nullptr;
5681 
5682   if (IsThreads) {
5683     uint32_t SrcLocStrSize;
5684     Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5685     Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5686     Value *ThreadId = getOrCreateThreadID(Ident);
5687     Value *Args[] = {Ident, ThreadId};
5688 
5689     Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_ordered);
5690     EntryCall = Builder.CreateCall(EntryRTLFn, Args);
5691 
5692     Function *ExitRTLFn =
5693         getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_ordered);
5694     ExitCall = Builder.CreateCall(ExitRTLFn, Args);
5695   }
5696 
5697   return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
5698                               /*Conditional*/ false, /*hasFinalize*/ true);
5699 }
5700 
EmitOMPInlinedRegion(Directive OMPD,Instruction * EntryCall,Instruction * ExitCall,BodyGenCallbackTy BodyGenCB,FinalizeCallbackTy FiniCB,bool Conditional,bool HasFinalize,bool IsCancellable)5701 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::EmitOMPInlinedRegion(
5702     Directive OMPD, Instruction *EntryCall, Instruction *ExitCall,
5703     BodyGenCallbackTy BodyGenCB, FinalizeCallbackTy FiniCB, bool Conditional,
5704     bool HasFinalize, bool IsCancellable) {
5705 
5706   if (HasFinalize)
5707     FinalizationStack.push_back({FiniCB, OMPD, IsCancellable});
5708 
5709   // Create inlined region's entry and body blocks, in preparation
5710   // for conditional creation
5711   BasicBlock *EntryBB = Builder.GetInsertBlock();
5712   Instruction *SplitPos = EntryBB->getTerminator();
5713   if (!isa_and_nonnull<BranchInst>(SplitPos))
5714     SplitPos = new UnreachableInst(Builder.getContext(), EntryBB);
5715   BasicBlock *ExitBB = EntryBB->splitBasicBlock(SplitPos, "omp_region.end");
5716   BasicBlock *FiniBB =
5717       EntryBB->splitBasicBlock(EntryBB->getTerminator(), "omp_region.finalize");
5718 
5719   Builder.SetInsertPoint(EntryBB->getTerminator());
5720   emitCommonDirectiveEntry(OMPD, EntryCall, ExitBB, Conditional);
5721 
5722   // generate body
5723   BodyGenCB(/* AllocaIP */ InsertPointTy(),
5724             /* CodeGenIP */ Builder.saveIP());
5725 
5726   // emit exit call and do any needed finalization.
5727   auto FinIP = InsertPointTy(FiniBB, FiniBB->getFirstInsertionPt());
5728   assert(FiniBB->getTerminator()->getNumSuccessors() == 1 &&
5729          FiniBB->getTerminator()->getSuccessor(0) == ExitBB &&
5730          "Unexpected control flow graph state!!");
5731   emitCommonDirectiveExit(OMPD, FinIP, ExitCall, HasFinalize);
5732   assert(FiniBB->getUniquePredecessor()->getUniqueSuccessor() == FiniBB &&
5733          "Unexpected Control Flow State!");
5734   MergeBlockIntoPredecessor(FiniBB);
5735 
5736   // If we are skipping the region of a non conditional, remove the exit
5737   // block, and clear the builder's insertion point.
5738   assert(SplitPos->getParent() == ExitBB &&
5739          "Unexpected Insertion point location!");
5740   auto merged = MergeBlockIntoPredecessor(ExitBB);
5741   BasicBlock *ExitPredBB = SplitPos->getParent();
5742   auto InsertBB = merged ? ExitPredBB : ExitBB;
5743   if (!isa_and_nonnull<BranchInst>(SplitPos))
5744     SplitPos->eraseFromParent();
5745   Builder.SetInsertPoint(InsertBB);
5746 
5747   return Builder.saveIP();
5748 }
5749 
emitCommonDirectiveEntry(Directive OMPD,Value * EntryCall,BasicBlock * ExitBB,bool Conditional)5750 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitCommonDirectiveEntry(
5751     Directive OMPD, Value *EntryCall, BasicBlock *ExitBB, bool Conditional) {
5752   // if nothing to do, Return current insertion point.
5753   if (!Conditional || !EntryCall)
5754     return Builder.saveIP();
5755 
5756   BasicBlock *EntryBB = Builder.GetInsertBlock();
5757   Value *CallBool = Builder.CreateIsNotNull(EntryCall);
5758   auto *ThenBB = BasicBlock::Create(M.getContext(), "omp_region.body");
5759   auto *UI = new UnreachableInst(Builder.getContext(), ThenBB);
5760 
5761   // Emit thenBB and set the Builder's insertion point there for
5762   // body generation next. Place the block after the current block.
5763   Function *CurFn = EntryBB->getParent();
5764   CurFn->insert(std::next(EntryBB->getIterator()), ThenBB);
5765 
5766   // Move Entry branch to end of ThenBB, and replace with conditional
5767   // branch (If-stmt)
5768   Instruction *EntryBBTI = EntryBB->getTerminator();
5769   Builder.CreateCondBr(CallBool, ThenBB, ExitBB);
5770   EntryBBTI->removeFromParent();
5771   Builder.SetInsertPoint(UI);
5772   Builder.Insert(EntryBBTI);
5773   UI->eraseFromParent();
5774   Builder.SetInsertPoint(ThenBB->getTerminator());
5775 
5776   // return an insertion point to ExitBB.
5777   return IRBuilder<>::InsertPoint(ExitBB, ExitBB->getFirstInsertionPt());
5778 }
5779 
emitCommonDirectiveExit(omp::Directive OMPD,InsertPointTy FinIP,Instruction * ExitCall,bool HasFinalize)5780 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitCommonDirectiveExit(
5781     omp::Directive OMPD, InsertPointTy FinIP, Instruction *ExitCall,
5782     bool HasFinalize) {
5783 
5784   Builder.restoreIP(FinIP);
5785 
5786   // If there is finalization to do, emit it before the exit call
5787   if (HasFinalize) {
5788     assert(!FinalizationStack.empty() &&
5789            "Unexpected finalization stack state!");
5790 
5791     FinalizationInfo Fi = FinalizationStack.pop_back_val();
5792     assert(Fi.DK == OMPD && "Unexpected Directive for Finalization call!");
5793 
5794     Fi.FiniCB(FinIP);
5795 
5796     BasicBlock *FiniBB = FinIP.getBlock();
5797     Instruction *FiniBBTI = FiniBB->getTerminator();
5798 
5799     // set Builder IP for call creation
5800     Builder.SetInsertPoint(FiniBBTI);
5801   }
5802 
5803   if (!ExitCall)
5804     return Builder.saveIP();
5805 
5806   // place the Exitcall as last instruction before Finalization block terminator
5807   ExitCall->removeFromParent();
5808   Builder.Insert(ExitCall);
5809 
5810   return IRBuilder<>::InsertPoint(ExitCall->getParent(),
5811                                   ExitCall->getIterator());
5812 }
5813 
createCopyinClauseBlocks(InsertPointTy IP,Value * MasterAddr,Value * PrivateAddr,llvm::IntegerType * IntPtrTy,bool BranchtoEnd)5814 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createCopyinClauseBlocks(
5815     InsertPointTy IP, Value *MasterAddr, Value *PrivateAddr,
5816     llvm::IntegerType *IntPtrTy, bool BranchtoEnd) {
5817   if (!IP.isSet())
5818     return IP;
5819 
5820   IRBuilder<>::InsertPointGuard IPG(Builder);
5821 
5822   // creates the following CFG structure
5823   //	   OMP_Entry : (MasterAddr != PrivateAddr)?
5824   //       F     T
5825   //       |      \
5826   //       |     copin.not.master
5827   //       |      /
5828   //       v     /
5829   //   copyin.not.master.end
5830   //		     |
5831   //         v
5832   //   OMP.Entry.Next
5833 
5834   BasicBlock *OMP_Entry = IP.getBlock();
5835   Function *CurFn = OMP_Entry->getParent();
5836   BasicBlock *CopyBegin =
5837       BasicBlock::Create(M.getContext(), "copyin.not.master", CurFn);
5838   BasicBlock *CopyEnd = nullptr;
5839 
5840   // If entry block is terminated, split to preserve the branch to following
5841   // basic block (i.e. OMP.Entry.Next), otherwise, leave everything as is.
5842   if (isa_and_nonnull<BranchInst>(OMP_Entry->getTerminator())) {
5843     CopyEnd = OMP_Entry->splitBasicBlock(OMP_Entry->getTerminator(),
5844                                          "copyin.not.master.end");
5845     OMP_Entry->getTerminator()->eraseFromParent();
5846   } else {
5847     CopyEnd =
5848         BasicBlock::Create(M.getContext(), "copyin.not.master.end", CurFn);
5849   }
5850 
5851   Builder.SetInsertPoint(OMP_Entry);
5852   Value *MasterPtr = Builder.CreatePtrToInt(MasterAddr, IntPtrTy);
5853   Value *PrivatePtr = Builder.CreatePtrToInt(PrivateAddr, IntPtrTy);
5854   Value *cmp = Builder.CreateICmpNE(MasterPtr, PrivatePtr);
5855   Builder.CreateCondBr(cmp, CopyBegin, CopyEnd);
5856 
5857   Builder.SetInsertPoint(CopyBegin);
5858   if (BranchtoEnd)
5859     Builder.SetInsertPoint(Builder.CreateBr(CopyEnd));
5860 
5861   return Builder.saveIP();
5862 }
5863 
createOMPAlloc(const LocationDescription & Loc,Value * Size,Value * Allocator,std::string Name)5864 CallInst *OpenMPIRBuilder::createOMPAlloc(const LocationDescription &Loc,
5865                                           Value *Size, Value *Allocator,
5866                                           std::string Name) {
5867   IRBuilder<>::InsertPointGuard IPG(Builder);
5868   updateToLocation(Loc);
5869 
5870   uint32_t SrcLocStrSize;
5871   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5872   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5873   Value *ThreadId = getOrCreateThreadID(Ident);
5874   Value *Args[] = {ThreadId, Size, Allocator};
5875 
5876   Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_alloc);
5877 
5878   return Builder.CreateCall(Fn, Args, Name);
5879 }
5880 
createOMPFree(const LocationDescription & Loc,Value * Addr,Value * Allocator,std::string Name)5881 CallInst *OpenMPIRBuilder::createOMPFree(const LocationDescription &Loc,
5882                                          Value *Addr, Value *Allocator,
5883                                          std::string Name) {
5884   IRBuilder<>::InsertPointGuard IPG(Builder);
5885   updateToLocation(Loc);
5886 
5887   uint32_t SrcLocStrSize;
5888   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5889   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5890   Value *ThreadId = getOrCreateThreadID(Ident);
5891   Value *Args[] = {ThreadId, Addr, Allocator};
5892   Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_free);
5893   return Builder.CreateCall(Fn, Args, Name);
5894 }
5895 
createOMPInteropInit(const LocationDescription & Loc,Value * InteropVar,omp::OMPInteropType InteropType,Value * Device,Value * NumDependences,Value * DependenceAddress,bool HaveNowaitClause)5896 CallInst *OpenMPIRBuilder::createOMPInteropInit(
5897     const LocationDescription &Loc, Value *InteropVar,
5898     omp::OMPInteropType InteropType, Value *Device, Value *NumDependences,
5899     Value *DependenceAddress, bool HaveNowaitClause) {
5900   IRBuilder<>::InsertPointGuard IPG(Builder);
5901   updateToLocation(Loc);
5902 
5903   uint32_t SrcLocStrSize;
5904   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5905   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5906   Value *ThreadId = getOrCreateThreadID(Ident);
5907   if (Device == nullptr)
5908     Device = ConstantInt::get(Int32, -1);
5909   Constant *InteropTypeVal = ConstantInt::get(Int32, (int)InteropType);
5910   if (NumDependences == nullptr) {
5911     NumDependences = ConstantInt::get(Int32, 0);
5912     PointerType *PointerTypeVar = PointerType::getUnqual(M.getContext());
5913     DependenceAddress = ConstantPointerNull::get(PointerTypeVar);
5914   }
5915   Value *HaveNowaitClauseVal = ConstantInt::get(Int32, HaveNowaitClause);
5916   Value *Args[] = {
5917       Ident,  ThreadId,       InteropVar,        InteropTypeVal,
5918       Device, NumDependences, DependenceAddress, HaveNowaitClauseVal};
5919 
5920   Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___tgt_interop_init);
5921 
5922   return Builder.CreateCall(Fn, Args);
5923 }
5924 
createOMPInteropDestroy(const LocationDescription & Loc,Value * InteropVar,Value * Device,Value * NumDependences,Value * DependenceAddress,bool HaveNowaitClause)5925 CallInst *OpenMPIRBuilder::createOMPInteropDestroy(
5926     const LocationDescription &Loc, Value *InteropVar, Value *Device,
5927     Value *NumDependences, Value *DependenceAddress, bool HaveNowaitClause) {
5928   IRBuilder<>::InsertPointGuard IPG(Builder);
5929   updateToLocation(Loc);
5930 
5931   uint32_t SrcLocStrSize;
5932   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5933   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5934   Value *ThreadId = getOrCreateThreadID(Ident);
5935   if (Device == nullptr)
5936     Device = ConstantInt::get(Int32, -1);
5937   if (NumDependences == nullptr) {
5938     NumDependences = ConstantInt::get(Int32, 0);
5939     PointerType *PointerTypeVar = PointerType::getUnqual(M.getContext());
5940     DependenceAddress = ConstantPointerNull::get(PointerTypeVar);
5941   }
5942   Value *HaveNowaitClauseVal = ConstantInt::get(Int32, HaveNowaitClause);
5943   Value *Args[] = {
5944       Ident,          ThreadId,          InteropVar,         Device,
5945       NumDependences, DependenceAddress, HaveNowaitClauseVal};
5946 
5947   Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___tgt_interop_destroy);
5948 
5949   return Builder.CreateCall(Fn, Args);
5950 }
5951 
createOMPInteropUse(const LocationDescription & Loc,Value * InteropVar,Value * Device,Value * NumDependences,Value * DependenceAddress,bool HaveNowaitClause)5952 CallInst *OpenMPIRBuilder::createOMPInteropUse(const LocationDescription &Loc,
5953                                                Value *InteropVar, Value *Device,
5954                                                Value *NumDependences,
5955                                                Value *DependenceAddress,
5956                                                bool HaveNowaitClause) {
5957   IRBuilder<>::InsertPointGuard IPG(Builder);
5958   updateToLocation(Loc);
5959   uint32_t SrcLocStrSize;
5960   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5961   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5962   Value *ThreadId = getOrCreateThreadID(Ident);
5963   if (Device == nullptr)
5964     Device = ConstantInt::get(Int32, -1);
5965   if (NumDependences == nullptr) {
5966     NumDependences = ConstantInt::get(Int32, 0);
5967     PointerType *PointerTypeVar = PointerType::getUnqual(M.getContext());
5968     DependenceAddress = ConstantPointerNull::get(PointerTypeVar);
5969   }
5970   Value *HaveNowaitClauseVal = ConstantInt::get(Int32, HaveNowaitClause);
5971   Value *Args[] = {
5972       Ident,          ThreadId,          InteropVar,         Device,
5973       NumDependences, DependenceAddress, HaveNowaitClauseVal};
5974 
5975   Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___tgt_interop_use);
5976 
5977   return Builder.CreateCall(Fn, Args);
5978 }
5979 
createCachedThreadPrivate(const LocationDescription & Loc,llvm::Value * Pointer,llvm::ConstantInt * Size,const llvm::Twine & Name)5980 CallInst *OpenMPIRBuilder::createCachedThreadPrivate(
5981     const LocationDescription &Loc, llvm::Value *Pointer,
5982     llvm::ConstantInt *Size, const llvm::Twine &Name) {
5983   IRBuilder<>::InsertPointGuard IPG(Builder);
5984   updateToLocation(Loc);
5985 
5986   uint32_t SrcLocStrSize;
5987   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5988   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5989   Value *ThreadId = getOrCreateThreadID(Ident);
5990   Constant *ThreadPrivateCache =
5991       getOrCreateInternalVariable(Int8PtrPtr, Name.str());
5992   llvm::Value *Args[] = {Ident, ThreadId, Pointer, Size, ThreadPrivateCache};
5993 
5994   Function *Fn =
5995       getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_threadprivate_cached);
5996 
5997   return Builder.CreateCall(Fn, Args);
5998 }
5999 
6000 OpenMPIRBuilder::InsertPointTy
createTargetInit(const LocationDescription & Loc,bool IsSPMD,int32_t MinThreadsVal,int32_t MaxThreadsVal,int32_t MinTeamsVal,int32_t MaxTeamsVal)6001 OpenMPIRBuilder::createTargetInit(const LocationDescription &Loc, bool IsSPMD,
6002                                   int32_t MinThreadsVal, int32_t MaxThreadsVal,
6003                                   int32_t MinTeamsVal, int32_t MaxTeamsVal) {
6004   if (!updateToLocation(Loc))
6005     return Loc.IP;
6006 
6007   uint32_t SrcLocStrSize;
6008   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
6009   Constant *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
6010   Constant *IsSPMDVal = ConstantInt::getSigned(
6011       Int8, IsSPMD ? OMP_TGT_EXEC_MODE_SPMD : OMP_TGT_EXEC_MODE_GENERIC);
6012   Constant *UseGenericStateMachineVal = ConstantInt::getSigned(Int8, !IsSPMD);
6013   Constant *MayUseNestedParallelismVal = ConstantInt::getSigned(Int8, true);
6014   Constant *DebugIndentionLevelVal = ConstantInt::getSigned(Int16, 0);
6015 
6016   Function *Kernel = Builder.GetInsertBlock()->getParent();
6017 
6018   // Manifest the launch configuration in the metadata matching the kernel
6019   // environment.
6020   if (MinTeamsVal > 1 || MaxTeamsVal > 0)
6021     writeTeamsForKernel(T, *Kernel, MinTeamsVal, MaxTeamsVal);
6022 
6023   // For max values, < 0 means unset, == 0 means set but unknown.
6024   if (MaxThreadsVal < 0)
6025     MaxThreadsVal = std::max(
6026         int32_t(getGridValue(T, Kernel).GV_Default_WG_Size), MinThreadsVal);
6027 
6028   if (MaxThreadsVal > 0)
6029     writeThreadBoundsForKernel(T, *Kernel, MinThreadsVal, MaxThreadsVal);
6030 
6031   Constant *MinThreads = ConstantInt::getSigned(Int32, MinThreadsVal);
6032   Constant *MaxThreads = ConstantInt::getSigned(Int32, MaxThreadsVal);
6033   Constant *MinTeams = ConstantInt::getSigned(Int32, MinTeamsVal);
6034   Constant *MaxTeams = ConstantInt::getSigned(Int32, MaxTeamsVal);
6035   Constant *ReductionDataSize = ConstantInt::getSigned(Int32, 0);
6036   Constant *ReductionBufferLength = ConstantInt::getSigned(Int32, 0);
6037 
6038   // We need to strip the debug prefix to get the correct kernel name.
6039   StringRef KernelName = Kernel->getName();
6040   const std::string DebugPrefix = "_debug__";
6041   if (KernelName.ends_with(DebugPrefix))
6042     KernelName = KernelName.drop_back(DebugPrefix.length());
6043 
6044   Function *Fn = getOrCreateRuntimeFunctionPtr(
6045       omp::RuntimeFunction::OMPRTL___kmpc_target_init);
6046   const DataLayout &DL = Fn->getDataLayout();
6047 
6048   Twine DynamicEnvironmentName = KernelName + "_dynamic_environment";
6049   Constant *DynamicEnvironmentInitializer =
6050       ConstantStruct::get(DynamicEnvironment, {DebugIndentionLevelVal});
6051   GlobalVariable *DynamicEnvironmentGV = new GlobalVariable(
6052       M, DynamicEnvironment, /*IsConstant=*/false, GlobalValue::WeakODRLinkage,
6053       DynamicEnvironmentInitializer, DynamicEnvironmentName,
6054       /*InsertBefore=*/nullptr, GlobalValue::NotThreadLocal,
6055       DL.getDefaultGlobalsAddressSpace());
6056   DynamicEnvironmentGV->setVisibility(GlobalValue::ProtectedVisibility);
6057 
6058   Constant *DynamicEnvironment =
6059       DynamicEnvironmentGV->getType() == DynamicEnvironmentPtr
6060           ? DynamicEnvironmentGV
6061           : ConstantExpr::getAddrSpaceCast(DynamicEnvironmentGV,
6062                                            DynamicEnvironmentPtr);
6063 
6064   Constant *ConfigurationEnvironmentInitializer = ConstantStruct::get(
6065       ConfigurationEnvironment, {
6066                                     UseGenericStateMachineVal,
6067                                     MayUseNestedParallelismVal,
6068                                     IsSPMDVal,
6069                                     MinThreads,
6070                                     MaxThreads,
6071                                     MinTeams,
6072                                     MaxTeams,
6073                                     ReductionDataSize,
6074                                     ReductionBufferLength,
6075                                 });
6076   Constant *KernelEnvironmentInitializer = ConstantStruct::get(
6077       KernelEnvironment, {
6078                              ConfigurationEnvironmentInitializer,
6079                              Ident,
6080                              DynamicEnvironment,
6081                          });
6082   std::string KernelEnvironmentName =
6083       (KernelName + "_kernel_environment").str();
6084   GlobalVariable *KernelEnvironmentGV = new GlobalVariable(
6085       M, KernelEnvironment, /*IsConstant=*/true, GlobalValue::WeakODRLinkage,
6086       KernelEnvironmentInitializer, KernelEnvironmentName,
6087       /*InsertBefore=*/nullptr, GlobalValue::NotThreadLocal,
6088       DL.getDefaultGlobalsAddressSpace());
6089   KernelEnvironmentGV->setVisibility(GlobalValue::ProtectedVisibility);
6090 
6091   Constant *KernelEnvironment =
6092       KernelEnvironmentGV->getType() == KernelEnvironmentPtr
6093           ? KernelEnvironmentGV
6094           : ConstantExpr::getAddrSpaceCast(KernelEnvironmentGV,
6095                                            KernelEnvironmentPtr);
6096   Value *KernelLaunchEnvironment = Kernel->getArg(0);
6097   CallInst *ThreadKind =
6098       Builder.CreateCall(Fn, {KernelEnvironment, KernelLaunchEnvironment});
6099 
6100   Value *ExecUserCode = Builder.CreateICmpEQ(
6101       ThreadKind, ConstantInt::get(ThreadKind->getType(), -1),
6102       "exec_user_code");
6103 
6104   // ThreadKind = __kmpc_target_init(...)
6105   // if (ThreadKind == -1)
6106   //   user_code
6107   // else
6108   //   return;
6109 
6110   auto *UI = Builder.CreateUnreachable();
6111   BasicBlock *CheckBB = UI->getParent();
6112   BasicBlock *UserCodeEntryBB = CheckBB->splitBasicBlock(UI, "user_code.entry");
6113 
6114   BasicBlock *WorkerExitBB = BasicBlock::Create(
6115       CheckBB->getContext(), "worker.exit", CheckBB->getParent());
6116   Builder.SetInsertPoint(WorkerExitBB);
6117   Builder.CreateRetVoid();
6118 
6119   auto *CheckBBTI = CheckBB->getTerminator();
6120   Builder.SetInsertPoint(CheckBBTI);
6121   Builder.CreateCondBr(ExecUserCode, UI->getParent(), WorkerExitBB);
6122 
6123   CheckBBTI->eraseFromParent();
6124   UI->eraseFromParent();
6125 
6126   // Continue in the "user_code" block, see diagram above and in
6127   // openmp/libomptarget/deviceRTLs/common/include/target.h .
6128   return InsertPointTy(UserCodeEntryBB, UserCodeEntryBB->getFirstInsertionPt());
6129 }
6130 
createTargetDeinit(const LocationDescription & Loc,int32_t TeamsReductionDataSize,int32_t TeamsReductionBufferLength)6131 void OpenMPIRBuilder::createTargetDeinit(const LocationDescription &Loc,
6132                                          int32_t TeamsReductionDataSize,
6133                                          int32_t TeamsReductionBufferLength) {
6134   if (!updateToLocation(Loc))
6135     return;
6136 
6137   Function *Fn = getOrCreateRuntimeFunctionPtr(
6138       omp::RuntimeFunction::OMPRTL___kmpc_target_deinit);
6139 
6140   Builder.CreateCall(Fn, {});
6141 
6142   if (!TeamsReductionBufferLength || !TeamsReductionDataSize)
6143     return;
6144 
6145   Function *Kernel = Builder.GetInsertBlock()->getParent();
6146   // We need to strip the debug prefix to get the correct kernel name.
6147   StringRef KernelName = Kernel->getName();
6148   const std::string DebugPrefix = "_debug__";
6149   if (KernelName.ends_with(DebugPrefix))
6150     KernelName = KernelName.drop_back(DebugPrefix.length());
6151   auto *KernelEnvironmentGV =
6152       M.getNamedGlobal((KernelName + "_kernel_environment").str());
6153   assert(KernelEnvironmentGV && "Expected kernel environment global\n");
6154   auto *KernelEnvironmentInitializer = KernelEnvironmentGV->getInitializer();
6155   auto *NewInitializer = ConstantFoldInsertValueInstruction(
6156       KernelEnvironmentInitializer,
6157       ConstantInt::get(Int32, TeamsReductionDataSize), {0, 7});
6158   NewInitializer = ConstantFoldInsertValueInstruction(
6159       NewInitializer, ConstantInt::get(Int32, TeamsReductionBufferLength),
6160       {0, 8});
6161   KernelEnvironmentGV->setInitializer(NewInitializer);
6162 }
6163 
getNVPTXMDNode(Function & Kernel,StringRef Name)6164 static MDNode *getNVPTXMDNode(Function &Kernel, StringRef Name) {
6165   Module &M = *Kernel.getParent();
6166   NamedMDNode *MD = M.getOrInsertNamedMetadata("nvvm.annotations");
6167   for (auto *Op : MD->operands()) {
6168     if (Op->getNumOperands() != 3)
6169       continue;
6170     auto *KernelOp = dyn_cast<ConstantAsMetadata>(Op->getOperand(0));
6171     if (!KernelOp || KernelOp->getValue() != &Kernel)
6172       continue;
6173     auto *Prop = dyn_cast<MDString>(Op->getOperand(1));
6174     if (!Prop || Prop->getString() != Name)
6175       continue;
6176     return Op;
6177   }
6178   return nullptr;
6179 }
6180 
updateNVPTXMetadata(Function & Kernel,StringRef Name,int32_t Value,bool Min)6181 static void updateNVPTXMetadata(Function &Kernel, StringRef Name, int32_t Value,
6182                                 bool Min) {
6183   // Update the "maxntidx" metadata for NVIDIA, or add it.
6184   MDNode *ExistingOp = getNVPTXMDNode(Kernel, Name);
6185   if (ExistingOp) {
6186     auto *OldVal = cast<ConstantAsMetadata>(ExistingOp->getOperand(2));
6187     int32_t OldLimit = cast<ConstantInt>(OldVal->getValue())->getZExtValue();
6188     ExistingOp->replaceOperandWith(
6189         2, ConstantAsMetadata::get(ConstantInt::get(
6190                OldVal->getValue()->getType(),
6191                Min ? std::min(OldLimit, Value) : std::max(OldLimit, Value))));
6192   } else {
6193     LLVMContext &Ctx = Kernel.getContext();
6194     Metadata *MDVals[] = {ConstantAsMetadata::get(&Kernel),
6195                           MDString::get(Ctx, Name),
6196                           ConstantAsMetadata::get(
6197                               ConstantInt::get(Type::getInt32Ty(Ctx), Value))};
6198     // Append metadata to nvvm.annotations
6199     Module &M = *Kernel.getParent();
6200     NamedMDNode *MD = M.getOrInsertNamedMetadata("nvvm.annotations");
6201     MD->addOperand(MDNode::get(Ctx, MDVals));
6202   }
6203 }
6204 
6205 std::pair<int32_t, int32_t>
readThreadBoundsForKernel(const Triple & T,Function & Kernel)6206 OpenMPIRBuilder::readThreadBoundsForKernel(const Triple &T, Function &Kernel) {
6207   int32_t ThreadLimit =
6208       Kernel.getFnAttributeAsParsedInteger("omp_target_thread_limit");
6209 
6210   if (T.isAMDGPU()) {
6211     const auto &Attr = Kernel.getFnAttribute("amdgpu-flat-work-group-size");
6212     if (!Attr.isValid() || !Attr.isStringAttribute())
6213       return {0, ThreadLimit};
6214     auto [LBStr, UBStr] = Attr.getValueAsString().split(',');
6215     int32_t LB, UB;
6216     if (!llvm::to_integer(UBStr, UB, 10))
6217       return {0, ThreadLimit};
6218     UB = ThreadLimit ? std::min(ThreadLimit, UB) : UB;
6219     if (!llvm::to_integer(LBStr, LB, 10))
6220       return {0, UB};
6221     return {LB, UB};
6222   }
6223 
6224   if (MDNode *ExistingOp = getNVPTXMDNode(Kernel, "maxntidx")) {
6225     auto *OldVal = cast<ConstantAsMetadata>(ExistingOp->getOperand(2));
6226     int32_t UB = cast<ConstantInt>(OldVal->getValue())->getZExtValue();
6227     return {0, ThreadLimit ? std::min(ThreadLimit, UB) : UB};
6228   }
6229   return {0, ThreadLimit};
6230 }
6231 
writeThreadBoundsForKernel(const Triple & T,Function & Kernel,int32_t LB,int32_t UB)6232 void OpenMPIRBuilder::writeThreadBoundsForKernel(const Triple &T,
6233                                                  Function &Kernel, int32_t LB,
6234                                                  int32_t UB) {
6235   Kernel.addFnAttr("omp_target_thread_limit", std::to_string(UB));
6236 
6237   if (T.isAMDGPU()) {
6238     Kernel.addFnAttr("amdgpu-flat-work-group-size",
6239                      llvm::utostr(LB) + "," + llvm::utostr(UB));
6240     return;
6241   }
6242 
6243   updateNVPTXMetadata(Kernel, "maxntidx", UB, true);
6244 }
6245 
6246 std::pair<int32_t, int32_t>
readTeamBoundsForKernel(const Triple &,Function & Kernel)6247 OpenMPIRBuilder::readTeamBoundsForKernel(const Triple &, Function &Kernel) {
6248   // TODO: Read from backend annotations if available.
6249   return {0, Kernel.getFnAttributeAsParsedInteger("omp_target_num_teams")};
6250 }
6251 
writeTeamsForKernel(const Triple & T,Function & Kernel,int32_t LB,int32_t UB)6252 void OpenMPIRBuilder::writeTeamsForKernel(const Triple &T, Function &Kernel,
6253                                           int32_t LB, int32_t UB) {
6254   if (T.isNVPTX())
6255     if (UB > 0)
6256       updateNVPTXMetadata(Kernel, "maxclusterrank", UB, true);
6257   if (T.isAMDGPU())
6258     Kernel.addFnAttr("amdgpu-max-num-workgroups", llvm::utostr(LB) + ",1,1");
6259 
6260   Kernel.addFnAttr("omp_target_num_teams", std::to_string(LB));
6261 }
6262 
setOutlinedTargetRegionFunctionAttributes(Function * OutlinedFn)6263 void OpenMPIRBuilder::setOutlinedTargetRegionFunctionAttributes(
6264     Function *OutlinedFn) {
6265   if (Config.isTargetDevice()) {
6266     OutlinedFn->setLinkage(GlobalValue::WeakODRLinkage);
6267     // TODO: Determine if DSO local can be set to true.
6268     OutlinedFn->setDSOLocal(false);
6269     OutlinedFn->setVisibility(GlobalValue::ProtectedVisibility);
6270     if (T.isAMDGCN())
6271       OutlinedFn->setCallingConv(CallingConv::AMDGPU_KERNEL);
6272   }
6273 }
6274 
createOutlinedFunctionID(Function * OutlinedFn,StringRef EntryFnIDName)6275 Constant *OpenMPIRBuilder::createOutlinedFunctionID(Function *OutlinedFn,
6276                                                     StringRef EntryFnIDName) {
6277   if (Config.isTargetDevice()) {
6278     assert(OutlinedFn && "The outlined function must exist if embedded");
6279     return OutlinedFn;
6280   }
6281 
6282   return new GlobalVariable(
6283       M, Builder.getInt8Ty(), /*isConstant=*/true, GlobalValue::WeakAnyLinkage,
6284       Constant::getNullValue(Builder.getInt8Ty()), EntryFnIDName);
6285 }
6286 
createTargetRegionEntryAddr(Function * OutlinedFn,StringRef EntryFnName)6287 Constant *OpenMPIRBuilder::createTargetRegionEntryAddr(Function *OutlinedFn,
6288                                                        StringRef EntryFnName) {
6289   if (OutlinedFn)
6290     return OutlinedFn;
6291 
6292   assert(!M.getGlobalVariable(EntryFnName, true) &&
6293          "Named kernel already exists?");
6294   return new GlobalVariable(
6295       M, Builder.getInt8Ty(), /*isConstant=*/true, GlobalValue::InternalLinkage,
6296       Constant::getNullValue(Builder.getInt8Ty()), EntryFnName);
6297 }
6298 
emitTargetRegionFunction(TargetRegionEntryInfo & EntryInfo,FunctionGenCallback & GenerateFunctionCallback,bool IsOffloadEntry,Function * & OutlinedFn,Constant * & OutlinedFnID)6299 void OpenMPIRBuilder::emitTargetRegionFunction(
6300     TargetRegionEntryInfo &EntryInfo,
6301     FunctionGenCallback &GenerateFunctionCallback, bool IsOffloadEntry,
6302     Function *&OutlinedFn, Constant *&OutlinedFnID) {
6303 
6304   SmallString<64> EntryFnName;
6305   OffloadInfoManager.getTargetRegionEntryFnName(EntryFnName, EntryInfo);
6306 
6307   OutlinedFn = Config.isTargetDevice() || !Config.openMPOffloadMandatory()
6308                    ? GenerateFunctionCallback(EntryFnName)
6309                    : nullptr;
6310 
6311   // If this target outline function is not an offload entry, we don't need to
6312   // register it. This may be in the case of a false if clause, or if there are
6313   // no OpenMP targets.
6314   if (!IsOffloadEntry)
6315     return;
6316 
6317   std::string EntryFnIDName =
6318       Config.isTargetDevice()
6319           ? std::string(EntryFnName)
6320           : createPlatformSpecificName({EntryFnName, "region_id"});
6321 
6322   OutlinedFnID = registerTargetRegionFunction(EntryInfo, OutlinedFn,
6323                                               EntryFnName, EntryFnIDName);
6324 }
6325 
registerTargetRegionFunction(TargetRegionEntryInfo & EntryInfo,Function * OutlinedFn,StringRef EntryFnName,StringRef EntryFnIDName)6326 Constant *OpenMPIRBuilder::registerTargetRegionFunction(
6327     TargetRegionEntryInfo &EntryInfo, Function *OutlinedFn,
6328     StringRef EntryFnName, StringRef EntryFnIDName) {
6329   if (OutlinedFn)
6330     setOutlinedTargetRegionFunctionAttributes(OutlinedFn);
6331   auto OutlinedFnID = createOutlinedFunctionID(OutlinedFn, EntryFnIDName);
6332   auto EntryAddr = createTargetRegionEntryAddr(OutlinedFn, EntryFnName);
6333   OffloadInfoManager.registerTargetRegionEntryInfo(
6334       EntryInfo, EntryAddr, OutlinedFnID,
6335       OffloadEntriesInfoManager::OMPTargetRegionEntryTargetRegion);
6336   return OutlinedFnID;
6337 }
6338 
createTargetData(const LocationDescription & Loc,InsertPointTy AllocaIP,InsertPointTy CodeGenIP,Value * DeviceID,Value * IfCond,TargetDataInfo & Info,GenMapInfoCallbackTy GenMapInfoCB,omp::RuntimeFunction * MapperFunc,function_ref<InsertPointTy (InsertPointTy CodeGenIP,BodyGenTy BodyGenType)> BodyGenCB,function_ref<void (unsigned int,Value *)> DeviceAddrCB,function_ref<Value * (unsigned int)> CustomMapperCB,Value * SrcLocInfo)6339 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetData(
6340     const LocationDescription &Loc, InsertPointTy AllocaIP,
6341     InsertPointTy CodeGenIP, Value *DeviceID, Value *IfCond,
6342     TargetDataInfo &Info, GenMapInfoCallbackTy GenMapInfoCB,
6343     omp::RuntimeFunction *MapperFunc,
6344     function_ref<InsertPointTy(InsertPointTy CodeGenIP, BodyGenTy BodyGenType)>
6345         BodyGenCB,
6346     function_ref<void(unsigned int, Value *)> DeviceAddrCB,
6347     function_ref<Value *(unsigned int)> CustomMapperCB, Value *SrcLocInfo) {
6348   if (!updateToLocation(Loc))
6349     return InsertPointTy();
6350 
6351   // Disable TargetData CodeGen on Device pass.
6352   if (Config.IsTargetDevice.value_or(false)) {
6353     if (BodyGenCB)
6354       Builder.restoreIP(BodyGenCB(Builder.saveIP(), BodyGenTy::NoPriv));
6355     return Builder.saveIP();
6356   }
6357 
6358   Builder.restoreIP(CodeGenIP);
6359   bool IsStandAlone = !BodyGenCB;
6360   MapInfosTy *MapInfo;
6361   // Generate the code for the opening of the data environment. Capture all the
6362   // arguments of the runtime call by reference because they are used in the
6363   // closing of the region.
6364   auto BeginThenGen = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
6365     MapInfo = &GenMapInfoCB(Builder.saveIP());
6366     emitOffloadingArrays(AllocaIP, Builder.saveIP(), *MapInfo, Info,
6367                          /*IsNonContiguous=*/true, DeviceAddrCB,
6368                          CustomMapperCB);
6369 
6370     TargetDataRTArgs RTArgs;
6371     emitOffloadingArraysArgument(Builder, RTArgs, Info,
6372                                  !MapInfo->Names.empty());
6373 
6374     // Emit the number of elements in the offloading arrays.
6375     Value *PointerNum = Builder.getInt32(Info.NumberOfPtrs);
6376 
6377     // Source location for the ident struct
6378     if (!SrcLocInfo) {
6379       uint32_t SrcLocStrSize;
6380       Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
6381       SrcLocInfo = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
6382     }
6383 
6384     Value *OffloadingArgs[] = {SrcLocInfo,           DeviceID,
6385                                PointerNum,           RTArgs.BasePointersArray,
6386                                RTArgs.PointersArray, RTArgs.SizesArray,
6387                                RTArgs.MapTypesArray, RTArgs.MapNamesArray,
6388                                RTArgs.MappersArray};
6389 
6390     if (IsStandAlone) {
6391       assert(MapperFunc && "MapperFunc missing for standalone target data");
6392       Builder.CreateCall(getOrCreateRuntimeFunctionPtr(*MapperFunc),
6393                          OffloadingArgs);
6394     } else {
6395       Function *BeginMapperFunc = getOrCreateRuntimeFunctionPtr(
6396           omp::OMPRTL___tgt_target_data_begin_mapper);
6397 
6398       Builder.CreateCall(BeginMapperFunc, OffloadingArgs);
6399 
6400       for (auto DeviceMap : Info.DevicePtrInfoMap) {
6401         if (isa<AllocaInst>(DeviceMap.second.second)) {
6402           auto *LI =
6403               Builder.CreateLoad(Builder.getPtrTy(), DeviceMap.second.first);
6404           Builder.CreateStore(LI, DeviceMap.second.second);
6405         }
6406       }
6407 
6408       // If device pointer privatization is required, emit the body of the
6409       // region here. It will have to be duplicated: with and without
6410       // privatization.
6411       Builder.restoreIP(BodyGenCB(Builder.saveIP(), BodyGenTy::Priv));
6412     }
6413   };
6414 
6415   // If we need device pointer privatization, we need to emit the body of the
6416   // region with no privatization in the 'else' branch of the conditional.
6417   // Otherwise, we don't have to do anything.
6418   auto BeginElseGen = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
6419     Builder.restoreIP(BodyGenCB(Builder.saveIP(), BodyGenTy::DupNoPriv));
6420   };
6421 
6422   // Generate code for the closing of the data region.
6423   auto EndThenGen = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
6424     TargetDataRTArgs RTArgs;
6425     emitOffloadingArraysArgument(Builder, RTArgs, Info, !MapInfo->Names.empty(),
6426                                  /*ForEndCall=*/true);
6427 
6428     // Emit the number of elements in the offloading arrays.
6429     Value *PointerNum = Builder.getInt32(Info.NumberOfPtrs);
6430 
6431     // Source location for the ident struct
6432     if (!SrcLocInfo) {
6433       uint32_t SrcLocStrSize;
6434       Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
6435       SrcLocInfo = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
6436     }
6437 
6438     Value *OffloadingArgs[] = {SrcLocInfo,           DeviceID,
6439                                PointerNum,           RTArgs.BasePointersArray,
6440                                RTArgs.PointersArray, RTArgs.SizesArray,
6441                                RTArgs.MapTypesArray, RTArgs.MapNamesArray,
6442                                RTArgs.MappersArray};
6443     Function *EndMapperFunc =
6444         getOrCreateRuntimeFunctionPtr(omp::OMPRTL___tgt_target_data_end_mapper);
6445 
6446     Builder.CreateCall(EndMapperFunc, OffloadingArgs);
6447   };
6448 
6449   // We don't have to do anything to close the region if the if clause evaluates
6450   // to false.
6451   auto EndElseGen = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {};
6452 
6453   if (BodyGenCB) {
6454     if (IfCond) {
6455       emitIfClause(IfCond, BeginThenGen, BeginElseGen, AllocaIP);
6456     } else {
6457       BeginThenGen(AllocaIP, Builder.saveIP());
6458     }
6459 
6460     // If we don't require privatization of device pointers, we emit the body in
6461     // between the runtime calls. This avoids duplicating the body code.
6462     Builder.restoreIP(BodyGenCB(Builder.saveIP(), BodyGenTy::NoPriv));
6463 
6464     if (IfCond) {
6465       emitIfClause(IfCond, EndThenGen, EndElseGen, AllocaIP);
6466     } else {
6467       EndThenGen(AllocaIP, Builder.saveIP());
6468     }
6469   } else {
6470     if (IfCond) {
6471       emitIfClause(IfCond, BeginThenGen, EndElseGen, AllocaIP);
6472     } else {
6473       BeginThenGen(AllocaIP, Builder.saveIP());
6474     }
6475   }
6476 
6477   return Builder.saveIP();
6478 }
6479 
6480 FunctionCallee
createForStaticInitFunction(unsigned IVSize,bool IVSigned,bool IsGPUDistribute)6481 OpenMPIRBuilder::createForStaticInitFunction(unsigned IVSize, bool IVSigned,
6482                                              bool IsGPUDistribute) {
6483   assert((IVSize == 32 || IVSize == 64) &&
6484          "IV size is not compatible with the omp runtime");
6485   RuntimeFunction Name;
6486   if (IsGPUDistribute)
6487     Name = IVSize == 32
6488                ? (IVSigned ? omp::OMPRTL___kmpc_distribute_static_init_4
6489                            : omp::OMPRTL___kmpc_distribute_static_init_4u)
6490                : (IVSigned ? omp::OMPRTL___kmpc_distribute_static_init_8
6491                            : omp::OMPRTL___kmpc_distribute_static_init_8u);
6492   else
6493     Name = IVSize == 32 ? (IVSigned ? omp::OMPRTL___kmpc_for_static_init_4
6494                                     : omp::OMPRTL___kmpc_for_static_init_4u)
6495                         : (IVSigned ? omp::OMPRTL___kmpc_for_static_init_8
6496                                     : omp::OMPRTL___kmpc_for_static_init_8u);
6497 
6498   return getOrCreateRuntimeFunction(M, Name);
6499 }
6500 
createDispatchInitFunction(unsigned IVSize,bool IVSigned)6501 FunctionCallee OpenMPIRBuilder::createDispatchInitFunction(unsigned IVSize,
6502                                                            bool IVSigned) {
6503   assert((IVSize == 32 || IVSize == 64) &&
6504          "IV size is not compatible with the omp runtime");
6505   RuntimeFunction Name = IVSize == 32
6506                              ? (IVSigned ? omp::OMPRTL___kmpc_dispatch_init_4
6507                                          : omp::OMPRTL___kmpc_dispatch_init_4u)
6508                              : (IVSigned ? omp::OMPRTL___kmpc_dispatch_init_8
6509                                          : omp::OMPRTL___kmpc_dispatch_init_8u);
6510 
6511   return getOrCreateRuntimeFunction(M, Name);
6512 }
6513 
createDispatchNextFunction(unsigned IVSize,bool IVSigned)6514 FunctionCallee OpenMPIRBuilder::createDispatchNextFunction(unsigned IVSize,
6515                                                            bool IVSigned) {
6516   assert((IVSize == 32 || IVSize == 64) &&
6517          "IV size is not compatible with the omp runtime");
6518   RuntimeFunction Name = IVSize == 32
6519                              ? (IVSigned ? omp::OMPRTL___kmpc_dispatch_next_4
6520                                          : omp::OMPRTL___kmpc_dispatch_next_4u)
6521                              : (IVSigned ? omp::OMPRTL___kmpc_dispatch_next_8
6522                                          : omp::OMPRTL___kmpc_dispatch_next_8u);
6523 
6524   return getOrCreateRuntimeFunction(M, Name);
6525 }
6526 
createDispatchFiniFunction(unsigned IVSize,bool IVSigned)6527 FunctionCallee OpenMPIRBuilder::createDispatchFiniFunction(unsigned IVSize,
6528                                                            bool IVSigned) {
6529   assert((IVSize == 32 || IVSize == 64) &&
6530          "IV size is not compatible with the omp runtime");
6531   RuntimeFunction Name = IVSize == 32
6532                              ? (IVSigned ? omp::OMPRTL___kmpc_dispatch_fini_4
6533                                          : omp::OMPRTL___kmpc_dispatch_fini_4u)
6534                              : (IVSigned ? omp::OMPRTL___kmpc_dispatch_fini_8
6535                                          : omp::OMPRTL___kmpc_dispatch_fini_8u);
6536 
6537   return getOrCreateRuntimeFunction(M, Name);
6538 }
6539 
createDispatchDeinitFunction()6540 FunctionCallee OpenMPIRBuilder::createDispatchDeinitFunction() {
6541   return getOrCreateRuntimeFunction(M, omp::OMPRTL___kmpc_dispatch_deinit);
6542 }
6543 
createOutlinedFunction(OpenMPIRBuilder & OMPBuilder,IRBuilderBase & Builder,StringRef FuncName,SmallVectorImpl<Value * > & Inputs,OpenMPIRBuilder::TargetBodyGenCallbackTy & CBFunc,OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy & ArgAccessorFuncCB)6544 static Function *createOutlinedFunction(
6545     OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, StringRef FuncName,
6546     SmallVectorImpl<Value *> &Inputs,
6547     OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
6548     OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy &ArgAccessorFuncCB) {
6549   SmallVector<Type *> ParameterTypes;
6550   if (OMPBuilder.Config.isTargetDevice()) {
6551     // Add the "implicit" runtime argument we use to provide launch specific
6552     // information for target devices.
6553     auto *Int8PtrTy = PointerType::getUnqual(Builder.getContext());
6554     ParameterTypes.push_back(Int8PtrTy);
6555 
6556     // All parameters to target devices are passed as pointers
6557     // or i64. This assumes 64-bit address spaces/pointers.
6558     for (auto &Arg : Inputs)
6559       ParameterTypes.push_back(Arg->getType()->isPointerTy()
6560                                    ? Arg->getType()
6561                                    : Type::getInt64Ty(Builder.getContext()));
6562   } else {
6563     for (auto &Arg : Inputs)
6564       ParameterTypes.push_back(Arg->getType());
6565   }
6566 
6567   auto FuncType = FunctionType::get(Builder.getVoidTy(), ParameterTypes,
6568                                     /*isVarArg*/ false);
6569   auto Func = Function::Create(FuncType, GlobalValue::InternalLinkage, FuncName,
6570                                Builder.GetInsertBlock()->getModule());
6571 
6572   // Save insert point.
6573   auto OldInsertPoint = Builder.saveIP();
6574 
6575   // Generate the region into the function.
6576   BasicBlock *EntryBB = BasicBlock::Create(Builder.getContext(), "entry", Func);
6577   Builder.SetInsertPoint(EntryBB);
6578 
6579   // Insert target init call in the device compilation pass.
6580   if (OMPBuilder.Config.isTargetDevice())
6581     Builder.restoreIP(OMPBuilder.createTargetInit(Builder, /*IsSPMD*/ false));
6582 
6583   BasicBlock *UserCodeEntryBB = Builder.GetInsertBlock();
6584 
6585   // As we embed the user code in the middle of our target region after we
6586   // generate entry code, we must move what allocas we can into the entry
6587   // block to avoid possible breaking optimisations for device
6588   if (OMPBuilder.Config.isTargetDevice())
6589     OMPBuilder.ConstantAllocaRaiseCandidates.emplace_back(Func);
6590 
6591   // Insert target deinit call in the device compilation pass.
6592   Builder.restoreIP(CBFunc(Builder.saveIP(), Builder.saveIP()));
6593   if (OMPBuilder.Config.isTargetDevice())
6594     OMPBuilder.createTargetDeinit(Builder);
6595 
6596   // Insert return instruction.
6597   Builder.CreateRetVoid();
6598 
6599   // New Alloca IP at entry point of created device function.
6600   Builder.SetInsertPoint(EntryBB->getFirstNonPHI());
6601   auto AllocaIP = Builder.saveIP();
6602 
6603   Builder.SetInsertPoint(UserCodeEntryBB->getFirstNonPHIOrDbg());
6604 
6605   // Skip the artificial dyn_ptr on the device.
6606   const auto &ArgRange =
6607       OMPBuilder.Config.isTargetDevice()
6608           ? make_range(Func->arg_begin() + 1, Func->arg_end())
6609           : Func->args();
6610 
6611   auto ReplaceValue = [](Value *Input, Value *InputCopy, Function *Func) {
6612     // Things like GEP's can come in the form of Constants. Constants and
6613     // ConstantExpr's do not have access to the knowledge of what they're
6614     // contained in, so we must dig a little to find an instruction so we
6615     // can tell if they're used inside of the function we're outlining. We
6616     // also replace the original constant expression with a new instruction
6617     // equivalent; an instruction as it allows easy modification in the
6618     // following loop, as we can now know the constant (instruction) is
6619     // owned by our target function and replaceUsesOfWith can now be invoked
6620     // on it (cannot do this with constants it seems). A brand new one also
6621     // allows us to be cautious as it is perhaps possible the old expression
6622     // was used inside of the function but exists and is used externally
6623     // (unlikely by the nature of a Constant, but still).
6624     // NOTE: We cannot remove dead constants that have been rewritten to
6625     // instructions at this stage, we run the risk of breaking later lowering
6626     // by doing so as we could still be in the process of lowering the module
6627     // from MLIR to LLVM-IR and the MLIR lowering may still require the original
6628     // constants we have created rewritten versions of.
6629     if (auto *Const = dyn_cast<Constant>(Input))
6630       convertUsersOfConstantsToInstructions(Const, Func, false);
6631 
6632     // Collect all the instructions
6633     for (User *User : make_early_inc_range(Input->users()))
6634       if (auto *Instr = dyn_cast<Instruction>(User))
6635         if (Instr->getFunction() == Func)
6636           Instr->replaceUsesOfWith(Input, InputCopy);
6637   };
6638 
6639   SmallVector<std::pair<Value *, Value *>> DeferredReplacement;
6640 
6641   // Rewrite uses of input valus to parameters.
6642   for (auto InArg : zip(Inputs, ArgRange)) {
6643     Value *Input = std::get<0>(InArg);
6644     Argument &Arg = std::get<1>(InArg);
6645     Value *InputCopy = nullptr;
6646 
6647     Builder.restoreIP(
6648         ArgAccessorFuncCB(Arg, Input, InputCopy, AllocaIP, Builder.saveIP()));
6649 
6650     // In certain cases a Global may be set up for replacement, however, this
6651     // Global may be used in multiple arguments to the kernel, just segmented
6652     // apart, for example, if we have a global array, that is sectioned into
6653     // multiple mappings (technically not legal in OpenMP, but there is a case
6654     // in Fortran for Common Blocks where this is neccesary), we will end up
6655     // with GEP's into this array inside the kernel, that refer to the Global
6656     // but are technically seperate arguments to the kernel for all intents and
6657     // purposes. If we have mapped a segment that requires a GEP into the 0-th
6658     // index, it will fold into an referal to the Global, if we then encounter
6659     // this folded GEP during replacement all of the references to the
6660     // Global in the kernel will be replaced with the argument we have generated
6661     // that corresponds to it, including any other GEP's that refer to the
6662     // Global that may be other arguments. This will invalidate all of the other
6663     // preceding mapped arguments that refer to the same global that may be
6664     // seperate segments. To prevent this, we defer global processing until all
6665     // other processing has been performed.
6666     if (llvm::isa<llvm::GlobalValue>(std::get<0>(InArg)) ||
6667         llvm::isa<llvm::GlobalObject>(std::get<0>(InArg)) ||
6668         llvm::isa<llvm::GlobalVariable>(std::get<0>(InArg))) {
6669       DeferredReplacement.push_back(std::make_pair(Input, InputCopy));
6670       continue;
6671     }
6672 
6673     ReplaceValue(Input, InputCopy, Func);
6674   }
6675 
6676   // Replace all of our deferred Input values, currently just Globals.
6677   for (auto Deferred : DeferredReplacement)
6678     ReplaceValue(std::get<0>(Deferred), std::get<1>(Deferred), Func);
6679 
6680   // Restore insert point.
6681   Builder.restoreIP(OldInsertPoint);
6682 
6683   return Func;
6684 }
6685 
6686 /// Create an entry point for a target task with the following.
6687 /// It'll have the following signature
6688 /// void @.omp_target_task_proxy_func(i32 %thread.id, ptr %task)
6689 /// This function is called from emitTargetTask once the
6690 /// code to launch the target kernel has been outlined already.
emitTargetTaskProxyFunction(OpenMPIRBuilder & OMPBuilder,IRBuilderBase & Builder,CallInst * StaleCI)6691 static Function *emitTargetTaskProxyFunction(OpenMPIRBuilder &OMPBuilder,
6692                                              IRBuilderBase &Builder,
6693                                              CallInst *StaleCI) {
6694   Module &M = OMPBuilder.M;
6695   // KernelLaunchFunction is the target launch function, i.e.
6696   // the function that sets up kernel arguments and calls
6697   // __tgt_target_kernel to launch the kernel on the device.
6698   //
6699   Function *KernelLaunchFunction = StaleCI->getCalledFunction();
6700 
6701   // StaleCI is the CallInst which is the call to the outlined
6702   // target kernel launch function. If there are values that the
6703   // outlined function uses then these are aggregated into a structure
6704   // which is passed as the second argument. If not, then there's
6705   // only one argument, the threadID. So, StaleCI can be
6706   //
6707   // %structArg = alloca { ptr, ptr }, align 8
6708   // %gep_ = getelementptr { ptr, ptr }, ptr %structArg, i32 0, i32 0
6709   // store ptr %20, ptr %gep_, align 8
6710   // %gep_8 = getelementptr { ptr, ptr }, ptr %structArg, i32 0, i32 1
6711   // store ptr %21, ptr %gep_8, align 8
6712   // call void @_QQmain..omp_par.1(i32 %global.tid.val6, ptr %structArg)
6713   //
6714   // OR
6715   //
6716   // call void @_QQmain..omp_par.1(i32 %global.tid.val6)
6717   OpenMPIRBuilder::InsertPointTy IP(StaleCI->getParent(),
6718                                     StaleCI->getIterator());
6719   LLVMContext &Ctx = StaleCI->getParent()->getContext();
6720   Type *ThreadIDTy = Type::getInt32Ty(Ctx);
6721   Type *TaskPtrTy = OMPBuilder.TaskPtr;
6722   Type *TaskTy = OMPBuilder.Task;
6723   auto ProxyFnTy =
6724       FunctionType::get(Builder.getVoidTy(), {ThreadIDTy, TaskPtrTy},
6725                         /* isVarArg */ false);
6726   auto ProxyFn = Function::Create(ProxyFnTy, GlobalValue::InternalLinkage,
6727                                   ".omp_target_task_proxy_func",
6728                                   Builder.GetInsertBlock()->getModule());
6729   ProxyFn->getArg(0)->setName("thread.id");
6730   ProxyFn->getArg(1)->setName("task");
6731 
6732   BasicBlock *EntryBB =
6733       BasicBlock::Create(Builder.getContext(), "entry", ProxyFn);
6734   Builder.SetInsertPoint(EntryBB);
6735 
6736   bool HasShareds = StaleCI->arg_size() > 1;
6737   // TODO: This is a temporary assert to prove to ourselves that
6738   // the outlined target launch function is always going to have
6739   // atmost two arguments if there is any data shared between
6740   // host and device.
6741   assert((!HasShareds || (StaleCI->arg_size() == 2)) &&
6742          "StaleCI with shareds should have exactly two arguments.");
6743   if (HasShareds) {
6744     auto *ArgStructAlloca = dyn_cast<AllocaInst>(StaleCI->getArgOperand(1));
6745     assert(ArgStructAlloca &&
6746            "Unable to find the alloca instruction corresponding to arguments "
6747            "for extracted function");
6748     auto *ArgStructType =
6749         dyn_cast<StructType>(ArgStructAlloca->getAllocatedType());
6750 
6751     AllocaInst *NewArgStructAlloca =
6752         Builder.CreateAlloca(ArgStructType, nullptr, "structArg");
6753     Value *TaskT = ProxyFn->getArg(1);
6754     Value *ThreadId = ProxyFn->getArg(0);
6755     Value *SharedsSize =
6756         Builder.getInt64(M.getDataLayout().getTypeStoreSize(ArgStructType));
6757 
6758     Value *Shareds = Builder.CreateStructGEP(TaskTy, TaskT, 0);
6759     LoadInst *LoadShared =
6760         Builder.CreateLoad(PointerType::getUnqual(Ctx), Shareds);
6761 
6762     Builder.CreateMemCpy(
6763         NewArgStructAlloca, NewArgStructAlloca->getAlign(), LoadShared,
6764         LoadShared->getPointerAlignment(M.getDataLayout()), SharedsSize);
6765 
6766     Builder.CreateCall(KernelLaunchFunction, {ThreadId, NewArgStructAlloca});
6767   }
6768   Builder.CreateRetVoid();
6769   return ProxyFn;
6770 }
emitTargetOutlinedFunction(OpenMPIRBuilder & OMPBuilder,IRBuilderBase & Builder,TargetRegionEntryInfo & EntryInfo,Function * & OutlinedFn,Constant * & OutlinedFnID,SmallVectorImpl<Value * > & Inputs,OpenMPIRBuilder::TargetBodyGenCallbackTy & CBFunc,OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy & ArgAccessorFuncCB)6771 static void emitTargetOutlinedFunction(
6772     OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
6773     TargetRegionEntryInfo &EntryInfo, Function *&OutlinedFn,
6774     Constant *&OutlinedFnID, SmallVectorImpl<Value *> &Inputs,
6775     OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
6776     OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy &ArgAccessorFuncCB) {
6777 
6778   OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction =
6779       [&OMPBuilder, &Builder, &Inputs, &CBFunc,
6780        &ArgAccessorFuncCB](StringRef EntryFnName) {
6781         return createOutlinedFunction(OMPBuilder, Builder, EntryFnName, Inputs,
6782                                       CBFunc, ArgAccessorFuncCB);
6783       };
6784 
6785   OMPBuilder.emitTargetRegionFunction(EntryInfo, GenerateOutlinedFunction, true,
6786                                       OutlinedFn, OutlinedFnID);
6787 }
emitTargetTask(Function * OutlinedFn,Value * OutlinedFnID,EmitFallbackCallbackTy EmitTargetCallFallbackCB,TargetKernelArgs & Args,Value * DeviceID,Value * RTLoc,OpenMPIRBuilder::InsertPointTy AllocaIP,SmallVector<llvm::OpenMPIRBuilder::DependData> & Dependencies,bool HasNoWait)6788 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitTargetTask(
6789     Function *OutlinedFn, Value *OutlinedFnID,
6790     EmitFallbackCallbackTy EmitTargetCallFallbackCB, TargetKernelArgs &Args,
6791     Value *DeviceID, Value *RTLoc, OpenMPIRBuilder::InsertPointTy AllocaIP,
6792     SmallVector<llvm::OpenMPIRBuilder::DependData> &Dependencies,
6793     bool HasNoWait) {
6794 
6795   // When we arrive at this function, the target region itself has been
6796   // outlined into the function OutlinedFn.
6797   // So at ths point, for
6798   // --------------------------------------------------
6799   //   void user_code_that_offloads(...) {
6800   //     omp target depend(..) map(from:a) map(to:b, c)
6801   //        a = b + c
6802   //   }
6803   //
6804   // --------------------------------------------------
6805   //
6806   // we have
6807   //
6808   // --------------------------------------------------
6809   //
6810   //   void user_code_that_offloads(...) {
6811   //     %.offload_baseptrs = alloca [3 x ptr], align 8
6812   //     %.offload_ptrs = alloca [3 x ptr], align 8
6813   //     %.offload_mappers = alloca [3 x ptr], align 8
6814   //     ;; target region has been outlined and now we need to
6815   //     ;; offload to it via a target task.
6816   //   }
6817   //   void outlined_device_function(ptr a, ptr b, ptr c) {
6818   //     *a = *b + *c
6819   //   }
6820   //
6821   // We have to now do the following
6822   // (i)   Make an offloading call to outlined_device_function using the OpenMP
6823   //       RTL. See 'kernel_launch_function' in the pseudo code below. This is
6824   //       emitted by emitKernelLaunch
6825   // (ii)  Create a task entry point function that calls kernel_launch_function
6826   //       and is the entry point for the target task. See
6827   //       '@.omp_target_task_proxy_func in the pseudocode below.
6828   // (iii) Create a task with the task entry point created in (ii)
6829   //
6830   // That is we create the following
6831   //
6832   //   void user_code_that_offloads(...) {
6833   //     %.offload_baseptrs = alloca [3 x ptr], align 8
6834   //     %.offload_ptrs = alloca [3 x ptr], align 8
6835   //     %.offload_mappers = alloca [3 x ptr], align 8
6836   //
6837   //     %structArg = alloca { ptr, ptr, ptr }, align 8
6838   //     %strucArg[0] = %.offload_baseptrs
6839   //     %strucArg[1] = %.offload_ptrs
6840   //     %strucArg[2] = %.offload_mappers
6841   //     proxy_target_task = @__kmpc_omp_task_alloc(...,
6842   //                                               @.omp_target_task_proxy_func)
6843   //     memcpy(proxy_target_task->shareds, %structArg, sizeof(structArg))
6844   //     dependencies_array = ...
6845   //     ;; if nowait not present
6846   //     call @__kmpc_omp_wait_deps(..., dependencies_array)
6847   //     call @__kmpc_omp_task_begin_if0(...)
6848   //     call @ @.omp_target_task_proxy_func(i32 thread_id, ptr
6849   //     %proxy_target_task) call @__kmpc_omp_task_complete_if0(...)
6850   //   }
6851   //
6852   //   define internal void @.omp_target_task_proxy_func(i32 %thread.id,
6853   //                                                     ptr %task) {
6854   //       %structArg = alloca {ptr, ptr, ptr}
6855   //       %shared_data = load (getelementptr %task, 0, 0)
6856   //       mempcy(%structArg, %shared_data, sizeof(structArg))
6857   //       kernel_launch_function(%thread.id, %structArg)
6858   //   }
6859   //
6860   //   We need the proxy function because the signature of the task entry point
6861   //   expected by kmpc_omp_task is always the same and will be different from
6862   //   that of the kernel_launch function.
6863   //
6864   //   kernel_launch_function is generated by emitKernelLaunch and has the
6865   //   always_inline attribute.
6866   //   void kernel_launch_function(thread_id,
6867   //                               structArg) alwaysinline {
6868   //       %kernel_args = alloca %struct.__tgt_kernel_arguments, align 8
6869   //       offload_baseptrs = load(getelementptr structArg, 0, 0)
6870   //       offload_ptrs = load(getelementptr structArg, 0, 1)
6871   //       offload_mappers = load(getelementptr structArg, 0, 2)
6872   //       ; setup kernel_args using offload_baseptrs, offload_ptrs and
6873   //       ; offload_mappers
6874   //       call i32 @__tgt_target_kernel(...,
6875   //                                     outlined_device_function,
6876   //                                     ptr %kernel_args)
6877   //   }
6878   //   void outlined_device_function(ptr a, ptr b, ptr c) {
6879   //      *a = *b + *c
6880   //   }
6881   //
6882   BasicBlock *TargetTaskBodyBB =
6883       splitBB(Builder, /*CreateBranch=*/true, "target.task.body");
6884   BasicBlock *TargetTaskAllocaBB =
6885       splitBB(Builder, /*CreateBranch=*/true, "target.task.alloca");
6886 
6887   InsertPointTy TargetTaskAllocaIP(TargetTaskAllocaBB,
6888                                    TargetTaskAllocaBB->begin());
6889   InsertPointTy TargetTaskBodyIP(TargetTaskBodyBB, TargetTaskBodyBB->begin());
6890 
6891   OutlineInfo OI;
6892   OI.EntryBB = TargetTaskAllocaBB;
6893   OI.OuterAllocaBB = AllocaIP.getBlock();
6894 
6895   // Add the thread ID argument.
6896   SmallVector<Instruction *, 4> ToBeDeleted;
6897   OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
6898       Builder, AllocaIP, ToBeDeleted, TargetTaskAllocaIP, "global.tid", false));
6899 
6900   Builder.restoreIP(TargetTaskBodyIP);
6901 
6902   // emitKernelLaunch makes the necessary runtime call to offload the kernel.
6903   // We then outline all that code into a separate function
6904   // ('kernel_launch_function' in the pseudo code above). This function is then
6905   // called by the target task proxy function (see
6906   // '@.omp_target_task_proxy_func' in the pseudo code above)
6907   // "@.omp_target_task_proxy_func' is generated by emitTargetTaskProxyFunction
6908   Builder.restoreIP(emitKernelLaunch(Builder, OutlinedFn, OutlinedFnID,
6909                                      EmitTargetCallFallbackCB, Args, DeviceID,
6910                                      RTLoc, TargetTaskAllocaIP));
6911 
6912   OI.ExitBB = Builder.saveIP().getBlock();
6913   OI.PostOutlineCB = [this, ToBeDeleted, Dependencies,
6914                       HasNoWait](Function &OutlinedFn) mutable {
6915     assert(OutlinedFn.getNumUses() == 1 &&
6916            "there must be a single user for the outlined function");
6917 
6918     CallInst *StaleCI = cast<CallInst>(OutlinedFn.user_back());
6919     bool HasShareds = StaleCI->arg_size() > 1;
6920 
6921     Function *ProxyFn = emitTargetTaskProxyFunction(*this, Builder, StaleCI);
6922 
6923     LLVM_DEBUG(dbgs() << "Proxy task entry function created: " << *ProxyFn
6924                       << "\n");
6925 
6926     Builder.SetInsertPoint(StaleCI);
6927 
6928     // Gather the arguments for emitting the runtime call.
6929     uint32_t SrcLocStrSize;
6930     Constant *SrcLocStr =
6931         getOrCreateSrcLocStr(LocationDescription(Builder), SrcLocStrSize);
6932     Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
6933 
6934     // @__kmpc_omp_task_alloc
6935     Function *TaskAllocFn =
6936         getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_alloc);
6937 
6938     // Arguments - `loc_ref` (Ident) and `gtid` (ThreadID)
6939     // call.
6940     Value *ThreadID = getOrCreateThreadID(Ident);
6941 
6942     // Argument - `sizeof_kmp_task_t` (TaskSize)
6943     // Tasksize refers to the size in bytes of kmp_task_t data structure
6944     // including private vars accessed in task.
6945     // TODO: add kmp_task_t_with_privates (privates)
6946     Value *TaskSize =
6947         Builder.getInt64(M.getDataLayout().getTypeStoreSize(Task));
6948 
6949     // Argument - `sizeof_shareds` (SharedsSize)
6950     // SharedsSize refers to the shareds array size in the kmp_task_t data
6951     // structure.
6952     Value *SharedsSize = Builder.getInt64(0);
6953     if (HasShareds) {
6954       auto *ArgStructAlloca = dyn_cast<AllocaInst>(StaleCI->getArgOperand(1));
6955       assert(ArgStructAlloca &&
6956              "Unable to find the alloca instruction corresponding to arguments "
6957              "for extracted function");
6958       auto *ArgStructType =
6959           dyn_cast<StructType>(ArgStructAlloca->getAllocatedType());
6960       assert(ArgStructType && "Unable to find struct type corresponding to "
6961                               "arguments for extracted function");
6962       SharedsSize =
6963           Builder.getInt64(M.getDataLayout().getTypeStoreSize(ArgStructType));
6964     }
6965 
6966     // Argument - `flags`
6967     // Task is tied iff (Flags & 1) == 1.
6968     // Task is untied iff (Flags & 1) == 0.
6969     // Task is final iff (Flags & 2) == 2.
6970     // Task is not final iff (Flags & 2) == 0.
6971     // A target task is not final and is untied.
6972     Value *Flags = Builder.getInt32(0);
6973 
6974     // Emit the @__kmpc_omp_task_alloc runtime call
6975     // The runtime call returns a pointer to an area where the task captured
6976     // variables must be copied before the task is run (TaskData)
6977     CallInst *TaskData = Builder.CreateCall(
6978         TaskAllocFn, {/*loc_ref=*/Ident, /*gtid=*/ThreadID, /*flags=*/Flags,
6979                       /*sizeof_task=*/TaskSize, /*sizeof_shared=*/SharedsSize,
6980                       /*task_func=*/ProxyFn});
6981 
6982     if (HasShareds) {
6983       Value *Shareds = StaleCI->getArgOperand(1);
6984       Align Alignment = TaskData->getPointerAlignment(M.getDataLayout());
6985       Value *TaskShareds = Builder.CreateLoad(VoidPtr, TaskData);
6986       Builder.CreateMemCpy(TaskShareds, Alignment, Shareds, Alignment,
6987                            SharedsSize);
6988     }
6989 
6990     Value *DepArray = emitTaskDependencies(*this, Dependencies);
6991 
6992     // ---------------------------------------------------------------
6993     // V5.2 13.8 target construct
6994     // If the nowait clause is present, execution of the target task
6995     // may be deferred. If the nowait clause is not present, the target task is
6996     // an included task.
6997     // ---------------------------------------------------------------
6998     // The above means that the lack of a nowait on the target construct
6999     // translates to '#pragma omp task if(0)'
7000     if (!HasNoWait) {
7001       if (DepArray) {
7002         Function *TaskWaitFn =
7003             getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_wait_deps);
7004         Builder.CreateCall(
7005             TaskWaitFn,
7006             {/*loc_ref=*/Ident, /*gtid=*/ThreadID,
7007              /*ndeps=*/Builder.getInt32(Dependencies.size()),
7008              /*dep_list=*/DepArray,
7009              /*ndeps_noalias=*/ConstantInt::get(Builder.getInt32Ty(), 0),
7010              /*noalias_dep_list=*/
7011              ConstantPointerNull::get(PointerType::getUnqual(M.getContext()))});
7012       }
7013       // Included task.
7014       Function *TaskBeginFn =
7015           getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_begin_if0);
7016       Function *TaskCompleteFn =
7017           getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_complete_if0);
7018       Builder.CreateCall(TaskBeginFn, {Ident, ThreadID, TaskData});
7019       CallInst *CI = nullptr;
7020       if (HasShareds)
7021         CI = Builder.CreateCall(ProxyFn, {ThreadID, TaskData});
7022       else
7023         CI = Builder.CreateCall(ProxyFn, {ThreadID});
7024       CI->setDebugLoc(StaleCI->getDebugLoc());
7025       Builder.CreateCall(TaskCompleteFn, {Ident, ThreadID, TaskData});
7026     } else if (DepArray) {
7027       // HasNoWait - meaning the task may be deferred. Call
7028       // __kmpc_omp_task_with_deps if there are dependencies,
7029       // else call __kmpc_omp_task
7030       Function *TaskFn =
7031           getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_with_deps);
7032       Builder.CreateCall(
7033           TaskFn,
7034           {Ident, ThreadID, TaskData, Builder.getInt32(Dependencies.size()),
7035            DepArray, ConstantInt::get(Builder.getInt32Ty(), 0),
7036            ConstantPointerNull::get(PointerType::getUnqual(M.getContext()))});
7037     } else {
7038       // Emit the @__kmpc_omp_task runtime call to spawn the task
7039       Function *TaskFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task);
7040       Builder.CreateCall(TaskFn, {Ident, ThreadID, TaskData});
7041     }
7042 
7043     StaleCI->eraseFromParent();
7044     llvm::for_each(llvm::reverse(ToBeDeleted),
7045                    [](Instruction *I) { I->eraseFromParent(); });
7046   };
7047   addOutlineInfo(std::move(OI));
7048 
7049   LLVM_DEBUG(dbgs() << "Insert block after emitKernelLaunch = \n"
7050                     << *(Builder.GetInsertBlock()) << "\n");
7051   LLVM_DEBUG(dbgs() << "Module after emitKernelLaunch = \n"
7052                     << *(Builder.GetInsertBlock()->getParent()->getParent())
7053                     << "\n");
7054   return Builder.saveIP();
7055 }
emitTargetCall(OpenMPIRBuilder & OMPBuilder,IRBuilderBase & Builder,OpenMPIRBuilder::InsertPointTy AllocaIP,Function * OutlinedFn,Constant * OutlinedFnID,int32_t NumTeams,int32_t NumThreads,SmallVectorImpl<Value * > & Args,OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies={})7056 static void emitTargetCall(
7057     OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
7058     OpenMPIRBuilder::InsertPointTy AllocaIP, Function *OutlinedFn,
7059     Constant *OutlinedFnID, int32_t NumTeams, int32_t NumThreads,
7060     SmallVectorImpl<Value *> &Args,
7061     OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
7062     SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies = {}) {
7063 
7064   OpenMPIRBuilder::TargetDataInfo Info(
7065       /*RequiresDevicePointerInfo=*/false,
7066       /*SeparateBeginEndCalls=*/true);
7067 
7068   OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP());
7069   OMPBuilder.emitOffloadingArrays(AllocaIP, Builder.saveIP(), MapInfo, Info,
7070                                   /*IsNonContiguous=*/true);
7071 
7072   OpenMPIRBuilder::TargetDataRTArgs RTArgs;
7073   OMPBuilder.emitOffloadingArraysArgument(Builder, RTArgs, Info,
7074                                           !MapInfo.Names.empty());
7075 
7076   //  emitKernelLaunch
7077   auto &&EmitTargetCallFallbackCB =
__anon46338d6b3702(OpenMPIRBuilder::InsertPointTy IP) 7078       [&](OpenMPIRBuilder::InsertPointTy IP) -> OpenMPIRBuilder::InsertPointTy {
7079     Builder.restoreIP(IP);
7080     Builder.CreateCall(OutlinedFn, Args);
7081     return Builder.saveIP();
7082   };
7083 
7084   unsigned NumTargetItems = MapInfo.BasePointers.size();
7085   // TODO: Use correct device ID
7086   Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF);
7087   Value *NumTeamsVal = Builder.getInt32(NumTeams);
7088   Value *NumThreadsVal = Builder.getInt32(NumThreads);
7089   uint32_t SrcLocStrSize;
7090   Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
7091   Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
7092                                              llvm::omp::IdentFlag(0), 0);
7093   // TODO: Use correct NumIterations
7094   Value *NumIterations = Builder.getInt64(0);
7095   // TODO: Use correct DynCGGroupMem
7096   Value *DynCGGroupMem = Builder.getInt32(0);
7097 
7098   bool HasNoWait = false;
7099   bool HasDependencies = Dependencies.size() > 0;
7100   bool RequiresOuterTargetTask = HasNoWait || HasDependencies;
7101 
7102   OpenMPIRBuilder::TargetKernelArgs KArgs(NumTargetItems, RTArgs, NumIterations,
7103                                           NumTeamsVal, NumThreadsVal,
7104                                           DynCGGroupMem, HasNoWait);
7105 
7106   // The presence of certain clauses on the target directive require the
7107   // explicit generation of the target task.
7108   if (RequiresOuterTargetTask) {
7109     Builder.restoreIP(OMPBuilder.emitTargetTask(
7110         OutlinedFn, OutlinedFnID, EmitTargetCallFallbackCB, KArgs, DeviceID,
7111         RTLoc, AllocaIP, Dependencies, HasNoWait));
7112   } else {
7113     Builder.restoreIP(OMPBuilder.emitKernelLaunch(
7114         Builder, OutlinedFn, OutlinedFnID, EmitTargetCallFallbackCB, KArgs,
7115         DeviceID, RTLoc, AllocaIP));
7116   }
7117 }
createTarget(const LocationDescription & Loc,InsertPointTy AllocaIP,InsertPointTy CodeGenIP,TargetRegionEntryInfo & EntryInfo,int32_t NumTeams,int32_t NumThreads,SmallVectorImpl<Value * > & Args,GenMapInfoCallbackTy GenMapInfoCB,OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,SmallVector<DependData> Dependencies)7118 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTarget(
7119     const LocationDescription &Loc, InsertPointTy AllocaIP,
7120     InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo, int32_t NumTeams,
7121     int32_t NumThreads, SmallVectorImpl<Value *> &Args,
7122     GenMapInfoCallbackTy GenMapInfoCB,
7123     OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
7124     OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
7125     SmallVector<DependData> Dependencies) {
7126 
7127   if (!updateToLocation(Loc))
7128     return InsertPointTy();
7129 
7130   Builder.restoreIP(CodeGenIP);
7131 
7132   Function *OutlinedFn;
7133   Constant *OutlinedFnID;
7134   // The target region is outlined into its own function. The LLVM IR for
7135   // the target region itself is generated using the callbacks CBFunc
7136   // and ArgAccessorFuncCB
7137   emitTargetOutlinedFunction(*this, Builder, EntryInfo, OutlinedFn,
7138                              OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB);
7139 
7140   // If we are not on the target device, then we need to generate code
7141   // to make a remote call (offload) to the previously outlined function
7142   // that represents the target region. Do that now.
7143   if (!Config.isTargetDevice())
7144     emitTargetCall(*this, Builder, AllocaIP, OutlinedFn, OutlinedFnID, NumTeams,
7145                    NumThreads, Args, GenMapInfoCB, Dependencies);
7146   return Builder.saveIP();
7147 }
7148 
getNameWithSeparators(ArrayRef<StringRef> Parts,StringRef FirstSeparator,StringRef Separator)7149 std::string OpenMPIRBuilder::getNameWithSeparators(ArrayRef<StringRef> Parts,
7150                                                    StringRef FirstSeparator,
7151                                                    StringRef Separator) {
7152   SmallString<128> Buffer;
7153   llvm::raw_svector_ostream OS(Buffer);
7154   StringRef Sep = FirstSeparator;
7155   for (StringRef Part : Parts) {
7156     OS << Sep << Part;
7157     Sep = Separator;
7158   }
7159   return OS.str().str();
7160 }
7161 
7162 std::string
createPlatformSpecificName(ArrayRef<StringRef> Parts) const7163 OpenMPIRBuilder::createPlatformSpecificName(ArrayRef<StringRef> Parts) const {
7164   return OpenMPIRBuilder::getNameWithSeparators(Parts, Config.firstSeparator(),
7165                                                 Config.separator());
7166 }
7167 
7168 GlobalVariable *
getOrCreateInternalVariable(Type * Ty,const StringRef & Name,unsigned AddressSpace)7169 OpenMPIRBuilder::getOrCreateInternalVariable(Type *Ty, const StringRef &Name,
7170                                              unsigned AddressSpace) {
7171   auto &Elem = *InternalVars.try_emplace(Name, nullptr).first;
7172   if (Elem.second) {
7173     assert(Elem.second->getValueType() == Ty &&
7174            "OMP internal variable has different type than requested");
7175   } else {
7176     // TODO: investigate the appropriate linkage type used for the global
7177     // variable for possibly changing that to internal or private, or maybe
7178     // create different versions of the function for different OMP internal
7179     // variables.
7180     auto Linkage = this->M.getTargetTriple().rfind("wasm32") == 0
7181                        ? GlobalValue::ExternalLinkage
7182                        : GlobalValue::CommonLinkage;
7183     auto *GV = new GlobalVariable(M, Ty, /*IsConstant=*/false, Linkage,
7184                                   Constant::getNullValue(Ty), Elem.first(),
7185                                   /*InsertBefore=*/nullptr,
7186                                   GlobalValue::NotThreadLocal, AddressSpace);
7187     const DataLayout &DL = M.getDataLayout();
7188     const llvm::Align TypeAlign = DL.getABITypeAlign(Ty);
7189     const llvm::Align PtrAlign = DL.getPointerABIAlignment(AddressSpace);
7190     GV->setAlignment(std::max(TypeAlign, PtrAlign));
7191     Elem.second = GV;
7192   }
7193 
7194   return Elem.second;
7195 }
7196 
getOMPCriticalRegionLock(StringRef CriticalName)7197 Value *OpenMPIRBuilder::getOMPCriticalRegionLock(StringRef CriticalName) {
7198   std::string Prefix = Twine("gomp_critical_user_", CriticalName).str();
7199   std::string Name = getNameWithSeparators({Prefix, "var"}, ".", ".");
7200   return getOrCreateInternalVariable(KmpCriticalNameTy, Name);
7201 }
7202 
getSizeInBytes(Value * BasePtr)7203 Value *OpenMPIRBuilder::getSizeInBytes(Value *BasePtr) {
7204   LLVMContext &Ctx = Builder.getContext();
7205   Value *Null =
7206       Constant::getNullValue(PointerType::getUnqual(BasePtr->getContext()));
7207   Value *SizeGep =
7208       Builder.CreateGEP(BasePtr->getType(), Null, Builder.getInt32(1));
7209   Value *SizePtrToInt = Builder.CreatePtrToInt(SizeGep, Type::getInt64Ty(Ctx));
7210   return SizePtrToInt;
7211 }
7212 
7213 GlobalVariable *
createOffloadMaptypes(SmallVectorImpl<uint64_t> & Mappings,std::string VarName)7214 OpenMPIRBuilder::createOffloadMaptypes(SmallVectorImpl<uint64_t> &Mappings,
7215                                        std::string VarName) {
7216   llvm::Constant *MaptypesArrayInit =
7217       llvm::ConstantDataArray::get(M.getContext(), Mappings);
7218   auto *MaptypesArrayGlobal = new llvm::GlobalVariable(
7219       M, MaptypesArrayInit->getType(),
7220       /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, MaptypesArrayInit,
7221       VarName);
7222   MaptypesArrayGlobal->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global);
7223   return MaptypesArrayGlobal;
7224 }
7225 
createMapperAllocas(const LocationDescription & Loc,InsertPointTy AllocaIP,unsigned NumOperands,struct MapperAllocas & MapperAllocas)7226 void OpenMPIRBuilder::createMapperAllocas(const LocationDescription &Loc,
7227                                           InsertPointTy AllocaIP,
7228                                           unsigned NumOperands,
7229                                           struct MapperAllocas &MapperAllocas) {
7230   if (!updateToLocation(Loc))
7231     return;
7232 
7233   auto *ArrI8PtrTy = ArrayType::get(Int8Ptr, NumOperands);
7234   auto *ArrI64Ty = ArrayType::get(Int64, NumOperands);
7235   Builder.restoreIP(AllocaIP);
7236   AllocaInst *ArgsBase = Builder.CreateAlloca(
7237       ArrI8PtrTy, /* ArraySize = */ nullptr, ".offload_baseptrs");
7238   AllocaInst *Args = Builder.CreateAlloca(ArrI8PtrTy, /* ArraySize = */ nullptr,
7239                                           ".offload_ptrs");
7240   AllocaInst *ArgSizes = Builder.CreateAlloca(
7241       ArrI64Ty, /* ArraySize = */ nullptr, ".offload_sizes");
7242   Builder.restoreIP(Loc.IP);
7243   MapperAllocas.ArgsBase = ArgsBase;
7244   MapperAllocas.Args = Args;
7245   MapperAllocas.ArgSizes = ArgSizes;
7246 }
7247 
emitMapperCall(const LocationDescription & Loc,Function * MapperFunc,Value * SrcLocInfo,Value * MaptypesArg,Value * MapnamesArg,struct MapperAllocas & MapperAllocas,int64_t DeviceID,unsigned NumOperands)7248 void OpenMPIRBuilder::emitMapperCall(const LocationDescription &Loc,
7249                                      Function *MapperFunc, Value *SrcLocInfo,
7250                                      Value *MaptypesArg, Value *MapnamesArg,
7251                                      struct MapperAllocas &MapperAllocas,
7252                                      int64_t DeviceID, unsigned NumOperands) {
7253   if (!updateToLocation(Loc))
7254     return;
7255 
7256   auto *ArrI8PtrTy = ArrayType::get(Int8Ptr, NumOperands);
7257   auto *ArrI64Ty = ArrayType::get(Int64, NumOperands);
7258   Value *ArgsBaseGEP =
7259       Builder.CreateInBoundsGEP(ArrI8PtrTy, MapperAllocas.ArgsBase,
7260                                 {Builder.getInt32(0), Builder.getInt32(0)});
7261   Value *ArgsGEP =
7262       Builder.CreateInBoundsGEP(ArrI8PtrTy, MapperAllocas.Args,
7263                                 {Builder.getInt32(0), Builder.getInt32(0)});
7264   Value *ArgSizesGEP =
7265       Builder.CreateInBoundsGEP(ArrI64Ty, MapperAllocas.ArgSizes,
7266                                 {Builder.getInt32(0), Builder.getInt32(0)});
7267   Value *NullPtr =
7268       Constant::getNullValue(PointerType::getUnqual(Int8Ptr->getContext()));
7269   Builder.CreateCall(MapperFunc,
7270                      {SrcLocInfo, Builder.getInt64(DeviceID),
7271                       Builder.getInt32(NumOperands), ArgsBaseGEP, ArgsGEP,
7272                       ArgSizesGEP, MaptypesArg, MapnamesArg, NullPtr});
7273 }
7274 
emitOffloadingArraysArgument(IRBuilderBase & Builder,TargetDataRTArgs & RTArgs,TargetDataInfo & Info,bool EmitDebug,bool ForEndCall)7275 void OpenMPIRBuilder::emitOffloadingArraysArgument(IRBuilderBase &Builder,
7276                                                    TargetDataRTArgs &RTArgs,
7277                                                    TargetDataInfo &Info,
7278                                                    bool EmitDebug,
7279                                                    bool ForEndCall) {
7280   assert((!ForEndCall || Info.separateBeginEndCalls()) &&
7281          "expected region end call to runtime only when end call is separate");
7282   auto UnqualPtrTy = PointerType::getUnqual(M.getContext());
7283   auto VoidPtrTy = UnqualPtrTy;
7284   auto VoidPtrPtrTy = UnqualPtrTy;
7285   auto Int64Ty = Type::getInt64Ty(M.getContext());
7286   auto Int64PtrTy = UnqualPtrTy;
7287 
7288   if (!Info.NumberOfPtrs) {
7289     RTArgs.BasePointersArray = ConstantPointerNull::get(VoidPtrPtrTy);
7290     RTArgs.PointersArray = ConstantPointerNull::get(VoidPtrPtrTy);
7291     RTArgs.SizesArray = ConstantPointerNull::get(Int64PtrTy);
7292     RTArgs.MapTypesArray = ConstantPointerNull::get(Int64PtrTy);
7293     RTArgs.MapNamesArray = ConstantPointerNull::get(VoidPtrPtrTy);
7294     RTArgs.MappersArray = ConstantPointerNull::get(VoidPtrPtrTy);
7295     return;
7296   }
7297 
7298   RTArgs.BasePointersArray = Builder.CreateConstInBoundsGEP2_32(
7299       ArrayType::get(VoidPtrTy, Info.NumberOfPtrs),
7300       Info.RTArgs.BasePointersArray,
7301       /*Idx0=*/0, /*Idx1=*/0);
7302   RTArgs.PointersArray = Builder.CreateConstInBoundsGEP2_32(
7303       ArrayType::get(VoidPtrTy, Info.NumberOfPtrs), Info.RTArgs.PointersArray,
7304       /*Idx0=*/0,
7305       /*Idx1=*/0);
7306   RTArgs.SizesArray = Builder.CreateConstInBoundsGEP2_32(
7307       ArrayType::get(Int64Ty, Info.NumberOfPtrs), Info.RTArgs.SizesArray,
7308       /*Idx0=*/0, /*Idx1=*/0);
7309   RTArgs.MapTypesArray = Builder.CreateConstInBoundsGEP2_32(
7310       ArrayType::get(Int64Ty, Info.NumberOfPtrs),
7311       ForEndCall && Info.RTArgs.MapTypesArrayEnd ? Info.RTArgs.MapTypesArrayEnd
7312                                                  : Info.RTArgs.MapTypesArray,
7313       /*Idx0=*/0,
7314       /*Idx1=*/0);
7315 
7316   // Only emit the mapper information arrays if debug information is
7317   // requested.
7318   if (!EmitDebug)
7319     RTArgs.MapNamesArray = ConstantPointerNull::get(VoidPtrPtrTy);
7320   else
7321     RTArgs.MapNamesArray = Builder.CreateConstInBoundsGEP2_32(
7322         ArrayType::get(VoidPtrTy, Info.NumberOfPtrs), Info.RTArgs.MapNamesArray,
7323         /*Idx0=*/0,
7324         /*Idx1=*/0);
7325   // If there is no user-defined mapper, set the mapper array to nullptr to
7326   // avoid an unnecessary data privatization
7327   if (!Info.HasMapper)
7328     RTArgs.MappersArray = ConstantPointerNull::get(VoidPtrPtrTy);
7329   else
7330     RTArgs.MappersArray =
7331         Builder.CreatePointerCast(Info.RTArgs.MappersArray, VoidPtrPtrTy);
7332 }
7333 
emitNonContiguousDescriptor(InsertPointTy AllocaIP,InsertPointTy CodeGenIP,MapInfosTy & CombinedInfo,TargetDataInfo & Info)7334 void OpenMPIRBuilder::emitNonContiguousDescriptor(InsertPointTy AllocaIP,
7335                                                   InsertPointTy CodeGenIP,
7336                                                   MapInfosTy &CombinedInfo,
7337                                                   TargetDataInfo &Info) {
7338   MapInfosTy::StructNonContiguousInfo &NonContigInfo =
7339       CombinedInfo.NonContigInfo;
7340 
7341   // Build an array of struct descriptor_dim and then assign it to
7342   // offload_args.
7343   //
7344   // struct descriptor_dim {
7345   //  uint64_t offset;
7346   //  uint64_t count;
7347   //  uint64_t stride
7348   // };
7349   Type *Int64Ty = Builder.getInt64Ty();
7350   StructType *DimTy = StructType::create(
7351       M.getContext(), ArrayRef<Type *>({Int64Ty, Int64Ty, Int64Ty}),
7352       "struct.descriptor_dim");
7353 
7354   enum { OffsetFD = 0, CountFD, StrideFD };
7355   // We need two index variable here since the size of "Dims" is the same as
7356   // the size of Components, however, the size of offset, count, and stride is
7357   // equal to the size of base declaration that is non-contiguous.
7358   for (unsigned I = 0, L = 0, E = NonContigInfo.Dims.size(); I < E; ++I) {
7359     // Skip emitting ir if dimension size is 1 since it cannot be
7360     // non-contiguous.
7361     if (NonContigInfo.Dims[I] == 1)
7362       continue;
7363     Builder.restoreIP(AllocaIP);
7364     ArrayType *ArrayTy = ArrayType::get(DimTy, NonContigInfo.Dims[I]);
7365     AllocaInst *DimsAddr =
7366         Builder.CreateAlloca(ArrayTy, /* ArraySize = */ nullptr, "dims");
7367     Builder.restoreIP(CodeGenIP);
7368     for (unsigned II = 0, EE = NonContigInfo.Dims[I]; II < EE; ++II) {
7369       unsigned RevIdx = EE - II - 1;
7370       Value *DimsLVal = Builder.CreateInBoundsGEP(
7371           DimsAddr->getAllocatedType(), DimsAddr,
7372           {Builder.getInt64(0), Builder.getInt64(II)});
7373       // Offset
7374       Value *OffsetLVal = Builder.CreateStructGEP(DimTy, DimsLVal, OffsetFD);
7375       Builder.CreateAlignedStore(
7376           NonContigInfo.Offsets[L][RevIdx], OffsetLVal,
7377           M.getDataLayout().getPrefTypeAlign(OffsetLVal->getType()));
7378       // Count
7379       Value *CountLVal = Builder.CreateStructGEP(DimTy, DimsLVal, CountFD);
7380       Builder.CreateAlignedStore(
7381           NonContigInfo.Counts[L][RevIdx], CountLVal,
7382           M.getDataLayout().getPrefTypeAlign(CountLVal->getType()));
7383       // Stride
7384       Value *StrideLVal = Builder.CreateStructGEP(DimTy, DimsLVal, StrideFD);
7385       Builder.CreateAlignedStore(
7386           NonContigInfo.Strides[L][RevIdx], StrideLVal,
7387           M.getDataLayout().getPrefTypeAlign(CountLVal->getType()));
7388     }
7389     // args[I] = &dims
7390     Builder.restoreIP(CodeGenIP);
7391     Value *DAddr = Builder.CreatePointerBitCastOrAddrSpaceCast(
7392         DimsAddr, Builder.getPtrTy());
7393     Value *P = Builder.CreateConstInBoundsGEP2_32(
7394         ArrayType::get(Builder.getPtrTy(), Info.NumberOfPtrs),
7395         Info.RTArgs.PointersArray, 0, I);
7396     Builder.CreateAlignedStore(
7397         DAddr, P, M.getDataLayout().getPrefTypeAlign(Builder.getPtrTy()));
7398     ++L;
7399   }
7400 }
7401 
emitOffloadingArrays(InsertPointTy AllocaIP,InsertPointTy CodeGenIP,MapInfosTy & CombinedInfo,TargetDataInfo & Info,bool IsNonContiguous,function_ref<void (unsigned int,Value *)> DeviceAddrCB,function_ref<Value * (unsigned int)> CustomMapperCB)7402 void OpenMPIRBuilder::emitOffloadingArrays(
7403     InsertPointTy AllocaIP, InsertPointTy CodeGenIP, MapInfosTy &CombinedInfo,
7404     TargetDataInfo &Info, bool IsNonContiguous,
7405     function_ref<void(unsigned int, Value *)> DeviceAddrCB,
7406     function_ref<Value *(unsigned int)> CustomMapperCB) {
7407 
7408   // Reset the array information.
7409   Info.clearArrayInfo();
7410   Info.NumberOfPtrs = CombinedInfo.BasePointers.size();
7411 
7412   if (Info.NumberOfPtrs == 0)
7413     return;
7414 
7415   Builder.restoreIP(AllocaIP);
7416   // Detect if we have any capture size requiring runtime evaluation of the
7417   // size so that a constant array could be eventually used.
7418   ArrayType *PointerArrayType =
7419       ArrayType::get(Builder.getPtrTy(), Info.NumberOfPtrs);
7420 
7421   Info.RTArgs.BasePointersArray = Builder.CreateAlloca(
7422       PointerArrayType, /* ArraySize = */ nullptr, ".offload_baseptrs");
7423 
7424   Info.RTArgs.PointersArray = Builder.CreateAlloca(
7425       PointerArrayType, /* ArraySize = */ nullptr, ".offload_ptrs");
7426   AllocaInst *MappersArray = Builder.CreateAlloca(
7427       PointerArrayType, /* ArraySize = */ nullptr, ".offload_mappers");
7428   Info.RTArgs.MappersArray = MappersArray;
7429 
7430   // If we don't have any VLA types or other types that require runtime
7431   // evaluation, we can use a constant array for the map sizes, otherwise we
7432   // need to fill up the arrays as we do for the pointers.
7433   Type *Int64Ty = Builder.getInt64Ty();
7434   SmallVector<Constant *> ConstSizes(CombinedInfo.Sizes.size(),
7435                                      ConstantInt::get(Int64Ty, 0));
7436   SmallBitVector RuntimeSizes(CombinedInfo.Sizes.size());
7437   for (unsigned I = 0, E = CombinedInfo.Sizes.size(); I < E; ++I) {
7438     if (auto *CI = dyn_cast<Constant>(CombinedInfo.Sizes[I])) {
7439       if (!isa<ConstantExpr>(CI) && !isa<GlobalValue>(CI)) {
7440         if (IsNonContiguous &&
7441             static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
7442                 CombinedInfo.Types[I] &
7443                 OpenMPOffloadMappingFlags::OMP_MAP_NON_CONTIG))
7444           ConstSizes[I] =
7445               ConstantInt::get(Int64Ty, CombinedInfo.NonContigInfo.Dims[I]);
7446         else
7447           ConstSizes[I] = CI;
7448         continue;
7449       }
7450     }
7451     RuntimeSizes.set(I);
7452   }
7453 
7454   if (RuntimeSizes.all()) {
7455     ArrayType *SizeArrayType = ArrayType::get(Int64Ty, Info.NumberOfPtrs);
7456     Info.RTArgs.SizesArray = Builder.CreateAlloca(
7457         SizeArrayType, /* ArraySize = */ nullptr, ".offload_sizes");
7458     Builder.restoreIP(CodeGenIP);
7459   } else {
7460     auto *SizesArrayInit = ConstantArray::get(
7461         ArrayType::get(Int64Ty, ConstSizes.size()), ConstSizes);
7462     std::string Name = createPlatformSpecificName({"offload_sizes"});
7463     auto *SizesArrayGbl =
7464         new GlobalVariable(M, SizesArrayInit->getType(), /*isConstant=*/true,
7465                            GlobalValue::PrivateLinkage, SizesArrayInit, Name);
7466     SizesArrayGbl->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
7467 
7468     if (!RuntimeSizes.any()) {
7469       Info.RTArgs.SizesArray = SizesArrayGbl;
7470     } else {
7471       unsigned IndexSize = M.getDataLayout().getIndexSizeInBits(0);
7472       Align OffloadSizeAlign = M.getDataLayout().getABIIntegerTypeAlignment(64);
7473       ArrayType *SizeArrayType = ArrayType::get(Int64Ty, Info.NumberOfPtrs);
7474       AllocaInst *Buffer = Builder.CreateAlloca(
7475           SizeArrayType, /* ArraySize = */ nullptr, ".offload_sizes");
7476       Buffer->setAlignment(OffloadSizeAlign);
7477       Builder.restoreIP(CodeGenIP);
7478       Builder.CreateMemCpy(
7479           Buffer, M.getDataLayout().getPrefTypeAlign(Buffer->getType()),
7480           SizesArrayGbl, OffloadSizeAlign,
7481           Builder.getIntN(
7482               IndexSize,
7483               Buffer->getAllocationSize(M.getDataLayout())->getFixedValue()));
7484 
7485       Info.RTArgs.SizesArray = Buffer;
7486     }
7487     Builder.restoreIP(CodeGenIP);
7488   }
7489 
7490   // The map types are always constant so we don't need to generate code to
7491   // fill arrays. Instead, we create an array constant.
7492   SmallVector<uint64_t, 4> Mapping;
7493   for (auto mapFlag : CombinedInfo.Types)
7494     Mapping.push_back(
7495         static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
7496             mapFlag));
7497   std::string MaptypesName = createPlatformSpecificName({"offload_maptypes"});
7498   auto *MapTypesArrayGbl = createOffloadMaptypes(Mapping, MaptypesName);
7499   Info.RTArgs.MapTypesArray = MapTypesArrayGbl;
7500 
7501   // The information types are only built if provided.
7502   if (!CombinedInfo.Names.empty()) {
7503     std::string MapnamesName = createPlatformSpecificName({"offload_mapnames"});
7504     auto *MapNamesArrayGbl =
7505         createOffloadMapnames(CombinedInfo.Names, MapnamesName);
7506     Info.RTArgs.MapNamesArray = MapNamesArrayGbl;
7507   } else {
7508     Info.RTArgs.MapNamesArray =
7509         Constant::getNullValue(PointerType::getUnqual(Builder.getContext()));
7510   }
7511 
7512   // If there's a present map type modifier, it must not be applied to the end
7513   // of a region, so generate a separate map type array in that case.
7514   if (Info.separateBeginEndCalls()) {
7515     bool EndMapTypesDiffer = false;
7516     for (uint64_t &Type : Mapping) {
7517       if (Type & static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
7518                      OpenMPOffloadMappingFlags::OMP_MAP_PRESENT)) {
7519         Type &= ~static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
7520             OpenMPOffloadMappingFlags::OMP_MAP_PRESENT);
7521         EndMapTypesDiffer = true;
7522       }
7523     }
7524     if (EndMapTypesDiffer) {
7525       MapTypesArrayGbl = createOffloadMaptypes(Mapping, MaptypesName);
7526       Info.RTArgs.MapTypesArrayEnd = MapTypesArrayGbl;
7527     }
7528   }
7529 
7530   PointerType *PtrTy = Builder.getPtrTy();
7531   for (unsigned I = 0; I < Info.NumberOfPtrs; ++I) {
7532     Value *BPVal = CombinedInfo.BasePointers[I];
7533     Value *BP = Builder.CreateConstInBoundsGEP2_32(
7534         ArrayType::get(PtrTy, Info.NumberOfPtrs), Info.RTArgs.BasePointersArray,
7535         0, I);
7536     Builder.CreateAlignedStore(BPVal, BP,
7537                                M.getDataLayout().getPrefTypeAlign(PtrTy));
7538 
7539     if (Info.requiresDevicePointerInfo()) {
7540       if (CombinedInfo.DevicePointers[I] == DeviceInfoTy::Pointer) {
7541         CodeGenIP = Builder.saveIP();
7542         Builder.restoreIP(AllocaIP);
7543         Info.DevicePtrInfoMap[BPVal] = {BP, Builder.CreateAlloca(PtrTy)};
7544         Builder.restoreIP(CodeGenIP);
7545         if (DeviceAddrCB)
7546           DeviceAddrCB(I, Info.DevicePtrInfoMap[BPVal].second);
7547       } else if (CombinedInfo.DevicePointers[I] == DeviceInfoTy::Address) {
7548         Info.DevicePtrInfoMap[BPVal] = {BP, BP};
7549         if (DeviceAddrCB)
7550           DeviceAddrCB(I, BP);
7551       }
7552     }
7553 
7554     Value *PVal = CombinedInfo.Pointers[I];
7555     Value *P = Builder.CreateConstInBoundsGEP2_32(
7556         ArrayType::get(PtrTy, Info.NumberOfPtrs), Info.RTArgs.PointersArray, 0,
7557         I);
7558     // TODO: Check alignment correct.
7559     Builder.CreateAlignedStore(PVal, P,
7560                                M.getDataLayout().getPrefTypeAlign(PtrTy));
7561 
7562     if (RuntimeSizes.test(I)) {
7563       Value *S = Builder.CreateConstInBoundsGEP2_32(
7564           ArrayType::get(Int64Ty, Info.NumberOfPtrs), Info.RTArgs.SizesArray,
7565           /*Idx0=*/0,
7566           /*Idx1=*/I);
7567       Builder.CreateAlignedStore(Builder.CreateIntCast(CombinedInfo.Sizes[I],
7568                                                        Int64Ty,
7569                                                        /*isSigned=*/true),
7570                                  S, M.getDataLayout().getPrefTypeAlign(PtrTy));
7571     }
7572     // Fill up the mapper array.
7573     unsigned IndexSize = M.getDataLayout().getIndexSizeInBits(0);
7574     Value *MFunc = ConstantPointerNull::get(PtrTy);
7575     if (CustomMapperCB)
7576       if (Value *CustomMFunc = CustomMapperCB(I))
7577         MFunc = Builder.CreatePointerCast(CustomMFunc, PtrTy);
7578     Value *MAddr = Builder.CreateInBoundsGEP(
7579         MappersArray->getAllocatedType(), MappersArray,
7580         {Builder.getIntN(IndexSize, 0), Builder.getIntN(IndexSize, I)});
7581     Builder.CreateAlignedStore(
7582         MFunc, MAddr, M.getDataLayout().getPrefTypeAlign(MAddr->getType()));
7583   }
7584 
7585   if (!IsNonContiguous || CombinedInfo.NonContigInfo.Offsets.empty() ||
7586       Info.NumberOfPtrs == 0)
7587     return;
7588   emitNonContiguousDescriptor(AllocaIP, CodeGenIP, CombinedInfo, Info);
7589 }
7590 
emitBranch(BasicBlock * Target)7591 void OpenMPIRBuilder::emitBranch(BasicBlock *Target) {
7592   BasicBlock *CurBB = Builder.GetInsertBlock();
7593 
7594   if (!CurBB || CurBB->getTerminator()) {
7595     // If there is no insert point or the previous block is already
7596     // terminated, don't touch it.
7597   } else {
7598     // Otherwise, create a fall-through branch.
7599     Builder.CreateBr(Target);
7600   }
7601 
7602   Builder.ClearInsertionPoint();
7603 }
7604 
emitBlock(BasicBlock * BB,Function * CurFn,bool IsFinished)7605 void OpenMPIRBuilder::emitBlock(BasicBlock *BB, Function *CurFn,
7606                                 bool IsFinished) {
7607   BasicBlock *CurBB = Builder.GetInsertBlock();
7608 
7609   // Fall out of the current block (if necessary).
7610   emitBranch(BB);
7611 
7612   if (IsFinished && BB->use_empty()) {
7613     BB->eraseFromParent();
7614     return;
7615   }
7616 
7617   // Place the block after the current block, if possible, or else at
7618   // the end of the function.
7619   if (CurBB && CurBB->getParent())
7620     CurFn->insert(std::next(CurBB->getIterator()), BB);
7621   else
7622     CurFn->insert(CurFn->end(), BB);
7623   Builder.SetInsertPoint(BB);
7624 }
7625 
emitIfClause(Value * Cond,BodyGenCallbackTy ThenGen,BodyGenCallbackTy ElseGen,InsertPointTy AllocaIP)7626 void OpenMPIRBuilder::emitIfClause(Value *Cond, BodyGenCallbackTy ThenGen,
7627                                    BodyGenCallbackTy ElseGen,
7628                                    InsertPointTy AllocaIP) {
7629   // If the condition constant folds and can be elided, try to avoid emitting
7630   // the condition and the dead arm of the if/else.
7631   if (auto *CI = dyn_cast<ConstantInt>(Cond)) {
7632     auto CondConstant = CI->getSExtValue();
7633     if (CondConstant)
7634       ThenGen(AllocaIP, Builder.saveIP());
7635     else
7636       ElseGen(AllocaIP, Builder.saveIP());
7637     return;
7638   }
7639 
7640   Function *CurFn = Builder.GetInsertBlock()->getParent();
7641 
7642   // Otherwise, the condition did not fold, or we couldn't elide it.  Just
7643   // emit the conditional branch.
7644   BasicBlock *ThenBlock = BasicBlock::Create(M.getContext(), "omp_if.then");
7645   BasicBlock *ElseBlock = BasicBlock::Create(M.getContext(), "omp_if.else");
7646   BasicBlock *ContBlock = BasicBlock::Create(M.getContext(), "omp_if.end");
7647   Builder.CreateCondBr(Cond, ThenBlock, ElseBlock);
7648   // Emit the 'then' code.
7649   emitBlock(ThenBlock, CurFn);
7650   ThenGen(AllocaIP, Builder.saveIP());
7651   emitBranch(ContBlock);
7652   // Emit the 'else' code if present.
7653   // There is no need to emit line number for unconditional branch.
7654   emitBlock(ElseBlock, CurFn);
7655   ElseGen(AllocaIP, Builder.saveIP());
7656   // There is no need to emit line number for unconditional branch.
7657   emitBranch(ContBlock);
7658   // Emit the continuation block for code after the if.
7659   emitBlock(ContBlock, CurFn, /*IsFinished=*/true);
7660 }
7661 
checkAndEmitFlushAfterAtomic(const LocationDescription & Loc,llvm::AtomicOrdering AO,AtomicKind AK)7662 bool OpenMPIRBuilder::checkAndEmitFlushAfterAtomic(
7663     const LocationDescription &Loc, llvm::AtomicOrdering AO, AtomicKind AK) {
7664   assert(!(AO == AtomicOrdering::NotAtomic ||
7665            AO == llvm::AtomicOrdering::Unordered) &&
7666          "Unexpected Atomic Ordering.");
7667 
7668   bool Flush = false;
7669   llvm::AtomicOrdering FlushAO = AtomicOrdering::Monotonic;
7670 
7671   switch (AK) {
7672   case Read:
7673     if (AO == AtomicOrdering::Acquire || AO == AtomicOrdering::AcquireRelease ||
7674         AO == AtomicOrdering::SequentiallyConsistent) {
7675       FlushAO = AtomicOrdering::Acquire;
7676       Flush = true;
7677     }
7678     break;
7679   case Write:
7680   case Compare:
7681   case Update:
7682     if (AO == AtomicOrdering::Release || AO == AtomicOrdering::AcquireRelease ||
7683         AO == AtomicOrdering::SequentiallyConsistent) {
7684       FlushAO = AtomicOrdering::Release;
7685       Flush = true;
7686     }
7687     break;
7688   case Capture:
7689     switch (AO) {
7690     case AtomicOrdering::Acquire:
7691       FlushAO = AtomicOrdering::Acquire;
7692       Flush = true;
7693       break;
7694     case AtomicOrdering::Release:
7695       FlushAO = AtomicOrdering::Release;
7696       Flush = true;
7697       break;
7698     case AtomicOrdering::AcquireRelease:
7699     case AtomicOrdering::SequentiallyConsistent:
7700       FlushAO = AtomicOrdering::AcquireRelease;
7701       Flush = true;
7702       break;
7703     default:
7704       // do nothing - leave silently.
7705       break;
7706     }
7707   }
7708 
7709   if (Flush) {
7710     // Currently Flush RT call still doesn't take memory_ordering, so for when
7711     // that happens, this tries to do the resolution of which atomic ordering
7712     // to use with but issue the flush call
7713     // TODO: pass `FlushAO` after memory ordering support is added
7714     (void)FlushAO;
7715     emitFlush(Loc);
7716   }
7717 
7718   // for AO == AtomicOrdering::Monotonic and  all other case combinations
7719   // do nothing
7720   return Flush;
7721 }
7722 
7723 OpenMPIRBuilder::InsertPointTy
createAtomicRead(const LocationDescription & Loc,AtomicOpValue & X,AtomicOpValue & V,AtomicOrdering AO)7724 OpenMPIRBuilder::createAtomicRead(const LocationDescription &Loc,
7725                                   AtomicOpValue &X, AtomicOpValue &V,
7726                                   AtomicOrdering AO) {
7727   if (!updateToLocation(Loc))
7728     return Loc.IP;
7729 
7730   assert(X.Var->getType()->isPointerTy() &&
7731          "OMP Atomic expects a pointer to target memory");
7732   Type *XElemTy = X.ElemTy;
7733   assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
7734           XElemTy->isPointerTy()) &&
7735          "OMP atomic read expected a scalar type");
7736 
7737   Value *XRead = nullptr;
7738 
7739   if (XElemTy->isIntegerTy()) {
7740     LoadInst *XLD =
7741         Builder.CreateLoad(XElemTy, X.Var, X.IsVolatile, "omp.atomic.read");
7742     XLD->setAtomic(AO);
7743     XRead = cast<Value>(XLD);
7744   } else {
7745     // We need to perform atomic op as integer
7746     IntegerType *IntCastTy =
7747         IntegerType::get(M.getContext(), XElemTy->getScalarSizeInBits());
7748     LoadInst *XLoad =
7749         Builder.CreateLoad(IntCastTy, X.Var, X.IsVolatile, "omp.atomic.load");
7750     XLoad->setAtomic(AO);
7751     if (XElemTy->isFloatingPointTy()) {
7752       XRead = Builder.CreateBitCast(XLoad, XElemTy, "atomic.flt.cast");
7753     } else {
7754       XRead = Builder.CreateIntToPtr(XLoad, XElemTy, "atomic.ptr.cast");
7755     }
7756   }
7757   checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Read);
7758   Builder.CreateStore(XRead, V.Var, V.IsVolatile);
7759   return Builder.saveIP();
7760 }
7761 
7762 OpenMPIRBuilder::InsertPointTy
createAtomicWrite(const LocationDescription & Loc,AtomicOpValue & X,Value * Expr,AtomicOrdering AO)7763 OpenMPIRBuilder::createAtomicWrite(const LocationDescription &Loc,
7764                                    AtomicOpValue &X, Value *Expr,
7765                                    AtomicOrdering AO) {
7766   if (!updateToLocation(Loc))
7767     return Loc.IP;
7768 
7769   assert(X.Var->getType()->isPointerTy() &&
7770          "OMP Atomic expects a pointer to target memory");
7771   Type *XElemTy = X.ElemTy;
7772   assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
7773           XElemTy->isPointerTy()) &&
7774          "OMP atomic write expected a scalar type");
7775 
7776   if (XElemTy->isIntegerTy()) {
7777     StoreInst *XSt = Builder.CreateStore(Expr, X.Var, X.IsVolatile);
7778     XSt->setAtomic(AO);
7779   } else {
7780     // We need to bitcast and perform atomic op as integers
7781     IntegerType *IntCastTy =
7782         IntegerType::get(M.getContext(), XElemTy->getScalarSizeInBits());
7783     Value *ExprCast =
7784         Builder.CreateBitCast(Expr, IntCastTy, "atomic.src.int.cast");
7785     StoreInst *XSt = Builder.CreateStore(ExprCast, X.Var, X.IsVolatile);
7786     XSt->setAtomic(AO);
7787   }
7788 
7789   checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Write);
7790   return Builder.saveIP();
7791 }
7792 
createAtomicUpdate(const LocationDescription & Loc,InsertPointTy AllocaIP,AtomicOpValue & X,Value * Expr,AtomicOrdering AO,AtomicRMWInst::BinOp RMWOp,AtomicUpdateCallbackTy & UpdateOp,bool IsXBinopExpr)7793 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicUpdate(
7794     const LocationDescription &Loc, InsertPointTy AllocaIP, AtomicOpValue &X,
7795     Value *Expr, AtomicOrdering AO, AtomicRMWInst::BinOp RMWOp,
7796     AtomicUpdateCallbackTy &UpdateOp, bool IsXBinopExpr) {
7797   assert(!isConflictIP(Loc.IP, AllocaIP) && "IPs must not be ambiguous");
7798   if (!updateToLocation(Loc))
7799     return Loc.IP;
7800 
7801   LLVM_DEBUG({
7802     Type *XTy = X.Var->getType();
7803     assert(XTy->isPointerTy() &&
7804            "OMP Atomic expects a pointer to target memory");
7805     Type *XElemTy = X.ElemTy;
7806     assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
7807             XElemTy->isPointerTy()) &&
7808            "OMP atomic update expected a scalar type");
7809     assert((RMWOp != AtomicRMWInst::Max) && (RMWOp != AtomicRMWInst::Min) &&
7810            (RMWOp != AtomicRMWInst::UMax) && (RMWOp != AtomicRMWInst::UMin) &&
7811            "OpenMP atomic does not support LT or GT operations");
7812   });
7813 
7814   emitAtomicUpdate(AllocaIP, X.Var, X.ElemTy, Expr, AO, RMWOp, UpdateOp,
7815                    X.IsVolatile, IsXBinopExpr);
7816   checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Update);
7817   return Builder.saveIP();
7818 }
7819 
7820 // FIXME: Duplicating AtomicExpand
emitRMWOpAsInstruction(Value * Src1,Value * Src2,AtomicRMWInst::BinOp RMWOp)7821 Value *OpenMPIRBuilder::emitRMWOpAsInstruction(Value *Src1, Value *Src2,
7822                                                AtomicRMWInst::BinOp RMWOp) {
7823   switch (RMWOp) {
7824   case AtomicRMWInst::Add:
7825     return Builder.CreateAdd(Src1, Src2);
7826   case AtomicRMWInst::Sub:
7827     return Builder.CreateSub(Src1, Src2);
7828   case AtomicRMWInst::And:
7829     return Builder.CreateAnd(Src1, Src2);
7830   case AtomicRMWInst::Nand:
7831     return Builder.CreateNeg(Builder.CreateAnd(Src1, Src2));
7832   case AtomicRMWInst::Or:
7833     return Builder.CreateOr(Src1, Src2);
7834   case AtomicRMWInst::Xor:
7835     return Builder.CreateXor(Src1, Src2);
7836   case AtomicRMWInst::Xchg:
7837   case AtomicRMWInst::FAdd:
7838   case AtomicRMWInst::FSub:
7839   case AtomicRMWInst::BAD_BINOP:
7840   case AtomicRMWInst::Max:
7841   case AtomicRMWInst::Min:
7842   case AtomicRMWInst::UMax:
7843   case AtomicRMWInst::UMin:
7844   case AtomicRMWInst::FMax:
7845   case AtomicRMWInst::FMin:
7846   case AtomicRMWInst::UIncWrap:
7847   case AtomicRMWInst::UDecWrap:
7848     llvm_unreachable("Unsupported atomic update operation");
7849   }
7850   llvm_unreachable("Unsupported atomic update operation");
7851 }
7852 
emitAtomicUpdate(InsertPointTy AllocaIP,Value * X,Type * XElemTy,Value * Expr,AtomicOrdering AO,AtomicRMWInst::BinOp RMWOp,AtomicUpdateCallbackTy & UpdateOp,bool VolatileX,bool IsXBinopExpr)7853 std::pair<Value *, Value *> OpenMPIRBuilder::emitAtomicUpdate(
7854     InsertPointTy AllocaIP, Value *X, Type *XElemTy, Value *Expr,
7855     AtomicOrdering AO, AtomicRMWInst::BinOp RMWOp,
7856     AtomicUpdateCallbackTy &UpdateOp, bool VolatileX, bool IsXBinopExpr) {
7857   // TODO: handle the case where XElemTy is not byte-sized or not a power of 2
7858   // or a complex datatype.
7859   bool emitRMWOp = false;
7860   switch (RMWOp) {
7861   case AtomicRMWInst::Add:
7862   case AtomicRMWInst::And:
7863   case AtomicRMWInst::Nand:
7864   case AtomicRMWInst::Or:
7865   case AtomicRMWInst::Xor:
7866   case AtomicRMWInst::Xchg:
7867     emitRMWOp = XElemTy;
7868     break;
7869   case AtomicRMWInst::Sub:
7870     emitRMWOp = (IsXBinopExpr && XElemTy);
7871     break;
7872   default:
7873     emitRMWOp = false;
7874   }
7875   emitRMWOp &= XElemTy->isIntegerTy();
7876 
7877   std::pair<Value *, Value *> Res;
7878   if (emitRMWOp) {
7879     Res.first = Builder.CreateAtomicRMW(RMWOp, X, Expr, llvm::MaybeAlign(), AO);
7880     // not needed except in case of postfix captures. Generate anyway for
7881     // consistency with the else part. Will be removed with any DCE pass.
7882     // AtomicRMWInst::Xchg does not have a coressponding instruction.
7883     if (RMWOp == AtomicRMWInst::Xchg)
7884       Res.second = Res.first;
7885     else
7886       Res.second = emitRMWOpAsInstruction(Res.first, Expr, RMWOp);
7887   } else {
7888     IntegerType *IntCastTy =
7889         IntegerType::get(M.getContext(), XElemTy->getScalarSizeInBits());
7890     LoadInst *OldVal =
7891         Builder.CreateLoad(IntCastTy, X, X->getName() + ".atomic.load");
7892     OldVal->setAtomic(AO);
7893     // CurBB
7894     // |     /---\
7895 		// ContBB    |
7896     // |     \---/
7897     // ExitBB
7898     BasicBlock *CurBB = Builder.GetInsertBlock();
7899     Instruction *CurBBTI = CurBB->getTerminator();
7900     CurBBTI = CurBBTI ? CurBBTI : Builder.CreateUnreachable();
7901     BasicBlock *ExitBB =
7902         CurBB->splitBasicBlock(CurBBTI, X->getName() + ".atomic.exit");
7903     BasicBlock *ContBB = CurBB->splitBasicBlock(CurBB->getTerminator(),
7904                                                 X->getName() + ".atomic.cont");
7905     ContBB->getTerminator()->eraseFromParent();
7906     Builder.restoreIP(AllocaIP);
7907     AllocaInst *NewAtomicAddr = Builder.CreateAlloca(XElemTy);
7908     NewAtomicAddr->setName(X->getName() + "x.new.val");
7909     Builder.SetInsertPoint(ContBB);
7910     llvm::PHINode *PHI = Builder.CreatePHI(OldVal->getType(), 2);
7911     PHI->addIncoming(OldVal, CurBB);
7912     bool IsIntTy = XElemTy->isIntegerTy();
7913     Value *OldExprVal = PHI;
7914     if (!IsIntTy) {
7915       if (XElemTy->isFloatingPointTy()) {
7916         OldExprVal = Builder.CreateBitCast(PHI, XElemTy,
7917                                            X->getName() + ".atomic.fltCast");
7918       } else {
7919         OldExprVal = Builder.CreateIntToPtr(PHI, XElemTy,
7920                                             X->getName() + ".atomic.ptrCast");
7921       }
7922     }
7923 
7924     Value *Upd = UpdateOp(OldExprVal, Builder);
7925     Builder.CreateStore(Upd, NewAtomicAddr);
7926     LoadInst *DesiredVal = Builder.CreateLoad(IntCastTy, NewAtomicAddr);
7927     AtomicOrdering Failure =
7928         llvm::AtomicCmpXchgInst::getStrongestFailureOrdering(AO);
7929     AtomicCmpXchgInst *Result = Builder.CreateAtomicCmpXchg(
7930         X, PHI, DesiredVal, llvm::MaybeAlign(), AO, Failure);
7931     Result->setVolatile(VolatileX);
7932     Value *PreviousVal = Builder.CreateExtractValue(Result, /*Idxs=*/0);
7933     Value *SuccessFailureVal = Builder.CreateExtractValue(Result, /*Idxs=*/1);
7934     PHI->addIncoming(PreviousVal, Builder.GetInsertBlock());
7935     Builder.CreateCondBr(SuccessFailureVal, ExitBB, ContBB);
7936 
7937     Res.first = OldExprVal;
7938     Res.second = Upd;
7939 
7940     // set Insertion point in exit block
7941     if (UnreachableInst *ExitTI =
7942             dyn_cast<UnreachableInst>(ExitBB->getTerminator())) {
7943       CurBBTI->eraseFromParent();
7944       Builder.SetInsertPoint(ExitBB);
7945     } else {
7946       Builder.SetInsertPoint(ExitTI);
7947     }
7948   }
7949 
7950   return Res;
7951 }
7952 
createAtomicCapture(const LocationDescription & Loc,InsertPointTy AllocaIP,AtomicOpValue & X,AtomicOpValue & V,Value * Expr,AtomicOrdering AO,AtomicRMWInst::BinOp RMWOp,AtomicUpdateCallbackTy & UpdateOp,bool UpdateExpr,bool IsPostfixUpdate,bool IsXBinopExpr)7953 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCapture(
7954     const LocationDescription &Loc, InsertPointTy AllocaIP, AtomicOpValue &X,
7955     AtomicOpValue &V, Value *Expr, AtomicOrdering AO,
7956     AtomicRMWInst::BinOp RMWOp, AtomicUpdateCallbackTy &UpdateOp,
7957     bool UpdateExpr, bool IsPostfixUpdate, bool IsXBinopExpr) {
7958   if (!updateToLocation(Loc))
7959     return Loc.IP;
7960 
7961   LLVM_DEBUG({
7962     Type *XTy = X.Var->getType();
7963     assert(XTy->isPointerTy() &&
7964            "OMP Atomic expects a pointer to target memory");
7965     Type *XElemTy = X.ElemTy;
7966     assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
7967             XElemTy->isPointerTy()) &&
7968            "OMP atomic capture expected a scalar type");
7969     assert((RMWOp != AtomicRMWInst::Max) && (RMWOp != AtomicRMWInst::Min) &&
7970            "OpenMP atomic does not support LT or GT operations");
7971   });
7972 
7973   // If UpdateExpr is 'x' updated with some `expr` not based on 'x',
7974   // 'x' is simply atomically rewritten with 'expr'.
7975   AtomicRMWInst::BinOp AtomicOp = (UpdateExpr ? RMWOp : AtomicRMWInst::Xchg);
7976   std::pair<Value *, Value *> Result =
7977       emitAtomicUpdate(AllocaIP, X.Var, X.ElemTy, Expr, AO, AtomicOp, UpdateOp,
7978                        X.IsVolatile, IsXBinopExpr);
7979 
7980   Value *CapturedVal = (IsPostfixUpdate ? Result.first : Result.second);
7981   Builder.CreateStore(CapturedVal, V.Var, V.IsVolatile);
7982 
7983   checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Capture);
7984   return Builder.saveIP();
7985 }
7986 
createAtomicCompare(const LocationDescription & Loc,AtomicOpValue & X,AtomicOpValue & V,AtomicOpValue & R,Value * E,Value * D,AtomicOrdering AO,omp::OMPAtomicCompareOp Op,bool IsXBinopExpr,bool IsPostfixUpdate,bool IsFailOnly)7987 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare(
7988     const LocationDescription &Loc, AtomicOpValue &X, AtomicOpValue &V,
7989     AtomicOpValue &R, Value *E, Value *D, AtomicOrdering AO,
7990     omp::OMPAtomicCompareOp Op, bool IsXBinopExpr, bool IsPostfixUpdate,
7991     bool IsFailOnly) {
7992 
7993   AtomicOrdering Failure = AtomicCmpXchgInst::getStrongestFailureOrdering(AO);
7994   return createAtomicCompare(Loc, X, V, R, E, D, AO, Op, IsXBinopExpr,
7995                              IsPostfixUpdate, IsFailOnly, Failure);
7996 }
7997 
createAtomicCompare(const LocationDescription & Loc,AtomicOpValue & X,AtomicOpValue & V,AtomicOpValue & R,Value * E,Value * D,AtomicOrdering AO,omp::OMPAtomicCompareOp Op,bool IsXBinopExpr,bool IsPostfixUpdate,bool IsFailOnly,AtomicOrdering Failure)7998 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare(
7999     const LocationDescription &Loc, AtomicOpValue &X, AtomicOpValue &V,
8000     AtomicOpValue &R, Value *E, Value *D, AtomicOrdering AO,
8001     omp::OMPAtomicCompareOp Op, bool IsXBinopExpr, bool IsPostfixUpdate,
8002     bool IsFailOnly, AtomicOrdering Failure) {
8003 
8004   if (!updateToLocation(Loc))
8005     return Loc.IP;
8006 
8007   assert(X.Var->getType()->isPointerTy() &&
8008          "OMP atomic expects a pointer to target memory");
8009   // compare capture
8010   if (V.Var) {
8011     assert(V.Var->getType()->isPointerTy() && "v.var must be of pointer type");
8012     assert(V.ElemTy == X.ElemTy && "x and v must be of same type");
8013   }
8014 
8015   bool IsInteger = E->getType()->isIntegerTy();
8016 
8017   if (Op == OMPAtomicCompareOp::EQ) {
8018     AtomicCmpXchgInst *Result = nullptr;
8019     if (!IsInteger) {
8020       IntegerType *IntCastTy =
8021           IntegerType::get(M.getContext(), X.ElemTy->getScalarSizeInBits());
8022       Value *EBCast = Builder.CreateBitCast(E, IntCastTy);
8023       Value *DBCast = Builder.CreateBitCast(D, IntCastTy);
8024       Result = Builder.CreateAtomicCmpXchg(X.Var, EBCast, DBCast, MaybeAlign(),
8025                                            AO, Failure);
8026     } else {
8027       Result =
8028           Builder.CreateAtomicCmpXchg(X.Var, E, D, MaybeAlign(), AO, Failure);
8029     }
8030 
8031     if (V.Var) {
8032       Value *OldValue = Builder.CreateExtractValue(Result, /*Idxs=*/0);
8033       if (!IsInteger)
8034         OldValue = Builder.CreateBitCast(OldValue, X.ElemTy);
8035       assert(OldValue->getType() == V.ElemTy &&
8036              "OldValue and V must be of same type");
8037       if (IsPostfixUpdate) {
8038         Builder.CreateStore(OldValue, V.Var, V.IsVolatile);
8039       } else {
8040         Value *SuccessOrFail = Builder.CreateExtractValue(Result, /*Idxs=*/1);
8041         if (IsFailOnly) {
8042           // CurBB----
8043           //   |     |
8044           //   v     |
8045           // ContBB  |
8046           //   |     |
8047           //   v     |
8048           // ExitBB <-
8049           //
8050           // where ContBB only contains the store of old value to 'v'.
8051           BasicBlock *CurBB = Builder.GetInsertBlock();
8052           Instruction *CurBBTI = CurBB->getTerminator();
8053           CurBBTI = CurBBTI ? CurBBTI : Builder.CreateUnreachable();
8054           BasicBlock *ExitBB = CurBB->splitBasicBlock(
8055               CurBBTI, X.Var->getName() + ".atomic.exit");
8056           BasicBlock *ContBB = CurBB->splitBasicBlock(
8057               CurBB->getTerminator(), X.Var->getName() + ".atomic.cont");
8058           ContBB->getTerminator()->eraseFromParent();
8059           CurBB->getTerminator()->eraseFromParent();
8060 
8061           Builder.CreateCondBr(SuccessOrFail, ExitBB, ContBB);
8062 
8063           Builder.SetInsertPoint(ContBB);
8064           Builder.CreateStore(OldValue, V.Var);
8065           Builder.CreateBr(ExitBB);
8066 
8067           if (UnreachableInst *ExitTI =
8068                   dyn_cast<UnreachableInst>(ExitBB->getTerminator())) {
8069             CurBBTI->eraseFromParent();
8070             Builder.SetInsertPoint(ExitBB);
8071           } else {
8072             Builder.SetInsertPoint(ExitTI);
8073           }
8074         } else {
8075           Value *CapturedValue =
8076               Builder.CreateSelect(SuccessOrFail, E, OldValue);
8077           Builder.CreateStore(CapturedValue, V.Var, V.IsVolatile);
8078         }
8079       }
8080     }
8081     // The comparison result has to be stored.
8082     if (R.Var) {
8083       assert(R.Var->getType()->isPointerTy() &&
8084              "r.var must be of pointer type");
8085       assert(R.ElemTy->isIntegerTy() && "r must be of integral type");
8086 
8087       Value *SuccessFailureVal = Builder.CreateExtractValue(Result, /*Idxs=*/1);
8088       Value *ResultCast = R.IsSigned
8089                               ? Builder.CreateSExt(SuccessFailureVal, R.ElemTy)
8090                               : Builder.CreateZExt(SuccessFailureVal, R.ElemTy);
8091       Builder.CreateStore(ResultCast, R.Var, R.IsVolatile);
8092     }
8093   } else {
8094     assert((Op == OMPAtomicCompareOp::MAX || Op == OMPAtomicCompareOp::MIN) &&
8095            "Op should be either max or min at this point");
8096     assert(!IsFailOnly && "IsFailOnly is only valid when the comparison is ==");
8097 
8098     // Reverse the ordop as the OpenMP forms are different from LLVM forms.
8099     // Let's take max as example.
8100     // OpenMP form:
8101     // x = x > expr ? expr : x;
8102     // LLVM form:
8103     // *ptr = *ptr > val ? *ptr : val;
8104     // We need to transform to LLVM form.
8105     // x = x <= expr ? x : expr;
8106     AtomicRMWInst::BinOp NewOp;
8107     if (IsXBinopExpr) {
8108       if (IsInteger) {
8109         if (X.IsSigned)
8110           NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::Min
8111                                                 : AtomicRMWInst::Max;
8112         else
8113           NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::UMin
8114                                                 : AtomicRMWInst::UMax;
8115       } else {
8116         NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::FMin
8117                                               : AtomicRMWInst::FMax;
8118       }
8119     } else {
8120       if (IsInteger) {
8121         if (X.IsSigned)
8122           NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::Max
8123                                                 : AtomicRMWInst::Min;
8124         else
8125           NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::UMax
8126                                                 : AtomicRMWInst::UMin;
8127       } else {
8128         NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::FMax
8129                                               : AtomicRMWInst::FMin;
8130       }
8131     }
8132 
8133     AtomicRMWInst *OldValue =
8134         Builder.CreateAtomicRMW(NewOp, X.Var, E, MaybeAlign(), AO);
8135     if (V.Var) {
8136       Value *CapturedValue = nullptr;
8137       if (IsPostfixUpdate) {
8138         CapturedValue = OldValue;
8139       } else {
8140         CmpInst::Predicate Pred;
8141         switch (NewOp) {
8142         case AtomicRMWInst::Max:
8143           Pred = CmpInst::ICMP_SGT;
8144           break;
8145         case AtomicRMWInst::UMax:
8146           Pred = CmpInst::ICMP_UGT;
8147           break;
8148         case AtomicRMWInst::FMax:
8149           Pred = CmpInst::FCMP_OGT;
8150           break;
8151         case AtomicRMWInst::Min:
8152           Pred = CmpInst::ICMP_SLT;
8153           break;
8154         case AtomicRMWInst::UMin:
8155           Pred = CmpInst::ICMP_ULT;
8156           break;
8157         case AtomicRMWInst::FMin:
8158           Pred = CmpInst::FCMP_OLT;
8159           break;
8160         default:
8161           llvm_unreachable("unexpected comparison op");
8162         }
8163         Value *NonAtomicCmp = Builder.CreateCmp(Pred, OldValue, E);
8164         CapturedValue = Builder.CreateSelect(NonAtomicCmp, E, OldValue);
8165       }
8166       Builder.CreateStore(CapturedValue, V.Var, V.IsVolatile);
8167     }
8168   }
8169 
8170   checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Compare);
8171 
8172   return Builder.saveIP();
8173 }
8174 
8175 OpenMPIRBuilder::InsertPointTy
createTeams(const LocationDescription & Loc,BodyGenCallbackTy BodyGenCB,Value * NumTeamsLower,Value * NumTeamsUpper,Value * ThreadLimit,Value * IfExpr)8176 OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
8177                              BodyGenCallbackTy BodyGenCB, Value *NumTeamsLower,
8178                              Value *NumTeamsUpper, Value *ThreadLimit,
8179                              Value *IfExpr) {
8180   if (!updateToLocation(Loc))
8181     return InsertPointTy();
8182 
8183   uint32_t SrcLocStrSize;
8184   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
8185   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
8186   Function *CurrentFunction = Builder.GetInsertBlock()->getParent();
8187 
8188   // Outer allocation basicblock is the entry block of the current function.
8189   BasicBlock &OuterAllocaBB = CurrentFunction->getEntryBlock();
8190   if (&OuterAllocaBB == Builder.GetInsertBlock()) {
8191     BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "teams.entry");
8192     Builder.SetInsertPoint(BodyBB, BodyBB->begin());
8193   }
8194 
8195   // The current basic block is split into four basic blocks. After outlining,
8196   // they will be mapped as follows:
8197   // ```
8198   // def current_fn() {
8199   //   current_basic_block:
8200   //     br label %teams.exit
8201   //   teams.exit:
8202   //     ; instructions after teams
8203   // }
8204   //
8205   // def outlined_fn() {
8206   //   teams.alloca:
8207   //     br label %teams.body
8208   //   teams.body:
8209   //     ; instructions within teams body
8210   // }
8211   // ```
8212   BasicBlock *ExitBB = splitBB(Builder, /*CreateBranch=*/true, "teams.exit");
8213   BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "teams.body");
8214   BasicBlock *AllocaBB =
8215       splitBB(Builder, /*CreateBranch=*/true, "teams.alloca");
8216 
8217   bool SubClausesPresent =
8218       (NumTeamsLower || NumTeamsUpper || ThreadLimit || IfExpr);
8219   // Push num_teams
8220   if (!Config.isTargetDevice() && SubClausesPresent) {
8221     assert((NumTeamsLower == nullptr || NumTeamsUpper != nullptr) &&
8222            "if lowerbound is non-null, then upperbound must also be non-null "
8223            "for bounds on num_teams");
8224 
8225     if (NumTeamsUpper == nullptr)
8226       NumTeamsUpper = Builder.getInt32(0);
8227 
8228     if (NumTeamsLower == nullptr)
8229       NumTeamsLower = NumTeamsUpper;
8230 
8231     if (IfExpr) {
8232       assert(IfExpr->getType()->isIntegerTy() &&
8233              "argument to if clause must be an integer value");
8234 
8235       // upper = ifexpr ? upper : 1
8236       if (IfExpr->getType() != Int1)
8237         IfExpr = Builder.CreateICmpNE(IfExpr,
8238                                       ConstantInt::get(IfExpr->getType(), 0));
8239       NumTeamsUpper = Builder.CreateSelect(
8240           IfExpr, NumTeamsUpper, Builder.getInt32(1), "numTeamsUpper");
8241 
8242       // lower = ifexpr ? lower : 1
8243       NumTeamsLower = Builder.CreateSelect(
8244           IfExpr, NumTeamsLower, Builder.getInt32(1), "numTeamsLower");
8245     }
8246 
8247     if (ThreadLimit == nullptr)
8248       ThreadLimit = Builder.getInt32(0);
8249 
8250     Value *ThreadNum = getOrCreateThreadID(Ident);
8251     Builder.CreateCall(
8252         getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_num_teams_51),
8253         {Ident, ThreadNum, NumTeamsLower, NumTeamsUpper, ThreadLimit});
8254   }
8255   // Generate the body of teams.
8256   InsertPointTy AllocaIP(AllocaBB, AllocaBB->begin());
8257   InsertPointTy CodeGenIP(BodyBB, BodyBB->begin());
8258   BodyGenCB(AllocaIP, CodeGenIP);
8259 
8260   OutlineInfo OI;
8261   OI.EntryBB = AllocaBB;
8262   OI.ExitBB = ExitBB;
8263   OI.OuterAllocaBB = &OuterAllocaBB;
8264 
8265   // Insert fake values for global tid and bound tid.
8266   SmallVector<Instruction *, 8> ToBeDeleted;
8267   InsertPointTy OuterAllocaIP(&OuterAllocaBB, OuterAllocaBB.begin());
8268   OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
8269       Builder, OuterAllocaIP, ToBeDeleted, AllocaIP, "gid", true));
8270   OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
8271       Builder, OuterAllocaIP, ToBeDeleted, AllocaIP, "tid", true));
8272 
8273   auto HostPostOutlineCB = [this, Ident,
8274                             ToBeDeleted](Function &OutlinedFn) mutable {
8275     // The stale call instruction will be replaced with a new call instruction
8276     // for runtime call with the outlined function.
8277 
8278     assert(OutlinedFn.getNumUses() == 1 &&
8279            "there must be a single user for the outlined function");
8280     CallInst *StaleCI = cast<CallInst>(OutlinedFn.user_back());
8281     ToBeDeleted.push_back(StaleCI);
8282 
8283     assert((OutlinedFn.arg_size() == 2 || OutlinedFn.arg_size() == 3) &&
8284            "Outlined function must have two or three arguments only");
8285 
8286     bool HasShared = OutlinedFn.arg_size() == 3;
8287 
8288     OutlinedFn.getArg(0)->setName("global.tid.ptr");
8289     OutlinedFn.getArg(1)->setName("bound.tid.ptr");
8290     if (HasShared)
8291       OutlinedFn.getArg(2)->setName("data");
8292 
8293     // Call to the runtime function for teams in the current function.
8294     assert(StaleCI && "Error while outlining - no CallInst user found for the "
8295                       "outlined function.");
8296     Builder.SetInsertPoint(StaleCI);
8297     SmallVector<Value *> Args = {
8298         Ident, Builder.getInt32(StaleCI->arg_size() - 2), &OutlinedFn};
8299     if (HasShared)
8300       Args.push_back(StaleCI->getArgOperand(2));
8301     Builder.CreateCall(getOrCreateRuntimeFunctionPtr(
8302                            omp::RuntimeFunction::OMPRTL___kmpc_fork_teams),
8303                        Args);
8304 
8305     llvm::for_each(llvm::reverse(ToBeDeleted),
8306                    [](Instruction *I) { I->eraseFromParent(); });
8307 
8308   };
8309 
8310   if (!Config.isTargetDevice())
8311     OI.PostOutlineCB = HostPostOutlineCB;
8312 
8313   addOutlineInfo(std::move(OI));
8314 
8315   Builder.SetInsertPoint(ExitBB, ExitBB->begin());
8316 
8317   return Builder.saveIP();
8318 }
8319 
8320 GlobalVariable *
createOffloadMapnames(SmallVectorImpl<llvm::Constant * > & Names,std::string VarName)8321 OpenMPIRBuilder::createOffloadMapnames(SmallVectorImpl<llvm::Constant *> &Names,
8322                                        std::string VarName) {
8323   llvm::Constant *MapNamesArrayInit = llvm::ConstantArray::get(
8324       llvm::ArrayType::get(llvm::PointerType::getUnqual(M.getContext()),
8325                            Names.size()),
8326       Names);
8327   auto *MapNamesArrayGlobal = new llvm::GlobalVariable(
8328       M, MapNamesArrayInit->getType(),
8329       /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, MapNamesArrayInit,
8330       VarName);
8331   return MapNamesArrayGlobal;
8332 }
8333 
8334 // Create all simple and struct types exposed by the runtime and remember
8335 // the llvm::PointerTypes of them for easy access later.
initializeTypes(Module & M)8336 void OpenMPIRBuilder::initializeTypes(Module &M) {
8337   LLVMContext &Ctx = M.getContext();
8338   StructType *T;
8339 #define OMP_TYPE(VarName, InitValue) VarName = InitValue;
8340 #define OMP_ARRAY_TYPE(VarName, ElemTy, ArraySize)                             \
8341   VarName##Ty = ArrayType::get(ElemTy, ArraySize);                             \
8342   VarName##PtrTy = PointerType::getUnqual(VarName##Ty);
8343 #define OMP_FUNCTION_TYPE(VarName, IsVarArg, ReturnType, ...)                  \
8344   VarName = FunctionType::get(ReturnType, {__VA_ARGS__}, IsVarArg);            \
8345   VarName##Ptr = PointerType::getUnqual(VarName);
8346 #define OMP_STRUCT_TYPE(VarName, StructName, Packed, ...)                      \
8347   T = StructType::getTypeByName(Ctx, StructName);                              \
8348   if (!T)                                                                      \
8349     T = StructType::create(Ctx, {__VA_ARGS__}, StructName, Packed);            \
8350   VarName = T;                                                                 \
8351   VarName##Ptr = PointerType::getUnqual(T);
8352 #include "llvm/Frontend/OpenMP/OMPKinds.def"
8353 }
8354 
collectBlocks(SmallPtrSetImpl<BasicBlock * > & BlockSet,SmallVectorImpl<BasicBlock * > & BlockVector)8355 void OpenMPIRBuilder::OutlineInfo::collectBlocks(
8356     SmallPtrSetImpl<BasicBlock *> &BlockSet,
8357     SmallVectorImpl<BasicBlock *> &BlockVector) {
8358   SmallVector<BasicBlock *, 32> Worklist;
8359   BlockSet.insert(EntryBB);
8360   BlockSet.insert(ExitBB);
8361 
8362   Worklist.push_back(EntryBB);
8363   while (!Worklist.empty()) {
8364     BasicBlock *BB = Worklist.pop_back_val();
8365     BlockVector.push_back(BB);
8366     for (BasicBlock *SuccBB : successors(BB))
8367       if (BlockSet.insert(SuccBB).second)
8368         Worklist.push_back(SuccBB);
8369   }
8370 }
8371 
createOffloadEntry(Constant * ID,Constant * Addr,uint64_t Size,int32_t Flags,GlobalValue::LinkageTypes,StringRef Name)8372 void OpenMPIRBuilder::createOffloadEntry(Constant *ID, Constant *Addr,
8373                                          uint64_t Size, int32_t Flags,
8374                                          GlobalValue::LinkageTypes,
8375                                          StringRef Name) {
8376   if (!Config.isGPU()) {
8377     llvm::offloading::emitOffloadingEntry(
8378         M, ID, Name.empty() ? Addr->getName() : Name, Size, Flags, /*Data=*/0,
8379         "omp_offloading_entries");
8380     return;
8381   }
8382   // TODO: Add support for global variables on the device after declare target
8383   // support.
8384   Function *Fn = dyn_cast<Function>(Addr);
8385   if (!Fn)
8386     return;
8387 
8388   Module &M = *(Fn->getParent());
8389   LLVMContext &Ctx = M.getContext();
8390 
8391   // Get "nvvm.annotations" metadata node.
8392   NamedMDNode *MD = M.getOrInsertNamedMetadata("nvvm.annotations");
8393 
8394   Metadata *MDVals[] = {
8395       ConstantAsMetadata::get(Fn), MDString::get(Ctx, "kernel"),
8396       ConstantAsMetadata::get(ConstantInt::get(Type::getInt32Ty(Ctx), 1))};
8397   // Append metadata to nvvm.annotations.
8398   MD->addOperand(MDNode::get(Ctx, MDVals));
8399 
8400   // Add a function attribute for the kernel.
8401   Fn->addFnAttr(Attribute::get(Ctx, "kernel"));
8402   if (T.isAMDGCN())
8403     Fn->addFnAttr("uniform-work-group-size", "true");
8404   Fn->addFnAttr(Attribute::MustProgress);
8405 }
8406 
8407 // We only generate metadata for function that contain target regions.
createOffloadEntriesAndInfoMetadata(EmitMetadataErrorReportFunctionTy & ErrorFn)8408 void OpenMPIRBuilder::createOffloadEntriesAndInfoMetadata(
8409     EmitMetadataErrorReportFunctionTy &ErrorFn) {
8410 
8411   // If there are no entries, we don't need to do anything.
8412   if (OffloadInfoManager.empty())
8413     return;
8414 
8415   LLVMContext &C = M.getContext();
8416   SmallVector<std::pair<const OffloadEntriesInfoManager::OffloadEntryInfo *,
8417                         TargetRegionEntryInfo>,
8418               16>
8419       OrderedEntries(OffloadInfoManager.size());
8420 
8421   // Auxiliary methods to create metadata values and strings.
8422   auto &&GetMDInt = [this](unsigned V) {
8423     return ConstantAsMetadata::get(ConstantInt::get(Builder.getInt32Ty(), V));
8424   };
8425 
8426   auto &&GetMDString = [&C](StringRef V) { return MDString::get(C, V); };
8427 
8428   // Create the offloading info metadata node.
8429   NamedMDNode *MD = M.getOrInsertNamedMetadata("omp_offload.info");
8430   auto &&TargetRegionMetadataEmitter =
8431       [&C, MD, &OrderedEntries, &GetMDInt, &GetMDString](
8432           const TargetRegionEntryInfo &EntryInfo,
8433           const OffloadEntriesInfoManager::OffloadEntryInfoTargetRegion &E) {
8434         // Generate metadata for target regions. Each entry of this metadata
8435         // contains:
8436         // - Entry 0 -> Kind of this type of metadata (0).
8437         // - Entry 1 -> Device ID of the file where the entry was identified.
8438         // - Entry 2 -> File ID of the file where the entry was identified.
8439         // - Entry 3 -> Mangled name of the function where the entry was
8440         // identified.
8441         // - Entry 4 -> Line in the file where the entry was identified.
8442         // - Entry 5 -> Count of regions at this DeviceID/FilesID/Line.
8443         // - Entry 6 -> Order the entry was created.
8444         // The first element of the metadata node is the kind.
8445         Metadata *Ops[] = {
8446             GetMDInt(E.getKind()),      GetMDInt(EntryInfo.DeviceID),
8447             GetMDInt(EntryInfo.FileID), GetMDString(EntryInfo.ParentName),
8448             GetMDInt(EntryInfo.Line),   GetMDInt(EntryInfo.Count),
8449             GetMDInt(E.getOrder())};
8450 
8451         // Save this entry in the right position of the ordered entries array.
8452         OrderedEntries[E.getOrder()] = std::make_pair(&E, EntryInfo);
8453 
8454         // Add metadata to the named metadata node.
8455         MD->addOperand(MDNode::get(C, Ops));
8456       };
8457 
8458   OffloadInfoManager.actOnTargetRegionEntriesInfo(TargetRegionMetadataEmitter);
8459 
8460   // Create function that emits metadata for each device global variable entry;
8461   auto &&DeviceGlobalVarMetadataEmitter =
8462       [&C, &OrderedEntries, &GetMDInt, &GetMDString, MD](
8463           StringRef MangledName,
8464           const OffloadEntriesInfoManager::OffloadEntryInfoDeviceGlobalVar &E) {
8465         // Generate metadata for global variables. Each entry of this metadata
8466         // contains:
8467         // - Entry 0 -> Kind of this type of metadata (1).
8468         // - Entry 1 -> Mangled name of the variable.
8469         // - Entry 2 -> Declare target kind.
8470         // - Entry 3 -> Order the entry was created.
8471         // The first element of the metadata node is the kind.
8472         Metadata *Ops[] = {GetMDInt(E.getKind()), GetMDString(MangledName),
8473                            GetMDInt(E.getFlags()), GetMDInt(E.getOrder())};
8474 
8475         // Save this entry in the right position of the ordered entries array.
8476         TargetRegionEntryInfo varInfo(MangledName, 0, 0, 0);
8477         OrderedEntries[E.getOrder()] = std::make_pair(&E, varInfo);
8478 
8479         // Add metadata to the named metadata node.
8480         MD->addOperand(MDNode::get(C, Ops));
8481       };
8482 
8483   OffloadInfoManager.actOnDeviceGlobalVarEntriesInfo(
8484       DeviceGlobalVarMetadataEmitter);
8485 
8486   for (const auto &E : OrderedEntries) {
8487     assert(E.first && "All ordered entries must exist!");
8488     if (const auto *CE =
8489             dyn_cast<OffloadEntriesInfoManager::OffloadEntryInfoTargetRegion>(
8490                 E.first)) {
8491       if (!CE->getID() || !CE->getAddress()) {
8492         // Do not blame the entry if the parent funtion is not emitted.
8493         TargetRegionEntryInfo EntryInfo = E.second;
8494         StringRef FnName = EntryInfo.ParentName;
8495         if (!M.getNamedValue(FnName))
8496           continue;
8497         ErrorFn(EMIT_MD_TARGET_REGION_ERROR, EntryInfo);
8498         continue;
8499       }
8500       createOffloadEntry(CE->getID(), CE->getAddress(),
8501                          /*Size=*/0, CE->getFlags(),
8502                          GlobalValue::WeakAnyLinkage);
8503     } else if (const auto *CE = dyn_cast<
8504                    OffloadEntriesInfoManager::OffloadEntryInfoDeviceGlobalVar>(
8505                    E.first)) {
8506       OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind Flags =
8507           static_cast<OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind>(
8508               CE->getFlags());
8509       switch (Flags) {
8510       case OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter:
8511       case OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo:
8512         if (Config.isTargetDevice() && Config.hasRequiresUnifiedSharedMemory())
8513           continue;
8514         if (!CE->getAddress()) {
8515           ErrorFn(EMIT_MD_DECLARE_TARGET_ERROR, E.second);
8516           continue;
8517         }
8518         // The vaiable has no definition - no need to add the entry.
8519         if (CE->getVarSize() == 0)
8520           continue;
8521         break;
8522       case OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink:
8523         assert(((Config.isTargetDevice() && !CE->getAddress()) ||
8524                 (!Config.isTargetDevice() && CE->getAddress())) &&
8525                "Declaret target link address is set.");
8526         if (Config.isTargetDevice())
8527           continue;
8528         if (!CE->getAddress()) {
8529           ErrorFn(EMIT_MD_GLOBAL_VAR_LINK_ERROR, TargetRegionEntryInfo());
8530           continue;
8531         }
8532         break;
8533       default:
8534         break;
8535       }
8536 
8537       // Hidden or internal symbols on the device are not externally visible.
8538       // We should not attempt to register them by creating an offloading
8539       // entry. Indirect variables are handled separately on the device.
8540       if (auto *GV = dyn_cast<GlobalValue>(CE->getAddress()))
8541         if ((GV->hasLocalLinkage() || GV->hasHiddenVisibility()) &&
8542             Flags != OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirect)
8543           continue;
8544 
8545       // Indirect globals need to use a special name that doesn't match the name
8546       // of the associated host global.
8547       if (Flags == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirect)
8548         createOffloadEntry(CE->getAddress(), CE->getAddress(), CE->getVarSize(),
8549                            Flags, CE->getLinkage(), CE->getVarName());
8550       else
8551         createOffloadEntry(CE->getAddress(), CE->getAddress(), CE->getVarSize(),
8552                            Flags, CE->getLinkage());
8553 
8554     } else {
8555       llvm_unreachable("Unsupported entry kind.");
8556     }
8557   }
8558 
8559   // Emit requires directive globals to a special entry so the runtime can
8560   // register them when the device image is loaded.
8561   // TODO: This reduces the offloading entries to a 32-bit integer. Offloading
8562   //       entries should be redesigned to better suit this use-case.
8563   if (Config.hasRequiresFlags() && !Config.isTargetDevice())
8564     offloading::emitOffloadingEntry(
8565         M, Constant::getNullValue(PointerType::getUnqual(M.getContext())),
8566         /*Name=*/"",
8567         /*Size=*/0, OffloadEntriesInfoManager::OMPTargetGlobalRegisterRequires,
8568         Config.getRequiresFlags(), "omp_offloading_entries");
8569 }
8570 
getTargetRegionEntryFnName(SmallVectorImpl<char> & Name,StringRef ParentName,unsigned DeviceID,unsigned FileID,unsigned Line,unsigned Count)8571 void TargetRegionEntryInfo::getTargetRegionEntryFnName(
8572     SmallVectorImpl<char> &Name, StringRef ParentName, unsigned DeviceID,
8573     unsigned FileID, unsigned Line, unsigned Count) {
8574   raw_svector_ostream OS(Name);
8575   OS << "__omp_offloading" << llvm::format("_%x", DeviceID)
8576      << llvm::format("_%x_", FileID) << ParentName << "_l" << Line;
8577   if (Count)
8578     OS << "_" << Count;
8579 }
8580 
getTargetRegionEntryFnName(SmallVectorImpl<char> & Name,const TargetRegionEntryInfo & EntryInfo)8581 void OffloadEntriesInfoManager::getTargetRegionEntryFnName(
8582     SmallVectorImpl<char> &Name, const TargetRegionEntryInfo &EntryInfo) {
8583   unsigned NewCount = getTargetRegionEntryInfoCount(EntryInfo);
8584   TargetRegionEntryInfo::getTargetRegionEntryFnName(
8585       Name, EntryInfo.ParentName, EntryInfo.DeviceID, EntryInfo.FileID,
8586       EntryInfo.Line, NewCount);
8587 }
8588 
8589 TargetRegionEntryInfo
getTargetEntryUniqueInfo(FileIdentifierInfoCallbackTy CallBack,StringRef ParentName)8590 OpenMPIRBuilder::getTargetEntryUniqueInfo(FileIdentifierInfoCallbackTy CallBack,
8591                                           StringRef ParentName) {
8592   sys::fs::UniqueID ID;
8593   auto FileIDInfo = CallBack();
8594   if (auto EC = sys::fs::getUniqueID(std::get<0>(FileIDInfo), ID)) {
8595     report_fatal_error(("Unable to get unique ID for file, during "
8596                         "getTargetEntryUniqueInfo, error message: " +
8597                         EC.message())
8598                            .c_str());
8599   }
8600 
8601   return TargetRegionEntryInfo(ParentName, ID.getDevice(), ID.getFile(),
8602                                std::get<1>(FileIDInfo));
8603 }
8604 
getFlagMemberOffset()8605 unsigned OpenMPIRBuilder::getFlagMemberOffset() {
8606   unsigned Offset = 0;
8607   for (uint64_t Remain =
8608            static_cast<std::underlying_type_t<omp::OpenMPOffloadMappingFlags>>(
8609                omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF);
8610        !(Remain & 1); Remain = Remain >> 1)
8611     Offset++;
8612   return Offset;
8613 }
8614 
8615 omp::OpenMPOffloadMappingFlags
getMemberOfFlag(unsigned Position)8616 OpenMPIRBuilder::getMemberOfFlag(unsigned Position) {
8617   // Rotate by getFlagMemberOffset() bits.
8618   return static_cast<omp::OpenMPOffloadMappingFlags>(((uint64_t)Position + 1)
8619                                                      << getFlagMemberOffset());
8620 }
8621 
setCorrectMemberOfFlag(omp::OpenMPOffloadMappingFlags & Flags,omp::OpenMPOffloadMappingFlags MemberOfFlag)8622 void OpenMPIRBuilder::setCorrectMemberOfFlag(
8623     omp::OpenMPOffloadMappingFlags &Flags,
8624     omp::OpenMPOffloadMappingFlags MemberOfFlag) {
8625   // If the entry is PTR_AND_OBJ but has not been marked with the special
8626   // placeholder value 0xFFFF in the MEMBER_OF field, then it should not be
8627   // marked as MEMBER_OF.
8628   if (static_cast<std::underlying_type_t<omp::OpenMPOffloadMappingFlags>>(
8629           Flags & omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ) &&
8630       static_cast<std::underlying_type_t<omp::OpenMPOffloadMappingFlags>>(
8631           (Flags & omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF) !=
8632           omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF))
8633     return;
8634 
8635   // Reset the placeholder value to prepare the flag for the assignment of the
8636   // proper MEMBER_OF value.
8637   Flags &= ~omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
8638   Flags |= MemberOfFlag;
8639 }
8640 
getAddrOfDeclareTargetVar(OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind CaptureClause,OffloadEntriesInfoManager::OMPTargetDeviceClauseKind DeviceClause,bool IsDeclaration,bool IsExternallyVisible,TargetRegionEntryInfo EntryInfo,StringRef MangledName,std::vector<GlobalVariable * > & GeneratedRefs,bool OpenMPSIMD,std::vector<Triple> TargetTriple,Type * LlvmPtrTy,std::function<Constant * ()> GlobalInitializer,std::function<GlobalValue::LinkageTypes ()> VariableLinkage)8641 Constant *OpenMPIRBuilder::getAddrOfDeclareTargetVar(
8642     OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind CaptureClause,
8643     OffloadEntriesInfoManager::OMPTargetDeviceClauseKind DeviceClause,
8644     bool IsDeclaration, bool IsExternallyVisible,
8645     TargetRegionEntryInfo EntryInfo, StringRef MangledName,
8646     std::vector<GlobalVariable *> &GeneratedRefs, bool OpenMPSIMD,
8647     std::vector<Triple> TargetTriple, Type *LlvmPtrTy,
8648     std::function<Constant *()> GlobalInitializer,
8649     std::function<GlobalValue::LinkageTypes()> VariableLinkage) {
8650   // TODO: convert this to utilise the IRBuilder Config rather than
8651   // a passed down argument.
8652   if (OpenMPSIMD)
8653     return nullptr;
8654 
8655   if (CaptureClause == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink ||
8656       ((CaptureClause == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo ||
8657         CaptureClause ==
8658             OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter) &&
8659        Config.hasRequiresUnifiedSharedMemory())) {
8660     SmallString<64> PtrName;
8661     {
8662       raw_svector_ostream OS(PtrName);
8663       OS << MangledName;
8664       if (!IsExternallyVisible)
8665         OS << format("_%x", EntryInfo.FileID);
8666       OS << "_decl_tgt_ref_ptr";
8667     }
8668 
8669     Value *Ptr = M.getNamedValue(PtrName);
8670 
8671     if (!Ptr) {
8672       GlobalValue *GlobalValue = M.getNamedValue(MangledName);
8673       Ptr = getOrCreateInternalVariable(LlvmPtrTy, PtrName);
8674 
8675       auto *GV = cast<GlobalVariable>(Ptr);
8676       GV->setLinkage(GlobalValue::WeakAnyLinkage);
8677 
8678       if (!Config.isTargetDevice()) {
8679         if (GlobalInitializer)
8680           GV->setInitializer(GlobalInitializer());
8681         else
8682           GV->setInitializer(GlobalValue);
8683       }
8684 
8685       registerTargetGlobalVariable(
8686           CaptureClause, DeviceClause, IsDeclaration, IsExternallyVisible,
8687           EntryInfo, MangledName, GeneratedRefs, OpenMPSIMD, TargetTriple,
8688           GlobalInitializer, VariableLinkage, LlvmPtrTy, cast<Constant>(Ptr));
8689     }
8690 
8691     return cast<Constant>(Ptr);
8692   }
8693 
8694   return nullptr;
8695 }
8696 
registerTargetGlobalVariable(OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind CaptureClause,OffloadEntriesInfoManager::OMPTargetDeviceClauseKind DeviceClause,bool IsDeclaration,bool IsExternallyVisible,TargetRegionEntryInfo EntryInfo,StringRef MangledName,std::vector<GlobalVariable * > & GeneratedRefs,bool OpenMPSIMD,std::vector<Triple> TargetTriple,std::function<Constant * ()> GlobalInitializer,std::function<GlobalValue::LinkageTypes ()> VariableLinkage,Type * LlvmPtrTy,Constant * Addr)8697 void OpenMPIRBuilder::registerTargetGlobalVariable(
8698     OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind CaptureClause,
8699     OffloadEntriesInfoManager::OMPTargetDeviceClauseKind DeviceClause,
8700     bool IsDeclaration, bool IsExternallyVisible,
8701     TargetRegionEntryInfo EntryInfo, StringRef MangledName,
8702     std::vector<GlobalVariable *> &GeneratedRefs, bool OpenMPSIMD,
8703     std::vector<Triple> TargetTriple,
8704     std::function<Constant *()> GlobalInitializer,
8705     std::function<GlobalValue::LinkageTypes()> VariableLinkage, Type *LlvmPtrTy,
8706     Constant *Addr) {
8707   if (DeviceClause != OffloadEntriesInfoManager::OMPTargetDeviceClauseAny ||
8708       (TargetTriple.empty() && !Config.isTargetDevice()))
8709     return;
8710 
8711   OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind Flags;
8712   StringRef VarName;
8713   int64_t VarSize;
8714   GlobalValue::LinkageTypes Linkage;
8715 
8716   if ((CaptureClause == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo ||
8717        CaptureClause ==
8718            OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter) &&
8719       !Config.hasRequiresUnifiedSharedMemory()) {
8720     Flags = OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
8721     VarName = MangledName;
8722     GlobalValue *LlvmVal = M.getNamedValue(VarName);
8723 
8724     if (!IsDeclaration)
8725       VarSize = divideCeil(
8726           M.getDataLayout().getTypeSizeInBits(LlvmVal->getValueType()), 8);
8727     else
8728       VarSize = 0;
8729     Linkage = (VariableLinkage) ? VariableLinkage() : LlvmVal->getLinkage();
8730 
8731     // This is a workaround carried over from Clang which prevents undesired
8732     // optimisation of internal variables.
8733     if (Config.isTargetDevice() &&
8734         (!IsExternallyVisible || Linkage == GlobalValue::LinkOnceODRLinkage)) {
8735       // Do not create a "ref-variable" if the original is not also available
8736       // on the host.
8737       if (!OffloadInfoManager.hasDeviceGlobalVarEntryInfo(VarName))
8738         return;
8739 
8740       std::string RefName = createPlatformSpecificName({VarName, "ref"});
8741 
8742       if (!M.getNamedValue(RefName)) {
8743         Constant *AddrRef =
8744             getOrCreateInternalVariable(Addr->getType(), RefName);
8745         auto *GvAddrRef = cast<GlobalVariable>(AddrRef);
8746         GvAddrRef->setConstant(true);
8747         GvAddrRef->setLinkage(GlobalValue::InternalLinkage);
8748         GvAddrRef->setInitializer(Addr);
8749         GeneratedRefs.push_back(GvAddrRef);
8750       }
8751     }
8752   } else {
8753     if (CaptureClause == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink)
8754       Flags = OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink;
8755     else
8756       Flags = OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
8757 
8758     if (Config.isTargetDevice()) {
8759       VarName = (Addr) ? Addr->getName() : "";
8760       Addr = nullptr;
8761     } else {
8762       Addr = getAddrOfDeclareTargetVar(
8763           CaptureClause, DeviceClause, IsDeclaration, IsExternallyVisible,
8764           EntryInfo, MangledName, GeneratedRefs, OpenMPSIMD, TargetTriple,
8765           LlvmPtrTy, GlobalInitializer, VariableLinkage);
8766       VarName = (Addr) ? Addr->getName() : "";
8767     }
8768     VarSize = M.getDataLayout().getPointerSize();
8769     Linkage = GlobalValue::WeakAnyLinkage;
8770   }
8771 
8772   OffloadInfoManager.registerDeviceGlobalVarEntryInfo(VarName, Addr, VarSize,
8773                                                       Flags, Linkage);
8774 }
8775 
8776 /// Loads all the offload entries information from the host IR
8777 /// metadata.
loadOffloadInfoMetadata(Module & M)8778 void OpenMPIRBuilder::loadOffloadInfoMetadata(Module &M) {
8779   // If we are in target mode, load the metadata from the host IR. This code has
8780   // to match the metadata creation in createOffloadEntriesAndInfoMetadata().
8781 
8782   NamedMDNode *MD = M.getNamedMetadata(ompOffloadInfoName);
8783   if (!MD)
8784     return;
8785 
8786   for (MDNode *MN : MD->operands()) {
8787     auto &&GetMDInt = [MN](unsigned Idx) {
8788       auto *V = cast<ConstantAsMetadata>(MN->getOperand(Idx));
8789       return cast<ConstantInt>(V->getValue())->getZExtValue();
8790     };
8791 
8792     auto &&GetMDString = [MN](unsigned Idx) {
8793       auto *V = cast<MDString>(MN->getOperand(Idx));
8794       return V->getString();
8795     };
8796 
8797     switch (GetMDInt(0)) {
8798     default:
8799       llvm_unreachable("Unexpected metadata!");
8800       break;
8801     case OffloadEntriesInfoManager::OffloadEntryInfo::
8802         OffloadingEntryInfoTargetRegion: {
8803       TargetRegionEntryInfo EntryInfo(/*ParentName=*/GetMDString(3),
8804                                       /*DeviceID=*/GetMDInt(1),
8805                                       /*FileID=*/GetMDInt(2),
8806                                       /*Line=*/GetMDInt(4),
8807                                       /*Count=*/GetMDInt(5));
8808       OffloadInfoManager.initializeTargetRegionEntryInfo(EntryInfo,
8809                                                          /*Order=*/GetMDInt(6));
8810       break;
8811     }
8812     case OffloadEntriesInfoManager::OffloadEntryInfo::
8813         OffloadingEntryInfoDeviceGlobalVar:
8814       OffloadInfoManager.initializeDeviceGlobalVarEntryInfo(
8815           /*MangledName=*/GetMDString(1),
8816           static_cast<OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind>(
8817               /*Flags=*/GetMDInt(2)),
8818           /*Order=*/GetMDInt(3));
8819       break;
8820     }
8821   }
8822 }
8823 
loadOffloadInfoMetadata(StringRef HostFilePath)8824 void OpenMPIRBuilder::loadOffloadInfoMetadata(StringRef HostFilePath) {
8825   if (HostFilePath.empty())
8826     return;
8827 
8828   auto Buf = MemoryBuffer::getFile(HostFilePath);
8829   if (std::error_code Err = Buf.getError()) {
8830     report_fatal_error(("error opening host file from host file path inside of "
8831                         "OpenMPIRBuilder: " +
8832                         Err.message())
8833                            .c_str());
8834   }
8835 
8836   LLVMContext Ctx;
8837   auto M = expectedToErrorOrAndEmitErrors(
8838       Ctx, parseBitcodeFile(Buf.get()->getMemBufferRef(), Ctx));
8839   if (std::error_code Err = M.getError()) {
8840     report_fatal_error(
8841         ("error parsing host file inside of OpenMPIRBuilder: " + Err.message())
8842             .c_str());
8843   }
8844 
8845   loadOffloadInfoMetadata(*M.get());
8846 }
8847 
8848 //===----------------------------------------------------------------------===//
8849 // OffloadEntriesInfoManager
8850 //===----------------------------------------------------------------------===//
8851 
empty() const8852 bool OffloadEntriesInfoManager::empty() const {
8853   return OffloadEntriesTargetRegion.empty() &&
8854          OffloadEntriesDeviceGlobalVar.empty();
8855 }
8856 
getTargetRegionEntryInfoCount(const TargetRegionEntryInfo & EntryInfo) const8857 unsigned OffloadEntriesInfoManager::getTargetRegionEntryInfoCount(
8858     const TargetRegionEntryInfo &EntryInfo) const {
8859   auto It = OffloadEntriesTargetRegionCount.find(
8860       getTargetRegionEntryCountKey(EntryInfo));
8861   if (It == OffloadEntriesTargetRegionCount.end())
8862     return 0;
8863   return It->second;
8864 }
8865 
incrementTargetRegionEntryInfoCount(const TargetRegionEntryInfo & EntryInfo)8866 void OffloadEntriesInfoManager::incrementTargetRegionEntryInfoCount(
8867     const TargetRegionEntryInfo &EntryInfo) {
8868   OffloadEntriesTargetRegionCount[getTargetRegionEntryCountKey(EntryInfo)] =
8869       EntryInfo.Count + 1;
8870 }
8871 
8872 /// Initialize target region entry.
initializeTargetRegionEntryInfo(const TargetRegionEntryInfo & EntryInfo,unsigned Order)8873 void OffloadEntriesInfoManager::initializeTargetRegionEntryInfo(
8874     const TargetRegionEntryInfo &EntryInfo, unsigned Order) {
8875   OffloadEntriesTargetRegion[EntryInfo] =
8876       OffloadEntryInfoTargetRegion(Order, /*Addr=*/nullptr, /*ID=*/nullptr,
8877                                    OMPTargetRegionEntryTargetRegion);
8878   ++OffloadingEntriesNum;
8879 }
8880 
registerTargetRegionEntryInfo(TargetRegionEntryInfo EntryInfo,Constant * Addr,Constant * ID,OMPTargetRegionEntryKind Flags)8881 void OffloadEntriesInfoManager::registerTargetRegionEntryInfo(
8882     TargetRegionEntryInfo EntryInfo, Constant *Addr, Constant *ID,
8883     OMPTargetRegionEntryKind Flags) {
8884   assert(EntryInfo.Count == 0 && "expected default EntryInfo");
8885 
8886   // Update the EntryInfo with the next available count for this location.
8887   EntryInfo.Count = getTargetRegionEntryInfoCount(EntryInfo);
8888 
8889   // If we are emitting code for a target, the entry is already initialized,
8890   // only has to be registered.
8891   if (OMPBuilder->Config.isTargetDevice()) {
8892     // This could happen if the device compilation is invoked standalone.
8893     if (!hasTargetRegionEntryInfo(EntryInfo)) {
8894       return;
8895     }
8896     auto &Entry = OffloadEntriesTargetRegion[EntryInfo];
8897     Entry.setAddress(Addr);
8898     Entry.setID(ID);
8899     Entry.setFlags(Flags);
8900   } else {
8901     if (Flags == OffloadEntriesInfoManager::OMPTargetRegionEntryTargetRegion &&
8902         hasTargetRegionEntryInfo(EntryInfo, /*IgnoreAddressId*/ true))
8903       return;
8904     assert(!hasTargetRegionEntryInfo(EntryInfo) &&
8905            "Target region entry already registered!");
8906     OffloadEntryInfoTargetRegion Entry(OffloadingEntriesNum, Addr, ID, Flags);
8907     OffloadEntriesTargetRegion[EntryInfo] = Entry;
8908     ++OffloadingEntriesNum;
8909   }
8910   incrementTargetRegionEntryInfoCount(EntryInfo);
8911 }
8912 
hasTargetRegionEntryInfo(TargetRegionEntryInfo EntryInfo,bool IgnoreAddressId) const8913 bool OffloadEntriesInfoManager::hasTargetRegionEntryInfo(
8914     TargetRegionEntryInfo EntryInfo, bool IgnoreAddressId) const {
8915 
8916   // Update the EntryInfo with the next available count for this location.
8917   EntryInfo.Count = getTargetRegionEntryInfoCount(EntryInfo);
8918 
8919   auto It = OffloadEntriesTargetRegion.find(EntryInfo);
8920   if (It == OffloadEntriesTargetRegion.end()) {
8921     return false;
8922   }
8923   // Fail if this entry is already registered.
8924   if (!IgnoreAddressId && (It->second.getAddress() || It->second.getID()))
8925     return false;
8926   return true;
8927 }
8928 
actOnTargetRegionEntriesInfo(const OffloadTargetRegionEntryInfoActTy & Action)8929 void OffloadEntriesInfoManager::actOnTargetRegionEntriesInfo(
8930     const OffloadTargetRegionEntryInfoActTy &Action) {
8931   // Scan all target region entries and perform the provided action.
8932   for (const auto &It : OffloadEntriesTargetRegion) {
8933     Action(It.first, It.second);
8934   }
8935 }
8936 
initializeDeviceGlobalVarEntryInfo(StringRef Name,OMPTargetGlobalVarEntryKind Flags,unsigned Order)8937 void OffloadEntriesInfoManager::initializeDeviceGlobalVarEntryInfo(
8938     StringRef Name, OMPTargetGlobalVarEntryKind Flags, unsigned Order) {
8939   OffloadEntriesDeviceGlobalVar.try_emplace(Name, Order, Flags);
8940   ++OffloadingEntriesNum;
8941 }
8942 
registerDeviceGlobalVarEntryInfo(StringRef VarName,Constant * Addr,int64_t VarSize,OMPTargetGlobalVarEntryKind Flags,GlobalValue::LinkageTypes Linkage)8943 void OffloadEntriesInfoManager::registerDeviceGlobalVarEntryInfo(
8944     StringRef VarName, Constant *Addr, int64_t VarSize,
8945     OMPTargetGlobalVarEntryKind Flags, GlobalValue::LinkageTypes Linkage) {
8946   if (OMPBuilder->Config.isTargetDevice()) {
8947     // This could happen if the device compilation is invoked standalone.
8948     if (!hasDeviceGlobalVarEntryInfo(VarName))
8949       return;
8950     auto &Entry = OffloadEntriesDeviceGlobalVar[VarName];
8951     if (Entry.getAddress() && hasDeviceGlobalVarEntryInfo(VarName)) {
8952       if (Entry.getVarSize() == 0) {
8953         Entry.setVarSize(VarSize);
8954         Entry.setLinkage(Linkage);
8955       }
8956       return;
8957     }
8958     Entry.setVarSize(VarSize);
8959     Entry.setLinkage(Linkage);
8960     Entry.setAddress(Addr);
8961   } else {
8962     if (hasDeviceGlobalVarEntryInfo(VarName)) {
8963       auto &Entry = OffloadEntriesDeviceGlobalVar[VarName];
8964       assert(Entry.isValid() && Entry.getFlags() == Flags &&
8965              "Entry not initialized!");
8966       if (Entry.getVarSize() == 0) {
8967         Entry.setVarSize(VarSize);
8968         Entry.setLinkage(Linkage);
8969       }
8970       return;
8971     }
8972     if (Flags == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirect)
8973       OffloadEntriesDeviceGlobalVar.try_emplace(VarName, OffloadingEntriesNum,
8974                                                 Addr, VarSize, Flags, Linkage,
8975                                                 VarName.str());
8976     else
8977       OffloadEntriesDeviceGlobalVar.try_emplace(
8978           VarName, OffloadingEntriesNum, Addr, VarSize, Flags, Linkage, "");
8979     ++OffloadingEntriesNum;
8980   }
8981 }
8982 
actOnDeviceGlobalVarEntriesInfo(const OffloadDeviceGlobalVarEntryInfoActTy & Action)8983 void OffloadEntriesInfoManager::actOnDeviceGlobalVarEntriesInfo(
8984     const OffloadDeviceGlobalVarEntryInfoActTy &Action) {
8985   // Scan all target region entries and perform the provided action.
8986   for (const auto &E : OffloadEntriesDeviceGlobalVar)
8987     Action(E.getKey(), E.getValue());
8988 }
8989 
8990 //===----------------------------------------------------------------------===//
8991 // CanonicalLoopInfo
8992 //===----------------------------------------------------------------------===//
8993 
collectControlBlocks(SmallVectorImpl<BasicBlock * > & BBs)8994 void CanonicalLoopInfo::collectControlBlocks(
8995     SmallVectorImpl<BasicBlock *> &BBs) {
8996   // We only count those BBs as control block for which we do not need to
8997   // reverse the CFG, i.e. not the loop body which can contain arbitrary control
8998   // flow. For consistency, this also means we do not add the Body block, which
8999   // is just the entry to the body code.
9000   BBs.reserve(BBs.size() + 6);
9001   BBs.append({getPreheader(), Header, Cond, Latch, Exit, getAfter()});
9002 }
9003 
getPreheader() const9004 BasicBlock *CanonicalLoopInfo::getPreheader() const {
9005   assert(isValid() && "Requires a valid canonical loop");
9006   for (BasicBlock *Pred : predecessors(Header)) {
9007     if (Pred != Latch)
9008       return Pred;
9009   }
9010   llvm_unreachable("Missing preheader");
9011 }
9012 
setTripCount(Value * TripCount)9013 void CanonicalLoopInfo::setTripCount(Value *TripCount) {
9014   assert(isValid() && "Requires a valid canonical loop");
9015 
9016   Instruction *CmpI = &getCond()->front();
9017   assert(isa<CmpInst>(CmpI) && "First inst must compare IV with TripCount");
9018   CmpI->setOperand(1, TripCount);
9019 
9020 #ifndef NDEBUG
9021   assertOK();
9022 #endif
9023 }
9024 
mapIndVar(llvm::function_ref<Value * (Instruction *)> Updater)9025 void CanonicalLoopInfo::mapIndVar(
9026     llvm::function_ref<Value *(Instruction *)> Updater) {
9027   assert(isValid() && "Requires a valid canonical loop");
9028 
9029   Instruction *OldIV = getIndVar();
9030 
9031   // Record all uses excluding those introduced by the updater. Uses by the
9032   // CanonicalLoopInfo itself to keep track of the number of iterations are
9033   // excluded.
9034   SmallVector<Use *> ReplacableUses;
9035   for (Use &U : OldIV->uses()) {
9036     auto *User = dyn_cast<Instruction>(U.getUser());
9037     if (!User)
9038       continue;
9039     if (User->getParent() == getCond())
9040       continue;
9041     if (User->getParent() == getLatch())
9042       continue;
9043     ReplacableUses.push_back(&U);
9044   }
9045 
9046   // Run the updater that may introduce new uses
9047   Value *NewIV = Updater(OldIV);
9048 
9049   // Replace the old uses with the value returned by the updater.
9050   for (Use *U : ReplacableUses)
9051     U->set(NewIV);
9052 
9053 #ifndef NDEBUG
9054   assertOK();
9055 #endif
9056 }
9057 
assertOK() const9058 void CanonicalLoopInfo::assertOK() const {
9059 #ifndef NDEBUG
9060   // No constraints if this object currently does not describe a loop.
9061   if (!isValid())
9062     return;
9063 
9064   BasicBlock *Preheader = getPreheader();
9065   BasicBlock *Body = getBody();
9066   BasicBlock *After = getAfter();
9067 
9068   // Verify standard control-flow we use for OpenMP loops.
9069   assert(Preheader);
9070   assert(isa<BranchInst>(Preheader->getTerminator()) &&
9071          "Preheader must terminate with unconditional branch");
9072   assert(Preheader->getSingleSuccessor() == Header &&
9073          "Preheader must jump to header");
9074 
9075   assert(Header);
9076   assert(isa<BranchInst>(Header->getTerminator()) &&
9077          "Header must terminate with unconditional branch");
9078   assert(Header->getSingleSuccessor() == Cond &&
9079          "Header must jump to exiting block");
9080 
9081   assert(Cond);
9082   assert(Cond->getSinglePredecessor() == Header &&
9083          "Exiting block only reachable from header");
9084 
9085   assert(isa<BranchInst>(Cond->getTerminator()) &&
9086          "Exiting block must terminate with conditional branch");
9087   assert(size(successors(Cond)) == 2 &&
9088          "Exiting block must have two successors");
9089   assert(cast<BranchInst>(Cond->getTerminator())->getSuccessor(0) == Body &&
9090          "Exiting block's first successor jump to the body");
9091   assert(cast<BranchInst>(Cond->getTerminator())->getSuccessor(1) == Exit &&
9092          "Exiting block's second successor must exit the loop");
9093 
9094   assert(Body);
9095   assert(Body->getSinglePredecessor() == Cond &&
9096          "Body only reachable from exiting block");
9097   assert(!isa<PHINode>(Body->front()));
9098 
9099   assert(Latch);
9100   assert(isa<BranchInst>(Latch->getTerminator()) &&
9101          "Latch must terminate with unconditional branch");
9102   assert(Latch->getSingleSuccessor() == Header && "Latch must jump to header");
9103   // TODO: To support simple redirecting of the end of the body code that has
9104   // multiple; introduce another auxiliary basic block like preheader and after.
9105   assert(Latch->getSinglePredecessor() != nullptr);
9106   assert(!isa<PHINode>(Latch->front()));
9107 
9108   assert(Exit);
9109   assert(isa<BranchInst>(Exit->getTerminator()) &&
9110          "Exit block must terminate with unconditional branch");
9111   assert(Exit->getSingleSuccessor() == After &&
9112          "Exit block must jump to after block");
9113 
9114   assert(After);
9115   assert(After->getSinglePredecessor() == Exit &&
9116          "After block only reachable from exit block");
9117   assert(After->empty() || !isa<PHINode>(After->front()));
9118 
9119   Instruction *IndVar = getIndVar();
9120   assert(IndVar && "Canonical induction variable not found?");
9121   assert(isa<IntegerType>(IndVar->getType()) &&
9122          "Induction variable must be an integer");
9123   assert(cast<PHINode>(IndVar)->getParent() == Header &&
9124          "Induction variable must be a PHI in the loop header");
9125   assert(cast<PHINode>(IndVar)->getIncomingBlock(0) == Preheader);
9126   assert(
9127       cast<ConstantInt>(cast<PHINode>(IndVar)->getIncomingValue(0))->isZero());
9128   assert(cast<PHINode>(IndVar)->getIncomingBlock(1) == Latch);
9129 
9130   auto *NextIndVar = cast<PHINode>(IndVar)->getIncomingValue(1);
9131   assert(cast<Instruction>(NextIndVar)->getParent() == Latch);
9132   assert(cast<BinaryOperator>(NextIndVar)->getOpcode() == BinaryOperator::Add);
9133   assert(cast<BinaryOperator>(NextIndVar)->getOperand(0) == IndVar);
9134   assert(cast<ConstantInt>(cast<BinaryOperator>(NextIndVar)->getOperand(1))
9135              ->isOne());
9136 
9137   Value *TripCount = getTripCount();
9138   assert(TripCount && "Loop trip count not found?");
9139   assert(IndVar->getType() == TripCount->getType() &&
9140          "Trip count and induction variable must have the same type");
9141 
9142   auto *CmpI = cast<CmpInst>(&Cond->front());
9143   assert(CmpI->getPredicate() == CmpInst::ICMP_ULT &&
9144          "Exit condition must be a signed less-than comparison");
9145   assert(CmpI->getOperand(0) == IndVar &&
9146          "Exit condition must compare the induction variable");
9147   assert(CmpI->getOperand(1) == TripCount &&
9148          "Exit condition must compare with the trip count");
9149 #endif
9150 }
9151 
invalidate()9152 void CanonicalLoopInfo::invalidate() {
9153   Header = nullptr;
9154   Cond = nullptr;
9155   Latch = nullptr;
9156   Exit = nullptr;
9157 }
9158