1 //===- OpenMPIRBuilder.cpp - Builder for LLVM-IR for OpenMP directives ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 ///
10 /// This file implements the OpenMPIRBuilder class, which is used as a
11 /// convenient way to create LLVM instructions for OpenMP directives.
12 ///
13 //===----------------------------------------------------------------------===//
14
15 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
16 #include "llvm/ADT/SmallBitVector.h"
17 #include "llvm/ADT/SmallSet.h"
18 #include "llvm/ADT/StringExtras.h"
19 #include "llvm/ADT/StringRef.h"
20 #include "llvm/Analysis/AssumptionCache.h"
21 #include "llvm/Analysis/CodeMetrics.h"
22 #include "llvm/Analysis/LoopInfo.h"
23 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
24 #include "llvm/Analysis/ScalarEvolution.h"
25 #include "llvm/Analysis/TargetLibraryInfo.h"
26 #include "llvm/Bitcode/BitcodeReader.h"
27 #include "llvm/Frontend/Offloading/Utility.h"
28 #include "llvm/Frontend/OpenMP/OMPGridValues.h"
29 #include "llvm/IR/Attributes.h"
30 #include "llvm/IR/BasicBlock.h"
31 #include "llvm/IR/CFG.h"
32 #include "llvm/IR/CallingConv.h"
33 #include "llvm/IR/Constant.h"
34 #include "llvm/IR/Constants.h"
35 #include "llvm/IR/DIBuilder.h"
36 #include "llvm/IR/DebugInfoMetadata.h"
37 #include "llvm/IR/DerivedTypes.h"
38 #include "llvm/IR/Function.h"
39 #include "llvm/IR/GlobalVariable.h"
40 #include "llvm/IR/IRBuilder.h"
41 #include "llvm/IR/InstIterator.h"
42 #include "llvm/IR/IntrinsicInst.h"
43 #include "llvm/IR/LLVMContext.h"
44 #include "llvm/IR/MDBuilder.h"
45 #include "llvm/IR/Metadata.h"
46 #include "llvm/IR/PassInstrumentation.h"
47 #include "llvm/IR/PassManager.h"
48 #include "llvm/IR/ReplaceConstant.h"
49 #include "llvm/IR/Value.h"
50 #include "llvm/MC/TargetRegistry.h"
51 #include "llvm/Support/CommandLine.h"
52 #include "llvm/Support/ErrorHandling.h"
53 #include "llvm/Support/FileSystem.h"
54 #include "llvm/Target/TargetMachine.h"
55 #include "llvm/Target/TargetOptions.h"
56 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
57 #include "llvm/Transforms/Utils/Cloning.h"
58 #include "llvm/Transforms/Utils/CodeExtractor.h"
59 #include "llvm/Transforms/Utils/LoopPeel.h"
60 #include "llvm/Transforms/Utils/UnrollLoop.h"
61
62 #include <cstdint>
63 #include <optional>
64
65 #define DEBUG_TYPE "openmp-ir-builder"
66
67 using namespace llvm;
68 using namespace omp;
69
70 static cl::opt<bool>
71 OptimisticAttributes("openmp-ir-builder-optimistic-attributes", cl::Hidden,
72 cl::desc("Use optimistic attributes describing "
73 "'as-if' properties of runtime calls."),
74 cl::init(false));
75
76 static cl::opt<double> UnrollThresholdFactor(
77 "openmp-ir-builder-unroll-threshold-factor", cl::Hidden,
78 cl::desc("Factor for the unroll threshold to account for code "
79 "simplifications still taking place"),
80 cl::init(1.5));
81
82 #ifndef NDEBUG
83 /// Return whether IP1 and IP2 are ambiguous, i.e. that inserting instructions
84 /// at position IP1 may change the meaning of IP2 or vice-versa. This is because
85 /// an InsertPoint stores the instruction before something is inserted. For
86 /// instance, if both point to the same instruction, two IRBuilders alternating
87 /// creating instruction will cause the instructions to be interleaved.
isConflictIP(IRBuilder<>::InsertPoint IP1,IRBuilder<>::InsertPoint IP2)88 static bool isConflictIP(IRBuilder<>::InsertPoint IP1,
89 IRBuilder<>::InsertPoint IP2) {
90 if (!IP1.isSet() || !IP2.isSet())
91 return false;
92 return IP1.getBlock() == IP2.getBlock() && IP1.getPoint() == IP2.getPoint();
93 }
94
isValidWorkshareLoopScheduleType(OMPScheduleType SchedType)95 static bool isValidWorkshareLoopScheduleType(OMPScheduleType SchedType) {
96 // Valid ordered/unordered and base algorithm combinations.
97 switch (SchedType & ~OMPScheduleType::MonotonicityMask) {
98 case OMPScheduleType::UnorderedStaticChunked:
99 case OMPScheduleType::UnorderedStatic:
100 case OMPScheduleType::UnorderedDynamicChunked:
101 case OMPScheduleType::UnorderedGuidedChunked:
102 case OMPScheduleType::UnorderedRuntime:
103 case OMPScheduleType::UnorderedAuto:
104 case OMPScheduleType::UnorderedTrapezoidal:
105 case OMPScheduleType::UnorderedGreedy:
106 case OMPScheduleType::UnorderedBalanced:
107 case OMPScheduleType::UnorderedGuidedIterativeChunked:
108 case OMPScheduleType::UnorderedGuidedAnalyticalChunked:
109 case OMPScheduleType::UnorderedSteal:
110 case OMPScheduleType::UnorderedStaticBalancedChunked:
111 case OMPScheduleType::UnorderedGuidedSimd:
112 case OMPScheduleType::UnorderedRuntimeSimd:
113 case OMPScheduleType::OrderedStaticChunked:
114 case OMPScheduleType::OrderedStatic:
115 case OMPScheduleType::OrderedDynamicChunked:
116 case OMPScheduleType::OrderedGuidedChunked:
117 case OMPScheduleType::OrderedRuntime:
118 case OMPScheduleType::OrderedAuto:
119 case OMPScheduleType::OrderdTrapezoidal:
120 case OMPScheduleType::NomergeUnorderedStaticChunked:
121 case OMPScheduleType::NomergeUnorderedStatic:
122 case OMPScheduleType::NomergeUnorderedDynamicChunked:
123 case OMPScheduleType::NomergeUnorderedGuidedChunked:
124 case OMPScheduleType::NomergeUnorderedRuntime:
125 case OMPScheduleType::NomergeUnorderedAuto:
126 case OMPScheduleType::NomergeUnorderedTrapezoidal:
127 case OMPScheduleType::NomergeUnorderedGreedy:
128 case OMPScheduleType::NomergeUnorderedBalanced:
129 case OMPScheduleType::NomergeUnorderedGuidedIterativeChunked:
130 case OMPScheduleType::NomergeUnorderedGuidedAnalyticalChunked:
131 case OMPScheduleType::NomergeUnorderedSteal:
132 case OMPScheduleType::NomergeOrderedStaticChunked:
133 case OMPScheduleType::NomergeOrderedStatic:
134 case OMPScheduleType::NomergeOrderedDynamicChunked:
135 case OMPScheduleType::NomergeOrderedGuidedChunked:
136 case OMPScheduleType::NomergeOrderedRuntime:
137 case OMPScheduleType::NomergeOrderedAuto:
138 case OMPScheduleType::NomergeOrderedTrapezoidal:
139 break;
140 default:
141 return false;
142 }
143
144 // Must not set both monotonicity modifiers at the same time.
145 OMPScheduleType MonotonicityFlags =
146 SchedType & OMPScheduleType::MonotonicityMask;
147 if (MonotonicityFlags == OMPScheduleType::MonotonicityMask)
148 return false;
149
150 return true;
151 }
152 #endif
153
getGridValue(const Triple & T,Function * Kernel)154 static const omp::GV &getGridValue(const Triple &T, Function *Kernel) {
155 if (T.isAMDGPU()) {
156 StringRef Features =
157 Kernel->getFnAttribute("target-features").getValueAsString();
158 if (Features.count("+wavefrontsize64"))
159 return omp::getAMDGPUGridValues<64>();
160 return omp::getAMDGPUGridValues<32>();
161 }
162 if (T.isNVPTX())
163 return omp::NVPTXGridValues;
164 if (T.isSPIRV())
165 return omp::SPIRVGridValues;
166 llvm_unreachable("No grid value available for this architecture!");
167 }
168
169 /// Determine which scheduling algorithm to use, determined from schedule clause
170 /// arguments.
171 static OMPScheduleType
getOpenMPBaseScheduleType(llvm::omp::ScheduleKind ClauseKind,bool HasChunks,bool HasSimdModifier)172 getOpenMPBaseScheduleType(llvm::omp::ScheduleKind ClauseKind, bool HasChunks,
173 bool HasSimdModifier) {
174 // Currently, the default schedule it static.
175 switch (ClauseKind) {
176 case OMP_SCHEDULE_Default:
177 case OMP_SCHEDULE_Static:
178 return HasChunks ? OMPScheduleType::BaseStaticChunked
179 : OMPScheduleType::BaseStatic;
180 case OMP_SCHEDULE_Dynamic:
181 return OMPScheduleType::BaseDynamicChunked;
182 case OMP_SCHEDULE_Guided:
183 return HasSimdModifier ? OMPScheduleType::BaseGuidedSimd
184 : OMPScheduleType::BaseGuidedChunked;
185 case OMP_SCHEDULE_Auto:
186 return llvm::omp::OMPScheduleType::BaseAuto;
187 case OMP_SCHEDULE_Runtime:
188 return HasSimdModifier ? OMPScheduleType::BaseRuntimeSimd
189 : OMPScheduleType::BaseRuntime;
190 }
191 llvm_unreachable("unhandled schedule clause argument");
192 }
193
194 /// Adds ordering modifier flags to schedule type.
195 static OMPScheduleType
getOpenMPOrderingScheduleType(OMPScheduleType BaseScheduleType,bool HasOrderedClause)196 getOpenMPOrderingScheduleType(OMPScheduleType BaseScheduleType,
197 bool HasOrderedClause) {
198 assert((BaseScheduleType & OMPScheduleType::ModifierMask) ==
199 OMPScheduleType::None &&
200 "Must not have ordering nor monotonicity flags already set");
201
202 OMPScheduleType OrderingModifier = HasOrderedClause
203 ? OMPScheduleType::ModifierOrdered
204 : OMPScheduleType::ModifierUnordered;
205 OMPScheduleType OrderingScheduleType = BaseScheduleType | OrderingModifier;
206
207 // Unsupported combinations
208 if (OrderingScheduleType ==
209 (OMPScheduleType::BaseGuidedSimd | OMPScheduleType::ModifierOrdered))
210 return OMPScheduleType::OrderedGuidedChunked;
211 else if (OrderingScheduleType == (OMPScheduleType::BaseRuntimeSimd |
212 OMPScheduleType::ModifierOrdered))
213 return OMPScheduleType::OrderedRuntime;
214
215 return OrderingScheduleType;
216 }
217
218 /// Adds monotonicity modifier flags to schedule type.
219 static OMPScheduleType
getOpenMPMonotonicityScheduleType(OMPScheduleType ScheduleType,bool HasSimdModifier,bool HasMonotonic,bool HasNonmonotonic,bool HasOrderedClause)220 getOpenMPMonotonicityScheduleType(OMPScheduleType ScheduleType,
221 bool HasSimdModifier, bool HasMonotonic,
222 bool HasNonmonotonic, bool HasOrderedClause) {
223 assert((ScheduleType & OMPScheduleType::MonotonicityMask) ==
224 OMPScheduleType::None &&
225 "Must not have monotonicity flags already set");
226 assert((!HasMonotonic || !HasNonmonotonic) &&
227 "Monotonic and Nonmonotonic are contradicting each other");
228
229 if (HasMonotonic) {
230 return ScheduleType | OMPScheduleType::ModifierMonotonic;
231 } else if (HasNonmonotonic) {
232 return ScheduleType | OMPScheduleType::ModifierNonmonotonic;
233 } else {
234 // OpenMP 5.1, 2.11.4 Worksharing-Loop Construct, Description.
235 // If the static schedule kind is specified or if the ordered clause is
236 // specified, and if the nonmonotonic modifier is not specified, the
237 // effect is as if the monotonic modifier is specified. Otherwise, unless
238 // the monotonic modifier is specified, the effect is as if the
239 // nonmonotonic modifier is specified.
240 OMPScheduleType BaseScheduleType =
241 ScheduleType & ~OMPScheduleType::ModifierMask;
242 if ((BaseScheduleType == OMPScheduleType::BaseStatic) ||
243 (BaseScheduleType == OMPScheduleType::BaseStaticChunked) ||
244 HasOrderedClause) {
245 // The monotonic is used by default in openmp runtime library, so no need
246 // to set it.
247 return ScheduleType;
248 } else {
249 return ScheduleType | OMPScheduleType::ModifierNonmonotonic;
250 }
251 }
252 }
253
254 /// Determine the schedule type using schedule and ordering clause arguments.
255 static OMPScheduleType
computeOpenMPScheduleType(ScheduleKind ClauseKind,bool HasChunks,bool HasSimdModifier,bool HasMonotonicModifier,bool HasNonmonotonicModifier,bool HasOrderedClause)256 computeOpenMPScheduleType(ScheduleKind ClauseKind, bool HasChunks,
257 bool HasSimdModifier, bool HasMonotonicModifier,
258 bool HasNonmonotonicModifier, bool HasOrderedClause) {
259 OMPScheduleType BaseSchedule =
260 getOpenMPBaseScheduleType(ClauseKind, HasChunks, HasSimdModifier);
261 OMPScheduleType OrderedSchedule =
262 getOpenMPOrderingScheduleType(BaseSchedule, HasOrderedClause);
263 OMPScheduleType Result = getOpenMPMonotonicityScheduleType(
264 OrderedSchedule, HasSimdModifier, HasMonotonicModifier,
265 HasNonmonotonicModifier, HasOrderedClause);
266
267 assert(isValidWorkshareLoopScheduleType(Result));
268 return Result;
269 }
270
271 /// Make \p Source branch to \p Target.
272 ///
273 /// Handles two situations:
274 /// * \p Source already has an unconditional branch.
275 /// * \p Source is a degenerate block (no terminator because the BB is
276 /// the current head of the IR construction).
redirectTo(BasicBlock * Source,BasicBlock * Target,DebugLoc DL)277 static void redirectTo(BasicBlock *Source, BasicBlock *Target, DebugLoc DL) {
278 if (Instruction *Term = Source->getTerminator()) {
279 auto *Br = cast<BranchInst>(Term);
280 assert(!Br->isConditional() &&
281 "BB's terminator must be an unconditional branch (or degenerate)");
282 BasicBlock *Succ = Br->getSuccessor(0);
283 Succ->removePredecessor(Source, /*KeepOneInputPHIs=*/true);
284 Br->setSuccessor(0, Target);
285 return;
286 }
287
288 auto *NewBr = BranchInst::Create(Target, Source);
289 NewBr->setDebugLoc(DL);
290 }
291
spliceBB(IRBuilderBase::InsertPoint IP,BasicBlock * New,bool CreateBranch,DebugLoc DL)292 void llvm::spliceBB(IRBuilderBase::InsertPoint IP, BasicBlock *New,
293 bool CreateBranch, DebugLoc DL) {
294 assert(New->getFirstInsertionPt() == New->begin() &&
295 "Target BB must not have PHI nodes");
296
297 // Move instructions to new block.
298 BasicBlock *Old = IP.getBlock();
299 New->splice(New->begin(), Old, IP.getPoint(), Old->end());
300
301 if (CreateBranch) {
302 auto *NewBr = BranchInst::Create(New, Old);
303 NewBr->setDebugLoc(DL);
304 }
305 }
306
spliceBB(IRBuilder<> & Builder,BasicBlock * New,bool CreateBranch)307 void llvm::spliceBB(IRBuilder<> &Builder, BasicBlock *New, bool CreateBranch) {
308 DebugLoc DebugLoc = Builder.getCurrentDebugLocation();
309 BasicBlock *Old = Builder.GetInsertBlock();
310
311 spliceBB(Builder.saveIP(), New, CreateBranch, DebugLoc);
312 if (CreateBranch)
313 Builder.SetInsertPoint(Old->getTerminator());
314 else
315 Builder.SetInsertPoint(Old);
316
317 // SetInsertPoint also updates the Builder's debug location, but we want to
318 // keep the one the Builder was configured to use.
319 Builder.SetCurrentDebugLocation(DebugLoc);
320 }
321
splitBB(IRBuilderBase::InsertPoint IP,bool CreateBranch,DebugLoc DL,llvm::Twine Name)322 BasicBlock *llvm::splitBB(IRBuilderBase::InsertPoint IP, bool CreateBranch,
323 DebugLoc DL, llvm::Twine Name) {
324 BasicBlock *Old = IP.getBlock();
325 BasicBlock *New = BasicBlock::Create(
326 Old->getContext(), Name.isTriviallyEmpty() ? Old->getName() : Name,
327 Old->getParent(), Old->getNextNode());
328 spliceBB(IP, New, CreateBranch, DL);
329 New->replaceSuccessorsPhiUsesWith(Old, New);
330 return New;
331 }
332
splitBB(IRBuilderBase & Builder,bool CreateBranch,llvm::Twine Name)333 BasicBlock *llvm::splitBB(IRBuilderBase &Builder, bool CreateBranch,
334 llvm::Twine Name) {
335 DebugLoc DebugLoc = Builder.getCurrentDebugLocation();
336 BasicBlock *New = splitBB(Builder.saveIP(), CreateBranch, DebugLoc, Name);
337 if (CreateBranch)
338 Builder.SetInsertPoint(Builder.GetInsertBlock()->getTerminator());
339 else
340 Builder.SetInsertPoint(Builder.GetInsertBlock());
341 // SetInsertPoint also updates the Builder's debug location, but we want to
342 // keep the one the Builder was configured to use.
343 Builder.SetCurrentDebugLocation(DebugLoc);
344 return New;
345 }
346
splitBB(IRBuilder<> & Builder,bool CreateBranch,llvm::Twine Name)347 BasicBlock *llvm::splitBB(IRBuilder<> &Builder, bool CreateBranch,
348 llvm::Twine Name) {
349 DebugLoc DebugLoc = Builder.getCurrentDebugLocation();
350 BasicBlock *New = splitBB(Builder.saveIP(), CreateBranch, DebugLoc, Name);
351 if (CreateBranch)
352 Builder.SetInsertPoint(Builder.GetInsertBlock()->getTerminator());
353 else
354 Builder.SetInsertPoint(Builder.GetInsertBlock());
355 // SetInsertPoint also updates the Builder's debug location, but we want to
356 // keep the one the Builder was configured to use.
357 Builder.SetCurrentDebugLocation(DebugLoc);
358 return New;
359 }
360
splitBBWithSuffix(IRBuilderBase & Builder,bool CreateBranch,llvm::Twine Suffix)361 BasicBlock *llvm::splitBBWithSuffix(IRBuilderBase &Builder, bool CreateBranch,
362 llvm::Twine Suffix) {
363 BasicBlock *Old = Builder.GetInsertBlock();
364 return splitBB(Builder, CreateBranch, Old->getName() + Suffix);
365 }
366
367 // This function creates a fake integer value and a fake use for the integer
368 // value. It returns the fake value created. This is useful in modeling the
369 // extra arguments to the outlined functions.
createFakeIntVal(IRBuilderBase & Builder,OpenMPIRBuilder::InsertPointTy OuterAllocaIP,llvm::SmallVectorImpl<Instruction * > & ToBeDeleted,OpenMPIRBuilder::InsertPointTy InnerAllocaIP,const Twine & Name="",bool AsPtr=true)370 Value *createFakeIntVal(IRBuilderBase &Builder,
371 OpenMPIRBuilder::InsertPointTy OuterAllocaIP,
372 llvm::SmallVectorImpl<Instruction *> &ToBeDeleted,
373 OpenMPIRBuilder::InsertPointTy InnerAllocaIP,
374 const Twine &Name = "", bool AsPtr = true) {
375 Builder.restoreIP(OuterAllocaIP);
376 Instruction *FakeVal;
377 AllocaInst *FakeValAddr =
378 Builder.CreateAlloca(Builder.getInt32Ty(), nullptr, Name + ".addr");
379 ToBeDeleted.push_back(FakeValAddr);
380
381 if (AsPtr) {
382 FakeVal = FakeValAddr;
383 } else {
384 FakeVal =
385 Builder.CreateLoad(Builder.getInt32Ty(), FakeValAddr, Name + ".val");
386 ToBeDeleted.push_back(FakeVal);
387 }
388
389 // Generate a fake use of this value
390 Builder.restoreIP(InnerAllocaIP);
391 Instruction *UseFakeVal;
392 if (AsPtr) {
393 UseFakeVal =
394 Builder.CreateLoad(Builder.getInt32Ty(), FakeVal, Name + ".use");
395 } else {
396 UseFakeVal =
397 cast<BinaryOperator>(Builder.CreateAdd(FakeVal, Builder.getInt32(10)));
398 }
399 ToBeDeleted.push_back(UseFakeVal);
400 return FakeVal;
401 }
402
403 //===----------------------------------------------------------------------===//
404 // OpenMPIRBuilderConfig
405 //===----------------------------------------------------------------------===//
406
407 namespace {
408 LLVM_ENABLE_BITMASK_ENUMS_IN_NAMESPACE();
409 /// Values for bit flags for marking which requires clauses have been used.
410 enum OpenMPOffloadingRequiresDirFlags {
411 /// flag undefined.
412 OMP_REQ_UNDEFINED = 0x000,
413 /// no requires directive present.
414 OMP_REQ_NONE = 0x001,
415 /// reverse_offload clause.
416 OMP_REQ_REVERSE_OFFLOAD = 0x002,
417 /// unified_address clause.
418 OMP_REQ_UNIFIED_ADDRESS = 0x004,
419 /// unified_shared_memory clause.
420 OMP_REQ_UNIFIED_SHARED_MEMORY = 0x008,
421 /// dynamic_allocators clause.
422 OMP_REQ_DYNAMIC_ALLOCATORS = 0x010,
423 LLVM_MARK_AS_BITMASK_ENUM(/*LargestValue=*/OMP_REQ_DYNAMIC_ALLOCATORS)
424 };
425
426 } // anonymous namespace
427
OpenMPIRBuilderConfig()428 OpenMPIRBuilderConfig::OpenMPIRBuilderConfig()
429 : RequiresFlags(OMP_REQ_UNDEFINED) {}
430
OpenMPIRBuilderConfig(bool IsTargetDevice,bool IsGPU,bool OpenMPOffloadMandatory,bool HasRequiresReverseOffload,bool HasRequiresUnifiedAddress,bool HasRequiresUnifiedSharedMemory,bool HasRequiresDynamicAllocators)431 OpenMPIRBuilderConfig::OpenMPIRBuilderConfig(
432 bool IsTargetDevice, bool IsGPU, bool OpenMPOffloadMandatory,
433 bool HasRequiresReverseOffload, bool HasRequiresUnifiedAddress,
434 bool HasRequiresUnifiedSharedMemory, bool HasRequiresDynamicAllocators)
435 : IsTargetDevice(IsTargetDevice), IsGPU(IsGPU),
436 OpenMPOffloadMandatory(OpenMPOffloadMandatory),
437 RequiresFlags(OMP_REQ_UNDEFINED) {
438 if (HasRequiresReverseOffload)
439 RequiresFlags |= OMP_REQ_REVERSE_OFFLOAD;
440 if (HasRequiresUnifiedAddress)
441 RequiresFlags |= OMP_REQ_UNIFIED_ADDRESS;
442 if (HasRequiresUnifiedSharedMemory)
443 RequiresFlags |= OMP_REQ_UNIFIED_SHARED_MEMORY;
444 if (HasRequiresDynamicAllocators)
445 RequiresFlags |= OMP_REQ_DYNAMIC_ALLOCATORS;
446 }
447
hasRequiresReverseOffload() const448 bool OpenMPIRBuilderConfig::hasRequiresReverseOffload() const {
449 return RequiresFlags & OMP_REQ_REVERSE_OFFLOAD;
450 }
451
hasRequiresUnifiedAddress() const452 bool OpenMPIRBuilderConfig::hasRequiresUnifiedAddress() const {
453 return RequiresFlags & OMP_REQ_UNIFIED_ADDRESS;
454 }
455
hasRequiresUnifiedSharedMemory() const456 bool OpenMPIRBuilderConfig::hasRequiresUnifiedSharedMemory() const {
457 return RequiresFlags & OMP_REQ_UNIFIED_SHARED_MEMORY;
458 }
459
hasRequiresDynamicAllocators() const460 bool OpenMPIRBuilderConfig::hasRequiresDynamicAllocators() const {
461 return RequiresFlags & OMP_REQ_DYNAMIC_ALLOCATORS;
462 }
463
getRequiresFlags() const464 int64_t OpenMPIRBuilderConfig::getRequiresFlags() const {
465 return hasRequiresFlags() ? RequiresFlags
466 : static_cast<int64_t>(OMP_REQ_NONE);
467 }
468
setHasRequiresReverseOffload(bool Value)469 void OpenMPIRBuilderConfig::setHasRequiresReverseOffload(bool Value) {
470 if (Value)
471 RequiresFlags |= OMP_REQ_REVERSE_OFFLOAD;
472 else
473 RequiresFlags &= ~OMP_REQ_REVERSE_OFFLOAD;
474 }
475
setHasRequiresUnifiedAddress(bool Value)476 void OpenMPIRBuilderConfig::setHasRequiresUnifiedAddress(bool Value) {
477 if (Value)
478 RequiresFlags |= OMP_REQ_UNIFIED_ADDRESS;
479 else
480 RequiresFlags &= ~OMP_REQ_UNIFIED_ADDRESS;
481 }
482
setHasRequiresUnifiedSharedMemory(bool Value)483 void OpenMPIRBuilderConfig::setHasRequiresUnifiedSharedMemory(bool Value) {
484 if (Value)
485 RequiresFlags |= OMP_REQ_UNIFIED_SHARED_MEMORY;
486 else
487 RequiresFlags &= ~OMP_REQ_UNIFIED_SHARED_MEMORY;
488 }
489
setHasRequiresDynamicAllocators(bool Value)490 void OpenMPIRBuilderConfig::setHasRequiresDynamicAllocators(bool Value) {
491 if (Value)
492 RequiresFlags |= OMP_REQ_DYNAMIC_ALLOCATORS;
493 else
494 RequiresFlags &= ~OMP_REQ_DYNAMIC_ALLOCATORS;
495 }
496
497 //===----------------------------------------------------------------------===//
498 // OpenMPIRBuilder
499 //===----------------------------------------------------------------------===//
500
getKernelArgsVector(TargetKernelArgs & KernelArgs,IRBuilderBase & Builder,SmallVector<Value * > & ArgsVector)501 void OpenMPIRBuilder::getKernelArgsVector(TargetKernelArgs &KernelArgs,
502 IRBuilderBase &Builder,
503 SmallVector<Value *> &ArgsVector) {
504 Value *Version = Builder.getInt32(OMP_KERNEL_ARG_VERSION);
505 Value *PointerNum = Builder.getInt32(KernelArgs.NumTargetItems);
506 auto Int32Ty = Type::getInt32Ty(Builder.getContext());
507 constexpr const size_t MaxDim = 3;
508 Value *ZeroArray = Constant::getNullValue(ArrayType::get(Int32Ty, MaxDim));
509 Value *Flags = Builder.getInt64(KernelArgs.HasNoWait);
510
511 assert(!KernelArgs.NumTeams.empty() && !KernelArgs.NumThreads.empty());
512
513 Value *NumTeams3D =
514 Builder.CreateInsertValue(ZeroArray, KernelArgs.NumTeams[0], {0});
515 Value *NumThreads3D =
516 Builder.CreateInsertValue(ZeroArray, KernelArgs.NumThreads[0], {0});
517 for (unsigned I :
518 seq<unsigned>(1, std::min(KernelArgs.NumTeams.size(), MaxDim)))
519 NumTeams3D =
520 Builder.CreateInsertValue(NumTeams3D, KernelArgs.NumTeams[I], {I});
521 for (unsigned I :
522 seq<unsigned>(1, std::min(KernelArgs.NumThreads.size(), MaxDim)))
523 NumThreads3D =
524 Builder.CreateInsertValue(NumThreads3D, KernelArgs.NumThreads[I], {I});
525
526 ArgsVector = {Version,
527 PointerNum,
528 KernelArgs.RTArgs.BasePointersArray,
529 KernelArgs.RTArgs.PointersArray,
530 KernelArgs.RTArgs.SizesArray,
531 KernelArgs.RTArgs.MapTypesArray,
532 KernelArgs.RTArgs.MapNamesArray,
533 KernelArgs.RTArgs.MappersArray,
534 KernelArgs.NumIterations,
535 Flags,
536 NumTeams3D,
537 NumThreads3D,
538 KernelArgs.DynCGGroupMem};
539 }
540
addAttributes(omp::RuntimeFunction FnID,Function & Fn)541 void OpenMPIRBuilder::addAttributes(omp::RuntimeFunction FnID, Function &Fn) {
542 LLVMContext &Ctx = Fn.getContext();
543
544 // Get the function's current attributes.
545 auto Attrs = Fn.getAttributes();
546 auto FnAttrs = Attrs.getFnAttrs();
547 auto RetAttrs = Attrs.getRetAttrs();
548 SmallVector<AttributeSet, 4> ArgAttrs;
549 for (size_t ArgNo = 0; ArgNo < Fn.arg_size(); ++ArgNo)
550 ArgAttrs.emplace_back(Attrs.getParamAttrs(ArgNo));
551
552 // Add AS to FnAS while taking special care with integer extensions.
553 auto addAttrSet = [&](AttributeSet &FnAS, const AttributeSet &AS,
554 bool Param = true) -> void {
555 bool HasSignExt = AS.hasAttribute(Attribute::SExt);
556 bool HasZeroExt = AS.hasAttribute(Attribute::ZExt);
557 if (HasSignExt || HasZeroExt) {
558 assert(AS.getNumAttributes() == 1 &&
559 "Currently not handling extension attr combined with others.");
560 if (Param) {
561 if (auto AK = TargetLibraryInfo::getExtAttrForI32Param(T, HasSignExt))
562 FnAS = FnAS.addAttribute(Ctx, AK);
563 } else if (auto AK =
564 TargetLibraryInfo::getExtAttrForI32Return(T, HasSignExt))
565 FnAS = FnAS.addAttribute(Ctx, AK);
566 } else {
567 FnAS = FnAS.addAttributes(Ctx, AS);
568 }
569 };
570
571 #define OMP_ATTRS_SET(VarName, AttrSet) AttributeSet VarName = AttrSet;
572 #include "llvm/Frontend/OpenMP/OMPKinds.def"
573
574 // Add attributes to the function declaration.
575 switch (FnID) {
576 #define OMP_RTL_ATTRS(Enum, FnAttrSet, RetAttrSet, ArgAttrSets) \
577 case Enum: \
578 FnAttrs = FnAttrs.addAttributes(Ctx, FnAttrSet); \
579 addAttrSet(RetAttrs, RetAttrSet, /*Param*/ false); \
580 for (size_t ArgNo = 0; ArgNo < ArgAttrSets.size(); ++ArgNo) \
581 addAttrSet(ArgAttrs[ArgNo], ArgAttrSets[ArgNo]); \
582 Fn.setAttributes(AttributeList::get(Ctx, FnAttrs, RetAttrs, ArgAttrs)); \
583 break;
584 #include "llvm/Frontend/OpenMP/OMPKinds.def"
585 default:
586 // Attributes are optional.
587 break;
588 }
589 }
590
591 FunctionCallee
getOrCreateRuntimeFunction(Module & M,RuntimeFunction FnID)592 OpenMPIRBuilder::getOrCreateRuntimeFunction(Module &M, RuntimeFunction FnID) {
593 FunctionType *FnTy = nullptr;
594 Function *Fn = nullptr;
595
596 // Try to find the declation in the module first.
597 switch (FnID) {
598 #define OMP_RTL(Enum, Str, IsVarArg, ReturnType, ...) \
599 case Enum: \
600 FnTy = FunctionType::get(ReturnType, ArrayRef<Type *>{__VA_ARGS__}, \
601 IsVarArg); \
602 Fn = M.getFunction(Str); \
603 break;
604 #include "llvm/Frontend/OpenMP/OMPKinds.def"
605 }
606
607 if (!Fn) {
608 // Create a new declaration if we need one.
609 switch (FnID) {
610 #define OMP_RTL(Enum, Str, ...) \
611 case Enum: \
612 Fn = Function::Create(FnTy, GlobalValue::ExternalLinkage, Str, M); \
613 break;
614 #include "llvm/Frontend/OpenMP/OMPKinds.def"
615 }
616
617 // Add information if the runtime function takes a callback function
618 if (FnID == OMPRTL___kmpc_fork_call || FnID == OMPRTL___kmpc_fork_teams) {
619 if (!Fn->hasMetadata(LLVMContext::MD_callback)) {
620 LLVMContext &Ctx = Fn->getContext();
621 MDBuilder MDB(Ctx);
622 // Annotate the callback behavior of the runtime function:
623 // - The callback callee is argument number 2 (microtask).
624 // - The first two arguments of the callback callee are unknown (-1).
625 // - All variadic arguments to the runtime function are passed to the
626 // callback callee.
627 Fn->addMetadata(
628 LLVMContext::MD_callback,
629 *MDNode::get(Ctx, {MDB.createCallbackEncoding(
630 2, {-1, -1}, /* VarArgsArePassed */ true)}));
631 }
632 }
633
634 LLVM_DEBUG(dbgs() << "Created OpenMP runtime function " << Fn->getName()
635 << " with type " << *Fn->getFunctionType() << "\n");
636 addAttributes(FnID, *Fn);
637
638 } else {
639 LLVM_DEBUG(dbgs() << "Found OpenMP runtime function " << Fn->getName()
640 << " with type " << *Fn->getFunctionType() << "\n");
641 }
642
643 assert(Fn && "Failed to create OpenMP runtime function");
644
645 return {FnTy, Fn};
646 }
647
getOrCreateRuntimeFunctionPtr(RuntimeFunction FnID)648 Function *OpenMPIRBuilder::getOrCreateRuntimeFunctionPtr(RuntimeFunction FnID) {
649 FunctionCallee RTLFn = getOrCreateRuntimeFunction(M, FnID);
650 auto *Fn = dyn_cast<llvm::Function>(RTLFn.getCallee());
651 assert(Fn && "Failed to create OpenMP runtime function pointer");
652 return Fn;
653 }
654
initialize()655 void OpenMPIRBuilder::initialize() { initializeTypes(M); }
656
raiseUserConstantDataAllocasToEntryBlock(IRBuilderBase & Builder,Function * Function)657 static void raiseUserConstantDataAllocasToEntryBlock(IRBuilderBase &Builder,
658 Function *Function) {
659 BasicBlock &EntryBlock = Function->getEntryBlock();
660 BasicBlock::iterator MoveLocInst = EntryBlock.getFirstNonPHIIt();
661
662 // Loop over blocks looking for constant allocas, skipping the entry block
663 // as any allocas there are already in the desired location.
664 for (auto Block = std::next(Function->begin(), 1); Block != Function->end();
665 Block++) {
666 for (auto Inst = Block->getReverseIterator()->begin();
667 Inst != Block->getReverseIterator()->end();) {
668 if (auto *AllocaInst = dyn_cast_if_present<llvm::AllocaInst>(Inst)) {
669 Inst++;
670 if (!isa<ConstantData>(AllocaInst->getArraySize()))
671 continue;
672 AllocaInst->moveBeforePreserving(MoveLocInst);
673 } else {
674 Inst++;
675 }
676 }
677 }
678 }
679
finalize(Function * Fn)680 void OpenMPIRBuilder::finalize(Function *Fn) {
681 SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
682 SmallVector<BasicBlock *, 32> Blocks;
683 SmallVector<OutlineInfo, 16> DeferredOutlines;
684 for (OutlineInfo &OI : OutlineInfos) {
685 // Skip functions that have not finalized yet; may happen with nested
686 // function generation.
687 if (Fn && OI.getFunction() != Fn) {
688 DeferredOutlines.push_back(OI);
689 continue;
690 }
691
692 ParallelRegionBlockSet.clear();
693 Blocks.clear();
694 OI.collectBlocks(ParallelRegionBlockSet, Blocks);
695
696 Function *OuterFn = OI.getFunction();
697 CodeExtractorAnalysisCache CEAC(*OuterFn);
698 // If we generate code for the target device, we need to allocate
699 // struct for aggregate params in the device default alloca address space.
700 // OpenMP runtime requires that the params of the extracted functions are
701 // passed as zero address space pointers. This flag ensures that
702 // CodeExtractor generates correct code for extracted functions
703 // which are used by OpenMP runtime.
704 bool ArgsInZeroAddressSpace = Config.isTargetDevice();
705 CodeExtractor Extractor(Blocks, /* DominatorTree */ nullptr,
706 /* AggregateArgs */ true,
707 /* BlockFrequencyInfo */ nullptr,
708 /* BranchProbabilityInfo */ nullptr,
709 /* AssumptionCache */ nullptr,
710 /* AllowVarArgs */ true,
711 /* AllowAlloca */ true,
712 /* AllocaBlock*/ OI.OuterAllocaBB,
713 /* Suffix */ ".omp_par", ArgsInZeroAddressSpace);
714
715 LLVM_DEBUG(dbgs() << "Before outlining: " << *OuterFn << "\n");
716 LLVM_DEBUG(dbgs() << "Entry " << OI.EntryBB->getName()
717 << " Exit: " << OI.ExitBB->getName() << "\n");
718 assert(Extractor.isEligible() &&
719 "Expected OpenMP outlining to be possible!");
720
721 for (auto *V : OI.ExcludeArgsFromAggregate)
722 Extractor.excludeArgFromAggregate(V);
723
724 Function *OutlinedFn = Extractor.extractCodeRegion(CEAC);
725
726 // Forward target-cpu, target-features attributes to the outlined function.
727 auto TargetCpuAttr = OuterFn->getFnAttribute("target-cpu");
728 if (TargetCpuAttr.isStringAttribute())
729 OutlinedFn->addFnAttr(TargetCpuAttr);
730
731 auto TargetFeaturesAttr = OuterFn->getFnAttribute("target-features");
732 if (TargetFeaturesAttr.isStringAttribute())
733 OutlinedFn->addFnAttr(TargetFeaturesAttr);
734
735 LLVM_DEBUG(dbgs() << "After outlining: " << *OuterFn << "\n");
736 LLVM_DEBUG(dbgs() << " Outlined function: " << *OutlinedFn << "\n");
737 assert(OutlinedFn->getReturnType()->isVoidTy() &&
738 "OpenMP outlined functions should not return a value!");
739
740 // For compability with the clang CG we move the outlined function after the
741 // one with the parallel region.
742 OutlinedFn->removeFromParent();
743 M.getFunctionList().insertAfter(OuterFn->getIterator(), OutlinedFn);
744
745 // Remove the artificial entry introduced by the extractor right away, we
746 // made our own entry block after all.
747 {
748 BasicBlock &ArtificialEntry = OutlinedFn->getEntryBlock();
749 assert(ArtificialEntry.getUniqueSuccessor() == OI.EntryBB);
750 assert(OI.EntryBB->getUniquePredecessor() == &ArtificialEntry);
751 // Move instructions from the to-be-deleted ArtificialEntry to the entry
752 // basic block of the parallel region. CodeExtractor generates
753 // instructions to unwrap the aggregate argument and may sink
754 // allocas/bitcasts for values that are solely used in the outlined region
755 // and do not escape.
756 assert(!ArtificialEntry.empty() &&
757 "Expected instructions to add in the outlined region entry");
758 for (BasicBlock::reverse_iterator It = ArtificialEntry.rbegin(),
759 End = ArtificialEntry.rend();
760 It != End;) {
761 Instruction &I = *It;
762 It++;
763
764 if (I.isTerminator()) {
765 // Absorb any debug value that terminator may have
766 if (OI.EntryBB->getTerminator())
767 OI.EntryBB->getTerminator()->adoptDbgRecords(
768 &ArtificialEntry, I.getIterator(), false);
769 continue;
770 }
771
772 I.moveBeforePreserving(*OI.EntryBB, OI.EntryBB->getFirstInsertionPt());
773 }
774
775 OI.EntryBB->moveBefore(&ArtificialEntry);
776 ArtificialEntry.eraseFromParent();
777 }
778 assert(&OutlinedFn->getEntryBlock() == OI.EntryBB);
779 assert(OutlinedFn && OutlinedFn->hasNUses(1));
780
781 // Run a user callback, e.g. to add attributes.
782 if (OI.PostOutlineCB)
783 OI.PostOutlineCB(*OutlinedFn);
784 }
785
786 // Remove work items that have been completed.
787 OutlineInfos = std::move(DeferredOutlines);
788
789 // The createTarget functions embeds user written code into
790 // the target region which may inject allocas which need to
791 // be moved to the entry block of our target or risk malformed
792 // optimisations by later passes, this is only relevant for
793 // the device pass which appears to be a little more delicate
794 // when it comes to optimisations (however, we do not block on
795 // that here, it's up to the inserter to the list to do so).
796 // This notbaly has to occur after the OutlinedInfo candidates
797 // have been extracted so we have an end product that will not
798 // be implicitly adversely affected by any raises unless
799 // intentionally appended to the list.
800 // NOTE: This only does so for ConstantData, it could be extended
801 // to ConstantExpr's with further effort, however, they should
802 // largely be folded when they get here. Extending it to runtime
803 // defined/read+writeable allocation sizes would be non-trivial
804 // (need to factor in movement of any stores to variables the
805 // allocation size depends on, as well as the usual loads,
806 // otherwise it'll yield the wrong result after movement) and
807 // likely be more suitable as an LLVM optimisation pass.
808 for (Function *F : ConstantAllocaRaiseCandidates)
809 raiseUserConstantDataAllocasToEntryBlock(Builder, F);
810
811 EmitMetadataErrorReportFunctionTy &&ErrorReportFn =
812 [](EmitMetadataErrorKind Kind,
813 const TargetRegionEntryInfo &EntryInfo) -> void {
814 errs() << "Error of kind: " << Kind
815 << " when emitting offload entries and metadata during "
816 "OMPIRBuilder finalization \n";
817 };
818
819 if (!OffloadInfoManager.empty())
820 createOffloadEntriesAndInfoMetadata(ErrorReportFn);
821
822 if (Config.EmitLLVMUsedMetaInfo.value_or(false)) {
823 std::vector<WeakTrackingVH> LLVMCompilerUsed = {
824 M.getGlobalVariable("__openmp_nvptx_data_transfer_temporary_storage")};
825 emitUsed("llvm.compiler.used", LLVMCompilerUsed);
826 }
827
828 IsFinalized = true;
829 }
830
isFinalized()831 bool OpenMPIRBuilder::isFinalized() { return IsFinalized; }
832
~OpenMPIRBuilder()833 OpenMPIRBuilder::~OpenMPIRBuilder() {
834 assert(OutlineInfos.empty() && "There must be no outstanding outlinings");
835 }
836
createGlobalFlag(unsigned Value,StringRef Name)837 GlobalValue *OpenMPIRBuilder::createGlobalFlag(unsigned Value, StringRef Name) {
838 IntegerType *I32Ty = Type::getInt32Ty(M.getContext());
839 auto *GV =
840 new GlobalVariable(M, I32Ty,
841 /* isConstant = */ true, GlobalValue::WeakODRLinkage,
842 ConstantInt::get(I32Ty, Value), Name);
843 GV->setVisibility(GlobalValue::HiddenVisibility);
844
845 return GV;
846 }
847
emitUsed(StringRef Name,ArrayRef<WeakTrackingVH> List)848 void OpenMPIRBuilder::emitUsed(StringRef Name, ArrayRef<WeakTrackingVH> List) {
849 if (List.empty())
850 return;
851
852 // Convert List to what ConstantArray needs.
853 SmallVector<Constant *, 8> UsedArray;
854 UsedArray.resize(List.size());
855 for (unsigned I = 0, E = List.size(); I != E; ++I)
856 UsedArray[I] = ConstantExpr::getPointerBitCastOrAddrSpaceCast(
857 cast<Constant>(&*List[I]), Builder.getPtrTy());
858
859 if (UsedArray.empty())
860 return;
861 ArrayType *ATy = ArrayType::get(Builder.getPtrTy(), UsedArray.size());
862
863 auto *GV = new GlobalVariable(M, ATy, false, GlobalValue::AppendingLinkage,
864 ConstantArray::get(ATy, UsedArray), Name);
865
866 GV->setSection("llvm.metadata");
867 }
868
869 GlobalVariable *
emitKernelExecutionMode(StringRef KernelName,OMPTgtExecModeFlags Mode)870 OpenMPIRBuilder::emitKernelExecutionMode(StringRef KernelName,
871 OMPTgtExecModeFlags Mode) {
872 auto *Int8Ty = Builder.getInt8Ty();
873 auto *GVMode = new GlobalVariable(
874 M, Int8Ty, /*isConstant=*/true, GlobalValue::WeakAnyLinkage,
875 ConstantInt::get(Int8Ty, Mode), Twine(KernelName, "_exec_mode"));
876 GVMode->setVisibility(GlobalVariable::ProtectedVisibility);
877 return GVMode;
878 }
879
getOrCreateIdent(Constant * SrcLocStr,uint32_t SrcLocStrSize,IdentFlag LocFlags,unsigned Reserve2Flags)880 Constant *OpenMPIRBuilder::getOrCreateIdent(Constant *SrcLocStr,
881 uint32_t SrcLocStrSize,
882 IdentFlag LocFlags,
883 unsigned Reserve2Flags) {
884 // Enable "C-mode".
885 LocFlags |= OMP_IDENT_FLAG_KMPC;
886
887 Constant *&Ident =
888 IdentMap[{SrcLocStr, uint64_t(LocFlags) << 31 | Reserve2Flags}];
889 if (!Ident) {
890 Constant *I32Null = ConstantInt::getNullValue(Int32);
891 Constant *IdentData[] = {I32Null,
892 ConstantInt::get(Int32, uint32_t(LocFlags)),
893 ConstantInt::get(Int32, Reserve2Flags),
894 ConstantInt::get(Int32, SrcLocStrSize), SrcLocStr};
895 Constant *Initializer =
896 ConstantStruct::get(OpenMPIRBuilder::Ident, IdentData);
897
898 // Look for existing encoding of the location + flags, not needed but
899 // minimizes the difference to the existing solution while we transition.
900 for (GlobalVariable &GV : M.globals())
901 if (GV.getValueType() == OpenMPIRBuilder::Ident && GV.hasInitializer())
902 if (GV.getInitializer() == Initializer)
903 Ident = &GV;
904
905 if (!Ident) {
906 auto *GV = new GlobalVariable(
907 M, OpenMPIRBuilder::Ident,
908 /* isConstant = */ true, GlobalValue::PrivateLinkage, Initializer, "",
909 nullptr, GlobalValue::NotThreadLocal,
910 M.getDataLayout().getDefaultGlobalsAddressSpace());
911 GV->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
912 GV->setAlignment(Align(8));
913 Ident = GV;
914 }
915 }
916
917 return ConstantExpr::getPointerBitCastOrAddrSpaceCast(Ident, IdentPtr);
918 }
919
getOrCreateSrcLocStr(StringRef LocStr,uint32_t & SrcLocStrSize)920 Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(StringRef LocStr,
921 uint32_t &SrcLocStrSize) {
922 SrcLocStrSize = LocStr.size();
923 Constant *&SrcLocStr = SrcLocStrMap[LocStr];
924 if (!SrcLocStr) {
925 Constant *Initializer =
926 ConstantDataArray::getString(M.getContext(), LocStr);
927
928 // Look for existing encoding of the location, not needed but minimizes the
929 // difference to the existing solution while we transition.
930 for (GlobalVariable &GV : M.globals())
931 if (GV.isConstant() && GV.hasInitializer() &&
932 GV.getInitializer() == Initializer)
933 return SrcLocStr = ConstantExpr::getPointerCast(&GV, Int8Ptr);
934
935 SrcLocStr = Builder.CreateGlobalString(LocStr, /* Name */ "",
936 /* AddressSpace */ 0, &M);
937 }
938 return SrcLocStr;
939 }
940
getOrCreateSrcLocStr(StringRef FunctionName,StringRef FileName,unsigned Line,unsigned Column,uint32_t & SrcLocStrSize)941 Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(StringRef FunctionName,
942 StringRef FileName,
943 unsigned Line, unsigned Column,
944 uint32_t &SrcLocStrSize) {
945 SmallString<128> Buffer;
946 Buffer.push_back(';');
947 Buffer.append(FileName);
948 Buffer.push_back(';');
949 Buffer.append(FunctionName);
950 Buffer.push_back(';');
951 Buffer.append(std::to_string(Line));
952 Buffer.push_back(';');
953 Buffer.append(std::to_string(Column));
954 Buffer.push_back(';');
955 Buffer.push_back(';');
956 return getOrCreateSrcLocStr(Buffer.str(), SrcLocStrSize);
957 }
958
959 Constant *
getOrCreateDefaultSrcLocStr(uint32_t & SrcLocStrSize)960 OpenMPIRBuilder::getOrCreateDefaultSrcLocStr(uint32_t &SrcLocStrSize) {
961 StringRef UnknownLoc = ";unknown;unknown;0;0;;";
962 return getOrCreateSrcLocStr(UnknownLoc, SrcLocStrSize);
963 }
964
getOrCreateSrcLocStr(DebugLoc DL,uint32_t & SrcLocStrSize,Function * F)965 Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(DebugLoc DL,
966 uint32_t &SrcLocStrSize,
967 Function *F) {
968 DILocation *DIL = DL.get();
969 if (!DIL)
970 return getOrCreateDefaultSrcLocStr(SrcLocStrSize);
971 StringRef FileName = M.getName();
972 if (DIFile *DIF = DIL->getFile())
973 if (std::optional<StringRef> Source = DIF->getSource())
974 FileName = *Source;
975 StringRef Function = DIL->getScope()->getSubprogram()->getName();
976 if (Function.empty() && F)
977 Function = F->getName();
978 return getOrCreateSrcLocStr(Function, FileName, DIL->getLine(),
979 DIL->getColumn(), SrcLocStrSize);
980 }
981
getOrCreateSrcLocStr(const LocationDescription & Loc,uint32_t & SrcLocStrSize)982 Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(const LocationDescription &Loc,
983 uint32_t &SrcLocStrSize) {
984 return getOrCreateSrcLocStr(Loc.DL, SrcLocStrSize,
985 Loc.IP.getBlock()->getParent());
986 }
987
getOrCreateThreadID(Value * Ident)988 Value *OpenMPIRBuilder::getOrCreateThreadID(Value *Ident) {
989 return Builder.CreateCall(
990 getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_global_thread_num), Ident,
991 "omp_global_thread_num");
992 }
993
994 OpenMPIRBuilder::InsertPointOrErrorTy
createBarrier(const LocationDescription & Loc,Directive Kind,bool ForceSimpleCall,bool CheckCancelFlag)995 OpenMPIRBuilder::createBarrier(const LocationDescription &Loc, Directive Kind,
996 bool ForceSimpleCall, bool CheckCancelFlag) {
997 if (!updateToLocation(Loc))
998 return Loc.IP;
999
1000 // Build call __kmpc_cancel_barrier(loc, thread_id) or
1001 // __kmpc_barrier(loc, thread_id);
1002
1003 IdentFlag BarrierLocFlags;
1004 switch (Kind) {
1005 case OMPD_for:
1006 BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL_FOR;
1007 break;
1008 case OMPD_sections:
1009 BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL_SECTIONS;
1010 break;
1011 case OMPD_single:
1012 BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL_SINGLE;
1013 break;
1014 case OMPD_barrier:
1015 BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_EXPL;
1016 break;
1017 default:
1018 BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL;
1019 break;
1020 }
1021
1022 uint32_t SrcLocStrSize;
1023 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1024 Value *Args[] = {
1025 getOrCreateIdent(SrcLocStr, SrcLocStrSize, BarrierLocFlags),
1026 getOrCreateThreadID(getOrCreateIdent(SrcLocStr, SrcLocStrSize))};
1027
1028 // If we are in a cancellable parallel region, barriers are cancellation
1029 // points.
1030 // TODO: Check why we would force simple calls or to ignore the cancel flag.
1031 bool UseCancelBarrier =
1032 !ForceSimpleCall && isLastFinalizationInfoCancellable(OMPD_parallel);
1033
1034 Value *Result =
1035 Builder.CreateCall(getOrCreateRuntimeFunctionPtr(
1036 UseCancelBarrier ? OMPRTL___kmpc_cancel_barrier
1037 : OMPRTL___kmpc_barrier),
1038 Args);
1039
1040 if (UseCancelBarrier && CheckCancelFlag)
1041 if (Error Err = emitCancelationCheckImpl(Result, OMPD_parallel))
1042 return Err;
1043
1044 return Builder.saveIP();
1045 }
1046
1047 OpenMPIRBuilder::InsertPointOrErrorTy
createCancel(const LocationDescription & Loc,Value * IfCondition,omp::Directive CanceledDirective)1048 OpenMPIRBuilder::createCancel(const LocationDescription &Loc,
1049 Value *IfCondition,
1050 omp::Directive CanceledDirective) {
1051 if (!updateToLocation(Loc))
1052 return Loc.IP;
1053
1054 // LLVM utilities like blocks with terminators.
1055 auto *UI = Builder.CreateUnreachable();
1056
1057 Instruction *ThenTI = UI, *ElseTI = nullptr;
1058 if (IfCondition)
1059 SplitBlockAndInsertIfThenElse(IfCondition, UI, &ThenTI, &ElseTI);
1060 Builder.SetInsertPoint(ThenTI);
1061
1062 Value *CancelKind = nullptr;
1063 switch (CanceledDirective) {
1064 #define OMP_CANCEL_KIND(Enum, Str, DirectiveEnum, Value) \
1065 case DirectiveEnum: \
1066 CancelKind = Builder.getInt32(Value); \
1067 break;
1068 #include "llvm/Frontend/OpenMP/OMPKinds.def"
1069 default:
1070 llvm_unreachable("Unknown cancel kind!");
1071 }
1072
1073 uint32_t SrcLocStrSize;
1074 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1075 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1076 Value *Args[] = {Ident, getOrCreateThreadID(Ident), CancelKind};
1077 Value *Result = Builder.CreateCall(
1078 getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_cancel), Args);
1079 auto ExitCB = [this, CanceledDirective, Loc](InsertPointTy IP) -> Error {
1080 if (CanceledDirective == OMPD_parallel) {
1081 IRBuilder<>::InsertPointGuard IPG(Builder);
1082 Builder.restoreIP(IP);
1083 return createBarrier(LocationDescription(Builder.saveIP(), Loc.DL),
1084 omp::Directive::OMPD_unknown,
1085 /* ForceSimpleCall */ false,
1086 /* CheckCancelFlag */ false)
1087 .takeError();
1088 }
1089 return Error::success();
1090 };
1091
1092 // The actual cancel logic is shared with others, e.g., cancel_barriers.
1093 if (Error Err = emitCancelationCheckImpl(Result, CanceledDirective, ExitCB))
1094 return Err;
1095
1096 // Update the insertion point and remove the terminator we introduced.
1097 Builder.SetInsertPoint(UI->getParent());
1098 UI->eraseFromParent();
1099
1100 return Builder.saveIP();
1101 }
1102
1103 OpenMPIRBuilder::InsertPointOrErrorTy
createCancellationPoint(const LocationDescription & Loc,omp::Directive CanceledDirective)1104 OpenMPIRBuilder::createCancellationPoint(const LocationDescription &Loc,
1105 omp::Directive CanceledDirective) {
1106 if (!updateToLocation(Loc))
1107 return Loc.IP;
1108
1109 // LLVM utilities like blocks with terminators.
1110 auto *UI = Builder.CreateUnreachable();
1111 Builder.SetInsertPoint(UI);
1112
1113 Value *CancelKind = nullptr;
1114 switch (CanceledDirective) {
1115 #define OMP_CANCEL_KIND(Enum, Str, DirectiveEnum, Value) \
1116 case DirectiveEnum: \
1117 CancelKind = Builder.getInt32(Value); \
1118 break;
1119 #include "llvm/Frontend/OpenMP/OMPKinds.def"
1120 default:
1121 llvm_unreachable("Unknown cancel kind!");
1122 }
1123
1124 uint32_t SrcLocStrSize;
1125 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1126 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1127 Value *Args[] = {Ident, getOrCreateThreadID(Ident), CancelKind};
1128 Value *Result = Builder.CreateCall(
1129 getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_cancellationpoint), Args);
1130 auto ExitCB = [this, CanceledDirective, Loc](InsertPointTy IP) -> Error {
1131 if (CanceledDirective == OMPD_parallel) {
1132 IRBuilder<>::InsertPointGuard IPG(Builder);
1133 Builder.restoreIP(IP);
1134 return createBarrier(LocationDescription(Builder.saveIP(), Loc.DL),
1135 omp::Directive::OMPD_unknown,
1136 /* ForceSimpleCall */ false,
1137 /* CheckCancelFlag */ false)
1138 .takeError();
1139 }
1140 return Error::success();
1141 };
1142
1143 // The actual cancel logic is shared with others, e.g., cancel_barriers.
1144 if (Error Err = emitCancelationCheckImpl(Result, CanceledDirective, ExitCB))
1145 return Err;
1146
1147 // Update the insertion point and remove the terminator we introduced.
1148 Builder.SetInsertPoint(UI->getParent());
1149 UI->eraseFromParent();
1150
1151 return Builder.saveIP();
1152 }
1153
emitTargetKernel(const LocationDescription & Loc,InsertPointTy AllocaIP,Value * & Return,Value * Ident,Value * DeviceID,Value * NumTeams,Value * NumThreads,Value * HostPtr,ArrayRef<Value * > KernelArgs)1154 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitTargetKernel(
1155 const LocationDescription &Loc, InsertPointTy AllocaIP, Value *&Return,
1156 Value *Ident, Value *DeviceID, Value *NumTeams, Value *NumThreads,
1157 Value *HostPtr, ArrayRef<Value *> KernelArgs) {
1158 if (!updateToLocation(Loc))
1159 return Loc.IP;
1160
1161 Builder.restoreIP(AllocaIP);
1162 auto *KernelArgsPtr =
1163 Builder.CreateAlloca(OpenMPIRBuilder::KernelArgs, nullptr, "kernel_args");
1164 Builder.restoreIP(Loc.IP);
1165
1166 for (unsigned I = 0, Size = KernelArgs.size(); I != Size; ++I) {
1167 llvm::Value *Arg =
1168 Builder.CreateStructGEP(OpenMPIRBuilder::KernelArgs, KernelArgsPtr, I);
1169 Builder.CreateAlignedStore(
1170 KernelArgs[I], Arg,
1171 M.getDataLayout().getPrefTypeAlign(KernelArgs[I]->getType()));
1172 }
1173
1174 SmallVector<Value *> OffloadingArgs{Ident, DeviceID, NumTeams,
1175 NumThreads, HostPtr, KernelArgsPtr};
1176
1177 Return = Builder.CreateCall(
1178 getOrCreateRuntimeFunction(M, OMPRTL___tgt_target_kernel),
1179 OffloadingArgs);
1180
1181 return Builder.saveIP();
1182 }
1183
emitKernelLaunch(const LocationDescription & Loc,Value * OutlinedFnID,EmitFallbackCallbackTy EmitTargetCallFallbackCB,TargetKernelArgs & Args,Value * DeviceID,Value * RTLoc,InsertPointTy AllocaIP)1184 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitKernelLaunch(
1185 const LocationDescription &Loc, Value *OutlinedFnID,
1186 EmitFallbackCallbackTy EmitTargetCallFallbackCB, TargetKernelArgs &Args,
1187 Value *DeviceID, Value *RTLoc, InsertPointTy AllocaIP) {
1188
1189 if (!updateToLocation(Loc))
1190 return Loc.IP;
1191
1192 Builder.restoreIP(Loc.IP);
1193 // On top of the arrays that were filled up, the target offloading call
1194 // takes as arguments the device id as well as the host pointer. The host
1195 // pointer is used by the runtime library to identify the current target
1196 // region, so it only has to be unique and not necessarily point to
1197 // anything. It could be the pointer to the outlined function that
1198 // implements the target region, but we aren't using that so that the
1199 // compiler doesn't need to keep that, and could therefore inline the host
1200 // function if proven worthwhile during optimization.
1201
1202 // From this point on, we need to have an ID of the target region defined.
1203 assert(OutlinedFnID && "Invalid outlined function ID!");
1204 (void)OutlinedFnID;
1205
1206 // Return value of the runtime offloading call.
1207 Value *Return = nullptr;
1208
1209 // Arguments for the target kernel.
1210 SmallVector<Value *> ArgsVector;
1211 getKernelArgsVector(Args, Builder, ArgsVector);
1212
1213 // The target region is an outlined function launched by the runtime
1214 // via calls to __tgt_target_kernel().
1215 //
1216 // Note that on the host and CPU targets, the runtime implementation of
1217 // these calls simply call the outlined function without forking threads.
1218 // The outlined functions themselves have runtime calls to
1219 // __kmpc_fork_teams() and __kmpc_fork() for this purpose, codegen'd by
1220 // the compiler in emitTeamsCall() and emitParallelCall().
1221 //
1222 // In contrast, on the NVPTX target, the implementation of
1223 // __tgt_target_teams() launches a GPU kernel with the requested number
1224 // of teams and threads so no additional calls to the runtime are required.
1225 // Check the error code and execute the host version if required.
1226 Builder.restoreIP(emitTargetKernel(
1227 Builder, AllocaIP, Return, RTLoc, DeviceID, Args.NumTeams.front(),
1228 Args.NumThreads.front(), OutlinedFnID, ArgsVector));
1229
1230 BasicBlock *OffloadFailedBlock =
1231 BasicBlock::Create(Builder.getContext(), "omp_offload.failed");
1232 BasicBlock *OffloadContBlock =
1233 BasicBlock::Create(Builder.getContext(), "omp_offload.cont");
1234 Value *Failed = Builder.CreateIsNotNull(Return);
1235 Builder.CreateCondBr(Failed, OffloadFailedBlock, OffloadContBlock);
1236
1237 auto CurFn = Builder.GetInsertBlock()->getParent();
1238 emitBlock(OffloadFailedBlock, CurFn);
1239 InsertPointOrErrorTy AfterIP = EmitTargetCallFallbackCB(Builder.saveIP());
1240 if (!AfterIP)
1241 return AfterIP.takeError();
1242 Builder.restoreIP(*AfterIP);
1243 emitBranch(OffloadContBlock);
1244 emitBlock(OffloadContBlock, CurFn, /*IsFinished=*/true);
1245 return Builder.saveIP();
1246 }
1247
emitCancelationCheckImpl(Value * CancelFlag,omp::Directive CanceledDirective,FinalizeCallbackTy ExitCB)1248 Error OpenMPIRBuilder::emitCancelationCheckImpl(
1249 Value *CancelFlag, omp::Directive CanceledDirective,
1250 FinalizeCallbackTy ExitCB) {
1251 assert(isLastFinalizationInfoCancellable(CanceledDirective) &&
1252 "Unexpected cancellation!");
1253
1254 // For a cancel barrier we create two new blocks.
1255 BasicBlock *BB = Builder.GetInsertBlock();
1256 BasicBlock *NonCancellationBlock;
1257 if (Builder.GetInsertPoint() == BB->end()) {
1258 // TODO: This branch will not be needed once we moved to the
1259 // OpenMPIRBuilder codegen completely.
1260 NonCancellationBlock = BasicBlock::Create(
1261 BB->getContext(), BB->getName() + ".cont", BB->getParent());
1262 } else {
1263 NonCancellationBlock = SplitBlock(BB, &*Builder.GetInsertPoint());
1264 BB->getTerminator()->eraseFromParent();
1265 Builder.SetInsertPoint(BB);
1266 }
1267 BasicBlock *CancellationBlock = BasicBlock::Create(
1268 BB->getContext(), BB->getName() + ".cncl", BB->getParent());
1269
1270 // Jump to them based on the return value.
1271 Value *Cmp = Builder.CreateIsNull(CancelFlag);
1272 Builder.CreateCondBr(Cmp, NonCancellationBlock, CancellationBlock,
1273 /* TODO weight */ nullptr, nullptr);
1274
1275 // From the cancellation block we finalize all variables and go to the
1276 // post finalization block that is known to the FiniCB callback.
1277 Builder.SetInsertPoint(CancellationBlock);
1278 if (ExitCB)
1279 if (Error Err = ExitCB(Builder.saveIP()))
1280 return Err;
1281 auto &FI = FinalizationStack.back();
1282 if (Error Err = FI.FiniCB(Builder.saveIP()))
1283 return Err;
1284
1285 // The continuation block is where code generation continues.
1286 Builder.SetInsertPoint(NonCancellationBlock, NonCancellationBlock->begin());
1287 return Error::success();
1288 }
1289
1290 // Callback used to create OpenMP runtime calls to support
1291 // omp parallel clause for the device.
1292 // We need to use this callback to replace call to the OutlinedFn in OuterFn
1293 // by the call to the OpenMP DeviceRTL runtime function (kmpc_parallel_51)
targetParallelCallback(OpenMPIRBuilder * OMPIRBuilder,Function & OutlinedFn,Function * OuterFn,BasicBlock * OuterAllocaBB,Value * Ident,Value * IfCondition,Value * NumThreads,Instruction * PrivTID,AllocaInst * PrivTIDAddr,Value * ThreadID,const SmallVector<Instruction *,4> & ToBeDeleted)1294 static void targetParallelCallback(
1295 OpenMPIRBuilder *OMPIRBuilder, Function &OutlinedFn, Function *OuterFn,
1296 BasicBlock *OuterAllocaBB, Value *Ident, Value *IfCondition,
1297 Value *NumThreads, Instruction *PrivTID, AllocaInst *PrivTIDAddr,
1298 Value *ThreadID, const SmallVector<Instruction *, 4> &ToBeDeleted) {
1299 // Add some known attributes.
1300 IRBuilder<> &Builder = OMPIRBuilder->Builder;
1301 OutlinedFn.addParamAttr(0, Attribute::NoAlias);
1302 OutlinedFn.addParamAttr(1, Attribute::NoAlias);
1303 OutlinedFn.addParamAttr(0, Attribute::NoUndef);
1304 OutlinedFn.addParamAttr(1, Attribute::NoUndef);
1305 OutlinedFn.addFnAttr(Attribute::NoUnwind);
1306
1307 assert(OutlinedFn.arg_size() >= 2 &&
1308 "Expected at least tid and bounded tid as arguments");
1309 unsigned NumCapturedVars = OutlinedFn.arg_size() - /* tid & bounded tid */ 2;
1310
1311 CallInst *CI = cast<CallInst>(OutlinedFn.user_back());
1312 assert(CI && "Expected call instruction to outlined function");
1313 CI->getParent()->setName("omp_parallel");
1314
1315 Builder.SetInsertPoint(CI);
1316 Type *PtrTy = OMPIRBuilder->VoidPtr;
1317 Value *NullPtrValue = Constant::getNullValue(PtrTy);
1318
1319 // Add alloca for kernel args
1320 OpenMPIRBuilder ::InsertPointTy CurrentIP = Builder.saveIP();
1321 Builder.SetInsertPoint(OuterAllocaBB, OuterAllocaBB->getFirstInsertionPt());
1322 AllocaInst *ArgsAlloca =
1323 Builder.CreateAlloca(ArrayType::get(PtrTy, NumCapturedVars));
1324 Value *Args = ArgsAlloca;
1325 // Add address space cast if array for storing arguments is not allocated
1326 // in address space 0
1327 if (ArgsAlloca->getAddressSpace())
1328 Args = Builder.CreatePointerCast(ArgsAlloca, PtrTy);
1329 Builder.restoreIP(CurrentIP);
1330
1331 // Store captured vars which are used by kmpc_parallel_51
1332 for (unsigned Idx = 0; Idx < NumCapturedVars; Idx++) {
1333 Value *V = *(CI->arg_begin() + 2 + Idx);
1334 Value *StoreAddress = Builder.CreateConstInBoundsGEP2_64(
1335 ArrayType::get(PtrTy, NumCapturedVars), Args, 0, Idx);
1336 Builder.CreateStore(V, StoreAddress);
1337 }
1338
1339 Value *Cond =
1340 IfCondition ? Builder.CreateSExtOrTrunc(IfCondition, OMPIRBuilder->Int32)
1341 : Builder.getInt32(1);
1342
1343 // Build kmpc_parallel_51 call
1344 Value *Parallel51CallArgs[] = {
1345 /* identifier*/ Ident,
1346 /* global thread num*/ ThreadID,
1347 /* if expression */ Cond,
1348 /* number of threads */ NumThreads ? NumThreads : Builder.getInt32(-1),
1349 /* Proc bind */ Builder.getInt32(-1),
1350 /* outlined function */ &OutlinedFn,
1351 /* wrapper function */ NullPtrValue,
1352 /* arguments of the outlined funciton*/ Args,
1353 /* number of arguments */ Builder.getInt64(NumCapturedVars)};
1354
1355 FunctionCallee RTLFn =
1356 OMPIRBuilder->getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_parallel_51);
1357
1358 Builder.CreateCall(RTLFn, Parallel51CallArgs);
1359
1360 LLVM_DEBUG(dbgs() << "With kmpc_parallel_51 placed: "
1361 << *Builder.GetInsertBlock()->getParent() << "\n");
1362
1363 // Initialize the local TID stack location with the argument value.
1364 Builder.SetInsertPoint(PrivTID);
1365 Function::arg_iterator OutlinedAI = OutlinedFn.arg_begin();
1366 Builder.CreateStore(Builder.CreateLoad(OMPIRBuilder->Int32, OutlinedAI),
1367 PrivTIDAddr);
1368
1369 // Remove redundant call to the outlined function.
1370 CI->eraseFromParent();
1371
1372 for (Instruction *I : ToBeDeleted) {
1373 I->eraseFromParent();
1374 }
1375 }
1376
1377 // Callback used to create OpenMP runtime calls to support
1378 // omp parallel clause for the host.
1379 // We need to use this callback to replace call to the OutlinedFn in OuterFn
1380 // by the call to the OpenMP host runtime function ( __kmpc_fork_call[_if])
1381 static void
hostParallelCallback(OpenMPIRBuilder * OMPIRBuilder,Function & OutlinedFn,Function * OuterFn,Value * Ident,Value * IfCondition,Instruction * PrivTID,AllocaInst * PrivTIDAddr,const SmallVector<Instruction *,4> & ToBeDeleted)1382 hostParallelCallback(OpenMPIRBuilder *OMPIRBuilder, Function &OutlinedFn,
1383 Function *OuterFn, Value *Ident, Value *IfCondition,
1384 Instruction *PrivTID, AllocaInst *PrivTIDAddr,
1385 const SmallVector<Instruction *, 4> &ToBeDeleted) {
1386 IRBuilder<> &Builder = OMPIRBuilder->Builder;
1387 FunctionCallee RTLFn;
1388 if (IfCondition) {
1389 RTLFn =
1390 OMPIRBuilder->getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_call_if);
1391 } else {
1392 RTLFn =
1393 OMPIRBuilder->getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_call);
1394 }
1395 if (auto *F = dyn_cast<Function>(RTLFn.getCallee())) {
1396 if (!F->hasMetadata(LLVMContext::MD_callback)) {
1397 LLVMContext &Ctx = F->getContext();
1398 MDBuilder MDB(Ctx);
1399 // Annotate the callback behavior of the __kmpc_fork_call:
1400 // - The callback callee is argument number 2 (microtask).
1401 // - The first two arguments of the callback callee are unknown (-1).
1402 // - All variadic arguments to the __kmpc_fork_call are passed to the
1403 // callback callee.
1404 F->addMetadata(LLVMContext::MD_callback,
1405 *MDNode::get(Ctx, {MDB.createCallbackEncoding(
1406 2, {-1, -1},
1407 /* VarArgsArePassed */ true)}));
1408 }
1409 }
1410 // Add some known attributes.
1411 OutlinedFn.addParamAttr(0, Attribute::NoAlias);
1412 OutlinedFn.addParamAttr(1, Attribute::NoAlias);
1413 OutlinedFn.addFnAttr(Attribute::NoUnwind);
1414
1415 assert(OutlinedFn.arg_size() >= 2 &&
1416 "Expected at least tid and bounded tid as arguments");
1417 unsigned NumCapturedVars = OutlinedFn.arg_size() - /* tid & bounded tid */ 2;
1418
1419 CallInst *CI = cast<CallInst>(OutlinedFn.user_back());
1420 CI->getParent()->setName("omp_parallel");
1421 Builder.SetInsertPoint(CI);
1422
1423 // Build call __kmpc_fork_call[_if](Ident, n, microtask, var1, .., varn);
1424 Value *ForkCallArgs[] = {Ident, Builder.getInt32(NumCapturedVars),
1425 &OutlinedFn};
1426
1427 SmallVector<Value *, 16> RealArgs;
1428 RealArgs.append(std::begin(ForkCallArgs), std::end(ForkCallArgs));
1429 if (IfCondition) {
1430 Value *Cond = Builder.CreateSExtOrTrunc(IfCondition, OMPIRBuilder->Int32);
1431 RealArgs.push_back(Cond);
1432 }
1433 RealArgs.append(CI->arg_begin() + /* tid & bound tid */ 2, CI->arg_end());
1434
1435 // __kmpc_fork_call_if always expects a void ptr as the last argument
1436 // If there are no arguments, pass a null pointer.
1437 auto PtrTy = OMPIRBuilder->VoidPtr;
1438 if (IfCondition && NumCapturedVars == 0) {
1439 Value *NullPtrValue = Constant::getNullValue(PtrTy);
1440 RealArgs.push_back(NullPtrValue);
1441 }
1442
1443 Builder.CreateCall(RTLFn, RealArgs);
1444
1445 LLVM_DEBUG(dbgs() << "With fork_call placed: "
1446 << *Builder.GetInsertBlock()->getParent() << "\n");
1447
1448 // Initialize the local TID stack location with the argument value.
1449 Builder.SetInsertPoint(PrivTID);
1450 Function::arg_iterator OutlinedAI = OutlinedFn.arg_begin();
1451 Builder.CreateStore(Builder.CreateLoad(OMPIRBuilder->Int32, OutlinedAI),
1452 PrivTIDAddr);
1453
1454 // Remove redundant call to the outlined function.
1455 CI->eraseFromParent();
1456
1457 for (Instruction *I : ToBeDeleted) {
1458 I->eraseFromParent();
1459 }
1460 }
1461
createParallel(const LocationDescription & Loc,InsertPointTy OuterAllocaIP,BodyGenCallbackTy BodyGenCB,PrivatizeCallbackTy PrivCB,FinalizeCallbackTy FiniCB,Value * IfCondition,Value * NumThreads,omp::ProcBindKind ProcBind,bool IsCancellable)1462 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createParallel(
1463 const LocationDescription &Loc, InsertPointTy OuterAllocaIP,
1464 BodyGenCallbackTy BodyGenCB, PrivatizeCallbackTy PrivCB,
1465 FinalizeCallbackTy FiniCB, Value *IfCondition, Value *NumThreads,
1466 omp::ProcBindKind ProcBind, bool IsCancellable) {
1467 assert(!isConflictIP(Loc.IP, OuterAllocaIP) && "IPs must not be ambiguous");
1468
1469 if (!updateToLocation(Loc))
1470 return Loc.IP;
1471
1472 uint32_t SrcLocStrSize;
1473 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1474 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1475 Value *ThreadID = getOrCreateThreadID(Ident);
1476 // If we generate code for the target device, we need to allocate
1477 // struct for aggregate params in the device default alloca address space.
1478 // OpenMP runtime requires that the params of the extracted functions are
1479 // passed as zero address space pointers. This flag ensures that extracted
1480 // function arguments are declared in zero address space
1481 bool ArgsInZeroAddressSpace = Config.isTargetDevice();
1482
1483 // Build call __kmpc_push_num_threads(&Ident, global_tid, num_threads)
1484 // only if we compile for host side.
1485 if (NumThreads && !Config.isTargetDevice()) {
1486 Value *Args[] = {
1487 Ident, ThreadID,
1488 Builder.CreateIntCast(NumThreads, Int32, /*isSigned*/ false)};
1489 Builder.CreateCall(
1490 getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_num_threads), Args);
1491 }
1492
1493 if (ProcBind != OMP_PROC_BIND_default) {
1494 // Build call __kmpc_push_proc_bind(&Ident, global_tid, proc_bind)
1495 Value *Args[] = {
1496 Ident, ThreadID,
1497 ConstantInt::get(Int32, unsigned(ProcBind), /*isSigned=*/true)};
1498 Builder.CreateCall(
1499 getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_proc_bind), Args);
1500 }
1501
1502 BasicBlock *InsertBB = Builder.GetInsertBlock();
1503 Function *OuterFn = InsertBB->getParent();
1504
1505 // Save the outer alloca block because the insertion iterator may get
1506 // invalidated and we still need this later.
1507 BasicBlock *OuterAllocaBlock = OuterAllocaIP.getBlock();
1508
1509 // Vector to remember instructions we used only during the modeling but which
1510 // we want to delete at the end.
1511 SmallVector<Instruction *, 4> ToBeDeleted;
1512
1513 // Change the location to the outer alloca insertion point to create and
1514 // initialize the allocas we pass into the parallel region.
1515 InsertPointTy NewOuter(OuterAllocaBlock, OuterAllocaBlock->begin());
1516 Builder.restoreIP(NewOuter);
1517 AllocaInst *TIDAddrAlloca = Builder.CreateAlloca(Int32, nullptr, "tid.addr");
1518 AllocaInst *ZeroAddrAlloca =
1519 Builder.CreateAlloca(Int32, nullptr, "zero.addr");
1520 Instruction *TIDAddr = TIDAddrAlloca;
1521 Instruction *ZeroAddr = ZeroAddrAlloca;
1522 if (ArgsInZeroAddressSpace && M.getDataLayout().getAllocaAddrSpace() != 0) {
1523 // Add additional casts to enforce pointers in zero address space
1524 TIDAddr = new AddrSpaceCastInst(
1525 TIDAddrAlloca, PointerType ::get(M.getContext(), 0), "tid.addr.ascast");
1526 TIDAddr->insertAfter(TIDAddrAlloca->getIterator());
1527 ToBeDeleted.push_back(TIDAddr);
1528 ZeroAddr = new AddrSpaceCastInst(ZeroAddrAlloca,
1529 PointerType ::get(M.getContext(), 0),
1530 "zero.addr.ascast");
1531 ZeroAddr->insertAfter(ZeroAddrAlloca->getIterator());
1532 ToBeDeleted.push_back(ZeroAddr);
1533 }
1534
1535 // We only need TIDAddr and ZeroAddr for modeling purposes to get the
1536 // associated arguments in the outlined function, so we delete them later.
1537 ToBeDeleted.push_back(TIDAddrAlloca);
1538 ToBeDeleted.push_back(ZeroAddrAlloca);
1539
1540 // Create an artificial insertion point that will also ensure the blocks we
1541 // are about to split are not degenerated.
1542 auto *UI = new UnreachableInst(Builder.getContext(), InsertBB);
1543
1544 BasicBlock *EntryBB = UI->getParent();
1545 BasicBlock *PRegEntryBB = EntryBB->splitBasicBlock(UI, "omp.par.entry");
1546 BasicBlock *PRegBodyBB = PRegEntryBB->splitBasicBlock(UI, "omp.par.region");
1547 BasicBlock *PRegPreFiniBB =
1548 PRegBodyBB->splitBasicBlock(UI, "omp.par.pre_finalize");
1549 BasicBlock *PRegExitBB = PRegPreFiniBB->splitBasicBlock(UI, "omp.par.exit");
1550
1551 auto FiniCBWrapper = [&](InsertPointTy IP) {
1552 // Hide "open-ended" blocks from the given FiniCB by setting the right jump
1553 // target to the region exit block.
1554 if (IP.getBlock()->end() == IP.getPoint()) {
1555 IRBuilder<>::InsertPointGuard IPG(Builder);
1556 Builder.restoreIP(IP);
1557 Instruction *I = Builder.CreateBr(PRegExitBB);
1558 IP = InsertPointTy(I->getParent(), I->getIterator());
1559 }
1560 assert(IP.getBlock()->getTerminator()->getNumSuccessors() == 1 &&
1561 IP.getBlock()->getTerminator()->getSuccessor(0) == PRegExitBB &&
1562 "Unexpected insertion point for finalization call!");
1563 return FiniCB(IP);
1564 };
1565
1566 FinalizationStack.push_back({FiniCBWrapper, OMPD_parallel, IsCancellable});
1567
1568 // Generate the privatization allocas in the block that will become the entry
1569 // of the outlined function.
1570 Builder.SetInsertPoint(PRegEntryBB->getTerminator());
1571 InsertPointTy InnerAllocaIP = Builder.saveIP();
1572
1573 AllocaInst *PrivTIDAddr =
1574 Builder.CreateAlloca(Int32, nullptr, "tid.addr.local");
1575 Instruction *PrivTID = Builder.CreateLoad(Int32, PrivTIDAddr, "tid");
1576
1577 // Add some fake uses for OpenMP provided arguments.
1578 ToBeDeleted.push_back(Builder.CreateLoad(Int32, TIDAddr, "tid.addr.use"));
1579 Instruction *ZeroAddrUse =
1580 Builder.CreateLoad(Int32, ZeroAddr, "zero.addr.use");
1581 ToBeDeleted.push_back(ZeroAddrUse);
1582
1583 // EntryBB
1584 // |
1585 // V
1586 // PRegionEntryBB <- Privatization allocas are placed here.
1587 // |
1588 // V
1589 // PRegionBodyBB <- BodeGen is invoked here.
1590 // |
1591 // V
1592 // PRegPreFiniBB <- The block we will start finalization from.
1593 // |
1594 // V
1595 // PRegionExitBB <- A common exit to simplify block collection.
1596 //
1597
1598 LLVM_DEBUG(dbgs() << "Before body codegen: " << *OuterFn << "\n");
1599
1600 // Let the caller create the body.
1601 assert(BodyGenCB && "Expected body generation callback!");
1602 InsertPointTy CodeGenIP(PRegBodyBB, PRegBodyBB->begin());
1603 if (Error Err = BodyGenCB(InnerAllocaIP, CodeGenIP))
1604 return Err;
1605
1606 LLVM_DEBUG(dbgs() << "After body codegen: " << *OuterFn << "\n");
1607
1608 OutlineInfo OI;
1609 if (Config.isTargetDevice()) {
1610 // Generate OpenMP target specific runtime call
1611 OI.PostOutlineCB = [=, ToBeDeletedVec =
1612 std::move(ToBeDeleted)](Function &OutlinedFn) {
1613 targetParallelCallback(this, OutlinedFn, OuterFn, OuterAllocaBlock, Ident,
1614 IfCondition, NumThreads, PrivTID, PrivTIDAddr,
1615 ThreadID, ToBeDeletedVec);
1616 };
1617 } else {
1618 // Generate OpenMP host runtime call
1619 OI.PostOutlineCB = [=, ToBeDeletedVec =
1620 std::move(ToBeDeleted)](Function &OutlinedFn) {
1621 hostParallelCallback(this, OutlinedFn, OuterFn, Ident, IfCondition,
1622 PrivTID, PrivTIDAddr, ToBeDeletedVec);
1623 };
1624 }
1625
1626 OI.OuterAllocaBB = OuterAllocaBlock;
1627 OI.EntryBB = PRegEntryBB;
1628 OI.ExitBB = PRegExitBB;
1629
1630 SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
1631 SmallVector<BasicBlock *, 32> Blocks;
1632 OI.collectBlocks(ParallelRegionBlockSet, Blocks);
1633
1634 CodeExtractorAnalysisCache CEAC(*OuterFn);
1635 CodeExtractor Extractor(Blocks, /* DominatorTree */ nullptr,
1636 /* AggregateArgs */ false,
1637 /* BlockFrequencyInfo */ nullptr,
1638 /* BranchProbabilityInfo */ nullptr,
1639 /* AssumptionCache */ nullptr,
1640 /* AllowVarArgs */ true,
1641 /* AllowAlloca */ true,
1642 /* AllocationBlock */ OuterAllocaBlock,
1643 /* Suffix */ ".omp_par", ArgsInZeroAddressSpace);
1644
1645 // Find inputs to, outputs from the code region.
1646 BasicBlock *CommonExit = nullptr;
1647 SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
1648 Extractor.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
1649
1650 Extractor.findInputsOutputs(Inputs, Outputs, SinkingCands,
1651 /*CollectGlobalInputs=*/true);
1652
1653 Inputs.remove_if([&](Value *I) {
1654 if (auto *GV = dyn_cast_if_present<GlobalVariable>(I))
1655 return GV->getValueType() == OpenMPIRBuilder::Ident;
1656
1657 return false;
1658 });
1659
1660 LLVM_DEBUG(dbgs() << "Before privatization: " << *OuterFn << "\n");
1661
1662 FunctionCallee TIDRTLFn =
1663 getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_global_thread_num);
1664
1665 auto PrivHelper = [&](Value &V) -> Error {
1666 if (&V == TIDAddr || &V == ZeroAddr) {
1667 OI.ExcludeArgsFromAggregate.push_back(&V);
1668 return Error::success();
1669 }
1670
1671 SetVector<Use *> Uses;
1672 for (Use &U : V.uses())
1673 if (auto *UserI = dyn_cast<Instruction>(U.getUser()))
1674 if (ParallelRegionBlockSet.count(UserI->getParent()))
1675 Uses.insert(&U);
1676
1677 // __kmpc_fork_call expects extra arguments as pointers. If the input
1678 // already has a pointer type, everything is fine. Otherwise, store the
1679 // value onto stack and load it back inside the to-be-outlined region. This
1680 // will ensure only the pointer will be passed to the function.
1681 // FIXME: if there are more than 15 trailing arguments, they must be
1682 // additionally packed in a struct.
1683 Value *Inner = &V;
1684 if (!V.getType()->isPointerTy()) {
1685 IRBuilder<>::InsertPointGuard Guard(Builder);
1686 LLVM_DEBUG(llvm::dbgs() << "Forwarding input as pointer: " << V << "\n");
1687
1688 Builder.restoreIP(OuterAllocaIP);
1689 Value *Ptr =
1690 Builder.CreateAlloca(V.getType(), nullptr, V.getName() + ".reloaded");
1691
1692 // Store to stack at end of the block that currently branches to the entry
1693 // block of the to-be-outlined region.
1694 Builder.SetInsertPoint(InsertBB,
1695 InsertBB->getTerminator()->getIterator());
1696 Builder.CreateStore(&V, Ptr);
1697
1698 // Load back next to allocations in the to-be-outlined region.
1699 Builder.restoreIP(InnerAllocaIP);
1700 Inner = Builder.CreateLoad(V.getType(), Ptr);
1701 }
1702
1703 Value *ReplacementValue = nullptr;
1704 CallInst *CI = dyn_cast<CallInst>(&V);
1705 if (CI && CI->getCalledFunction() == TIDRTLFn.getCallee()) {
1706 ReplacementValue = PrivTID;
1707 } else {
1708 InsertPointOrErrorTy AfterIP =
1709 PrivCB(InnerAllocaIP, Builder.saveIP(), V, *Inner, ReplacementValue);
1710 if (!AfterIP)
1711 return AfterIP.takeError();
1712 Builder.restoreIP(*AfterIP);
1713 InnerAllocaIP = {
1714 InnerAllocaIP.getBlock(),
1715 InnerAllocaIP.getBlock()->getTerminator()->getIterator()};
1716
1717 assert(ReplacementValue &&
1718 "Expected copy/create callback to set replacement value!");
1719 if (ReplacementValue == &V)
1720 return Error::success();
1721 }
1722
1723 for (Use *UPtr : Uses)
1724 UPtr->set(ReplacementValue);
1725
1726 return Error::success();
1727 };
1728
1729 // Reset the inner alloca insertion as it will be used for loading the values
1730 // wrapped into pointers before passing them into the to-be-outlined region.
1731 // Configure it to insert immediately after the fake use of zero address so
1732 // that they are available in the generated body and so that the
1733 // OpenMP-related values (thread ID and zero address pointers) remain leading
1734 // in the argument list.
1735 InnerAllocaIP = IRBuilder<>::InsertPoint(
1736 ZeroAddrUse->getParent(), ZeroAddrUse->getNextNode()->getIterator());
1737
1738 // Reset the outer alloca insertion point to the entry of the relevant block
1739 // in case it was invalidated.
1740 OuterAllocaIP = IRBuilder<>::InsertPoint(
1741 OuterAllocaBlock, OuterAllocaBlock->getFirstInsertionPt());
1742
1743 for (Value *Input : Inputs) {
1744 LLVM_DEBUG(dbgs() << "Captured input: " << *Input << "\n");
1745 if (Error Err = PrivHelper(*Input))
1746 return Err;
1747 }
1748 LLVM_DEBUG({
1749 for (Value *Output : Outputs)
1750 LLVM_DEBUG(dbgs() << "Captured output: " << *Output << "\n");
1751 });
1752 assert(Outputs.empty() &&
1753 "OpenMP outlining should not produce live-out values!");
1754
1755 LLVM_DEBUG(dbgs() << "After privatization: " << *OuterFn << "\n");
1756 LLVM_DEBUG({
1757 for (auto *BB : Blocks)
1758 dbgs() << " PBR: " << BB->getName() << "\n";
1759 });
1760
1761 // Adjust the finalization stack, verify the adjustment, and call the
1762 // finalize function a last time to finalize values between the pre-fini
1763 // block and the exit block if we left the parallel "the normal way".
1764 auto FiniInfo = FinalizationStack.pop_back_val();
1765 (void)FiniInfo;
1766 assert(FiniInfo.DK == OMPD_parallel &&
1767 "Unexpected finalization stack state!");
1768
1769 Instruction *PRegPreFiniTI = PRegPreFiniBB->getTerminator();
1770
1771 InsertPointTy PreFiniIP(PRegPreFiniBB, PRegPreFiniTI->getIterator());
1772 if (Error Err = FiniCB(PreFiniIP))
1773 return Err;
1774
1775 // Register the outlined info.
1776 addOutlineInfo(std::move(OI));
1777
1778 InsertPointTy AfterIP(UI->getParent(), UI->getParent()->end());
1779 UI->eraseFromParent();
1780
1781 return AfterIP;
1782 }
1783
emitFlush(const LocationDescription & Loc)1784 void OpenMPIRBuilder::emitFlush(const LocationDescription &Loc) {
1785 // Build call void __kmpc_flush(ident_t *loc)
1786 uint32_t SrcLocStrSize;
1787 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1788 Value *Args[] = {getOrCreateIdent(SrcLocStr, SrcLocStrSize)};
1789
1790 Builder.CreateCall(getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_flush), Args);
1791 }
1792
createFlush(const LocationDescription & Loc)1793 void OpenMPIRBuilder::createFlush(const LocationDescription &Loc) {
1794 if (!updateToLocation(Loc))
1795 return;
1796 emitFlush(Loc);
1797 }
1798
emitTaskwaitImpl(const LocationDescription & Loc)1799 void OpenMPIRBuilder::emitTaskwaitImpl(const LocationDescription &Loc) {
1800 // Build call kmp_int32 __kmpc_omp_taskwait(ident_t *loc, kmp_int32
1801 // global_tid);
1802 uint32_t SrcLocStrSize;
1803 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1804 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1805 Value *Args[] = {Ident, getOrCreateThreadID(Ident)};
1806
1807 // Ignore return result until untied tasks are supported.
1808 Builder.CreateCall(getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_taskwait),
1809 Args);
1810 }
1811
createTaskwait(const LocationDescription & Loc)1812 void OpenMPIRBuilder::createTaskwait(const LocationDescription &Loc) {
1813 if (!updateToLocation(Loc))
1814 return;
1815 emitTaskwaitImpl(Loc);
1816 }
1817
emitTaskyieldImpl(const LocationDescription & Loc)1818 void OpenMPIRBuilder::emitTaskyieldImpl(const LocationDescription &Loc) {
1819 // Build call __kmpc_omp_taskyield(loc, thread_id, 0);
1820 uint32_t SrcLocStrSize;
1821 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1822 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1823 Constant *I32Null = ConstantInt::getNullValue(Int32);
1824 Value *Args[] = {Ident, getOrCreateThreadID(Ident), I32Null};
1825
1826 Builder.CreateCall(getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_taskyield),
1827 Args);
1828 }
1829
createTaskyield(const LocationDescription & Loc)1830 void OpenMPIRBuilder::createTaskyield(const LocationDescription &Loc) {
1831 if (!updateToLocation(Loc))
1832 return;
1833 emitTaskyieldImpl(Loc);
1834 }
1835
1836 // Processes the dependencies in Dependencies and does the following
1837 // - Allocates space on the stack of an array of DependInfo objects
1838 // - Populates each DependInfo object with relevant information of
1839 // the corresponding dependence.
1840 // - All code is inserted in the entry block of the current function.
emitTaskDependencies(OpenMPIRBuilder & OMPBuilder,const SmallVectorImpl<OpenMPIRBuilder::DependData> & Dependencies)1841 static Value *emitTaskDependencies(
1842 OpenMPIRBuilder &OMPBuilder,
1843 const SmallVectorImpl<OpenMPIRBuilder::DependData> &Dependencies) {
1844 // Early return if we have no dependencies to process
1845 if (Dependencies.empty())
1846 return nullptr;
1847
1848 // Given a vector of DependData objects, in this function we create an
1849 // array on the stack that holds kmp_dep_info objects corresponding
1850 // to each dependency. This is then passed to the OpenMP runtime.
1851 // For example, if there are 'n' dependencies then the following psedo
1852 // code is generated. Assume the first dependence is on a variable 'a'
1853 //
1854 // \code{c}
1855 // DepArray = alloc(n x sizeof(kmp_depend_info);
1856 // idx = 0;
1857 // DepArray[idx].base_addr = ptrtoint(&a);
1858 // DepArray[idx].len = 8;
1859 // DepArray[idx].flags = Dep.DepKind; /*(See OMPContants.h for DepKind)*/
1860 // ++idx;
1861 // DepArray[idx].base_addr = ...;
1862 // \endcode
1863
1864 IRBuilderBase &Builder = OMPBuilder.Builder;
1865 Type *DependInfo = OMPBuilder.DependInfo;
1866 Module &M = OMPBuilder.M;
1867
1868 Value *DepArray = nullptr;
1869 OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
1870 Builder.SetInsertPoint(
1871 OldIP.getBlock()->getParent()->getEntryBlock().getTerminator());
1872
1873 Type *DepArrayTy = ArrayType::get(DependInfo, Dependencies.size());
1874 DepArray = Builder.CreateAlloca(DepArrayTy, nullptr, ".dep.arr.addr");
1875
1876 Builder.restoreIP(OldIP);
1877
1878 for (const auto &[DepIdx, Dep] : enumerate(Dependencies)) {
1879 Value *Base =
1880 Builder.CreateConstInBoundsGEP2_64(DepArrayTy, DepArray, 0, DepIdx);
1881 // Store the pointer to the variable
1882 Value *Addr = Builder.CreateStructGEP(
1883 DependInfo, Base,
1884 static_cast<unsigned int>(RTLDependInfoFields::BaseAddr));
1885 Value *DepValPtr = Builder.CreatePtrToInt(Dep.DepVal, Builder.getInt64Ty());
1886 Builder.CreateStore(DepValPtr, Addr);
1887 // Store the size of the variable
1888 Value *Size = Builder.CreateStructGEP(
1889 DependInfo, Base, static_cast<unsigned int>(RTLDependInfoFields::Len));
1890 Builder.CreateStore(
1891 Builder.getInt64(M.getDataLayout().getTypeStoreSize(Dep.DepValueType)),
1892 Size);
1893 // Store the dependency kind
1894 Value *Flags = Builder.CreateStructGEP(
1895 DependInfo, Base,
1896 static_cast<unsigned int>(RTLDependInfoFields::Flags));
1897 Builder.CreateStore(
1898 ConstantInt::get(Builder.getInt8Ty(),
1899 static_cast<unsigned int>(Dep.DepKind)),
1900 Flags);
1901 }
1902 return DepArray;
1903 }
1904
createTask(const LocationDescription & Loc,InsertPointTy AllocaIP,BodyGenCallbackTy BodyGenCB,bool Tied,Value * Final,Value * IfCondition,SmallVector<DependData> Dependencies,bool Mergeable,Value * EventHandle,Value * Priority)1905 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTask(
1906 const LocationDescription &Loc, InsertPointTy AllocaIP,
1907 BodyGenCallbackTy BodyGenCB, bool Tied, Value *Final, Value *IfCondition,
1908 SmallVector<DependData> Dependencies, bool Mergeable, Value *EventHandle,
1909 Value *Priority) {
1910
1911 if (!updateToLocation(Loc))
1912 return InsertPointTy();
1913
1914 uint32_t SrcLocStrSize;
1915 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1916 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1917 // The current basic block is split into four basic blocks. After outlining,
1918 // they will be mapped as follows:
1919 // ```
1920 // def current_fn() {
1921 // current_basic_block:
1922 // br label %task.exit
1923 // task.exit:
1924 // ; instructions after task
1925 // }
1926 // def outlined_fn() {
1927 // task.alloca:
1928 // br label %task.body
1929 // task.body:
1930 // ret void
1931 // }
1932 // ```
1933 BasicBlock *TaskExitBB = splitBB(Builder, /*CreateBranch=*/true, "task.exit");
1934 BasicBlock *TaskBodyBB = splitBB(Builder, /*CreateBranch=*/true, "task.body");
1935 BasicBlock *TaskAllocaBB =
1936 splitBB(Builder, /*CreateBranch=*/true, "task.alloca");
1937
1938 InsertPointTy TaskAllocaIP =
1939 InsertPointTy(TaskAllocaBB, TaskAllocaBB->begin());
1940 InsertPointTy TaskBodyIP = InsertPointTy(TaskBodyBB, TaskBodyBB->begin());
1941 if (Error Err = BodyGenCB(TaskAllocaIP, TaskBodyIP))
1942 return Err;
1943
1944 OutlineInfo OI;
1945 OI.EntryBB = TaskAllocaBB;
1946 OI.OuterAllocaBB = AllocaIP.getBlock();
1947 OI.ExitBB = TaskExitBB;
1948
1949 // Add the thread ID argument.
1950 SmallVector<Instruction *, 4> ToBeDeleted;
1951 OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
1952 Builder, AllocaIP, ToBeDeleted, TaskAllocaIP, "global.tid", false));
1953
1954 OI.PostOutlineCB = [this, Ident, Tied, Final, IfCondition, Dependencies,
1955 Mergeable, Priority, EventHandle, TaskAllocaBB,
1956 ToBeDeleted](Function &OutlinedFn) mutable {
1957 // Replace the Stale CI by appropriate RTL function call.
1958 assert(OutlinedFn.hasOneUse() &&
1959 "there must be a single user for the outlined function");
1960 CallInst *StaleCI = cast<CallInst>(OutlinedFn.user_back());
1961
1962 // HasShareds is true if any variables are captured in the outlined region,
1963 // false otherwise.
1964 bool HasShareds = StaleCI->arg_size() > 1;
1965 Builder.SetInsertPoint(StaleCI);
1966
1967 // Gather the arguments for emitting the runtime call for
1968 // @__kmpc_omp_task_alloc
1969 Function *TaskAllocFn =
1970 getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_alloc);
1971
1972 // Arguments - `loc_ref` (Ident) and `gtid` (ThreadID)
1973 // call.
1974 Value *ThreadID = getOrCreateThreadID(Ident);
1975
1976 // Argument - `flags`
1977 // Task is tied iff (Flags & 1) == 1.
1978 // Task is untied iff (Flags & 1) == 0.
1979 // Task is final iff (Flags & 2) == 2.
1980 // Task is not final iff (Flags & 2) == 0.
1981 // Task is mergeable iff (Flags & 4) == 4.
1982 // Task is not mergeable iff (Flags & 4) == 0.
1983 // Task is priority iff (Flags & 32) == 32.
1984 // Task is not priority iff (Flags & 32) == 0.
1985 // TODO: Handle the other flags.
1986 Value *Flags = Builder.getInt32(Tied);
1987 if (Final) {
1988 Value *FinalFlag =
1989 Builder.CreateSelect(Final, Builder.getInt32(2), Builder.getInt32(0));
1990 Flags = Builder.CreateOr(FinalFlag, Flags);
1991 }
1992
1993 if (Mergeable)
1994 Flags = Builder.CreateOr(Builder.getInt32(4), Flags);
1995 if (Priority)
1996 Flags = Builder.CreateOr(Builder.getInt32(32), Flags);
1997
1998 // Argument - `sizeof_kmp_task_t` (TaskSize)
1999 // Tasksize refers to the size in bytes of kmp_task_t data structure
2000 // including private vars accessed in task.
2001 // TODO: add kmp_task_t_with_privates (privates)
2002 Value *TaskSize = Builder.getInt64(
2003 divideCeil(M.getDataLayout().getTypeSizeInBits(Task), 8));
2004
2005 // Argument - `sizeof_shareds` (SharedsSize)
2006 // SharedsSize refers to the shareds array size in the kmp_task_t data
2007 // structure.
2008 Value *SharedsSize = Builder.getInt64(0);
2009 if (HasShareds) {
2010 AllocaInst *ArgStructAlloca =
2011 dyn_cast<AllocaInst>(StaleCI->getArgOperand(1));
2012 assert(ArgStructAlloca &&
2013 "Unable to find the alloca instruction corresponding to arguments "
2014 "for extracted function");
2015 StructType *ArgStructType =
2016 dyn_cast<StructType>(ArgStructAlloca->getAllocatedType());
2017 assert(ArgStructType && "Unable to find struct type corresponding to "
2018 "arguments for extracted function");
2019 SharedsSize =
2020 Builder.getInt64(M.getDataLayout().getTypeStoreSize(ArgStructType));
2021 }
2022 // Emit the @__kmpc_omp_task_alloc runtime call
2023 // The runtime call returns a pointer to an area where the task captured
2024 // variables must be copied before the task is run (TaskData)
2025 CallInst *TaskData = Builder.CreateCall(
2026 TaskAllocFn, {/*loc_ref=*/Ident, /*gtid=*/ThreadID, /*flags=*/Flags,
2027 /*sizeof_task=*/TaskSize, /*sizeof_shared=*/SharedsSize,
2028 /*task_func=*/&OutlinedFn});
2029
2030 // Emit detach clause initialization.
2031 // evt = (typeof(evt))__kmpc_task_allow_completion_event(loc, tid,
2032 // task_descriptor);
2033 if (EventHandle) {
2034 Function *TaskDetachFn = getOrCreateRuntimeFunctionPtr(
2035 OMPRTL___kmpc_task_allow_completion_event);
2036 llvm::Value *EventVal =
2037 Builder.CreateCall(TaskDetachFn, {Ident, ThreadID, TaskData});
2038 llvm::Value *EventHandleAddr =
2039 Builder.CreatePointerBitCastOrAddrSpaceCast(EventHandle,
2040 Builder.getPtrTy(0));
2041 EventVal = Builder.CreatePtrToInt(EventVal, Builder.getInt64Ty());
2042 Builder.CreateStore(EventVal, EventHandleAddr);
2043 }
2044 // Copy the arguments for outlined function
2045 if (HasShareds) {
2046 Value *Shareds = StaleCI->getArgOperand(1);
2047 Align Alignment = TaskData->getPointerAlignment(M.getDataLayout());
2048 Value *TaskShareds = Builder.CreateLoad(VoidPtr, TaskData);
2049 Builder.CreateMemCpy(TaskShareds, Alignment, Shareds, Alignment,
2050 SharedsSize);
2051 }
2052
2053 if (Priority) {
2054 //
2055 // The return type of "__kmpc_omp_task_alloc" is "kmp_task_t *",
2056 // we populate the priority information into the "kmp_task_t" here
2057 //
2058 // The struct "kmp_task_t" definition is available in kmp.h
2059 // kmp_task_t = { shareds, routine, part_id, data1, data2 }
2060 // data2 is used for priority
2061 //
2062 Type *Int32Ty = Builder.getInt32Ty();
2063 Constant *Zero = ConstantInt::get(Int32Ty, 0);
2064 // kmp_task_t* => { ptr }
2065 Type *TaskPtr = StructType::get(VoidPtr);
2066 Value *TaskGEP =
2067 Builder.CreateInBoundsGEP(TaskPtr, TaskData, {Zero, Zero});
2068 // kmp_task_t => { ptr, ptr, i32, ptr, ptr }
2069 Type *TaskStructType = StructType::get(
2070 VoidPtr, VoidPtr, Builder.getInt32Ty(), VoidPtr, VoidPtr);
2071 Value *PriorityData = Builder.CreateInBoundsGEP(
2072 TaskStructType, TaskGEP, {Zero, ConstantInt::get(Int32Ty, 4)});
2073 // kmp_cmplrdata_t => { ptr, ptr }
2074 Type *CmplrStructType = StructType::get(VoidPtr, VoidPtr);
2075 Value *CmplrData = Builder.CreateInBoundsGEP(CmplrStructType,
2076 PriorityData, {Zero, Zero});
2077 Builder.CreateStore(Priority, CmplrData);
2078 }
2079
2080 Value *DepArray = emitTaskDependencies(*this, Dependencies);
2081
2082 // In the presence of the `if` clause, the following IR is generated:
2083 // ...
2084 // %data = call @__kmpc_omp_task_alloc(...)
2085 // br i1 %if_condition, label %then, label %else
2086 // then:
2087 // call @__kmpc_omp_task(...)
2088 // br label %exit
2089 // else:
2090 // ;; Wait for resolution of dependencies, if any, before
2091 // ;; beginning the task
2092 // call @__kmpc_omp_wait_deps(...)
2093 // call @__kmpc_omp_task_begin_if0(...)
2094 // call @outlined_fn(...)
2095 // call @__kmpc_omp_task_complete_if0(...)
2096 // br label %exit
2097 // exit:
2098 // ...
2099 if (IfCondition) {
2100 // `SplitBlockAndInsertIfThenElse` requires the block to have a
2101 // terminator.
2102 splitBB(Builder, /*CreateBranch=*/true, "if.end");
2103 Instruction *IfTerminator =
2104 Builder.GetInsertPoint()->getParent()->getTerminator();
2105 Instruction *ThenTI = IfTerminator, *ElseTI = nullptr;
2106 Builder.SetInsertPoint(IfTerminator);
2107 SplitBlockAndInsertIfThenElse(IfCondition, IfTerminator, &ThenTI,
2108 &ElseTI);
2109 Builder.SetInsertPoint(ElseTI);
2110
2111 if (Dependencies.size()) {
2112 Function *TaskWaitFn =
2113 getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_wait_deps);
2114 Builder.CreateCall(
2115 TaskWaitFn,
2116 {Ident, ThreadID, Builder.getInt32(Dependencies.size()), DepArray,
2117 ConstantInt::get(Builder.getInt32Ty(), 0),
2118 ConstantPointerNull::get(PointerType::getUnqual(M.getContext()))});
2119 }
2120 Function *TaskBeginFn =
2121 getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_begin_if0);
2122 Function *TaskCompleteFn =
2123 getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_complete_if0);
2124 Builder.CreateCall(TaskBeginFn, {Ident, ThreadID, TaskData});
2125 CallInst *CI = nullptr;
2126 if (HasShareds)
2127 CI = Builder.CreateCall(&OutlinedFn, {ThreadID, TaskData});
2128 else
2129 CI = Builder.CreateCall(&OutlinedFn, {ThreadID});
2130 CI->setDebugLoc(StaleCI->getDebugLoc());
2131 Builder.CreateCall(TaskCompleteFn, {Ident, ThreadID, TaskData});
2132 Builder.SetInsertPoint(ThenTI);
2133 }
2134
2135 if (Dependencies.size()) {
2136 Function *TaskFn =
2137 getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_with_deps);
2138 Builder.CreateCall(
2139 TaskFn,
2140 {Ident, ThreadID, TaskData, Builder.getInt32(Dependencies.size()),
2141 DepArray, ConstantInt::get(Builder.getInt32Ty(), 0),
2142 ConstantPointerNull::get(PointerType::getUnqual(M.getContext()))});
2143
2144 } else {
2145 // Emit the @__kmpc_omp_task runtime call to spawn the task
2146 Function *TaskFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task);
2147 Builder.CreateCall(TaskFn, {Ident, ThreadID, TaskData});
2148 }
2149
2150 StaleCI->eraseFromParent();
2151
2152 Builder.SetInsertPoint(TaskAllocaBB, TaskAllocaBB->begin());
2153 if (HasShareds) {
2154 LoadInst *Shareds = Builder.CreateLoad(VoidPtr, OutlinedFn.getArg(1));
2155 OutlinedFn.getArg(1)->replaceUsesWithIf(
2156 Shareds, [Shareds](Use &U) { return U.getUser() != Shareds; });
2157 }
2158
2159 for (Instruction *I : llvm::reverse(ToBeDeleted))
2160 I->eraseFromParent();
2161 };
2162
2163 addOutlineInfo(std::move(OI));
2164 Builder.SetInsertPoint(TaskExitBB, TaskExitBB->begin());
2165
2166 return Builder.saveIP();
2167 }
2168
2169 OpenMPIRBuilder::InsertPointOrErrorTy
createTaskgroup(const LocationDescription & Loc,InsertPointTy AllocaIP,BodyGenCallbackTy BodyGenCB)2170 OpenMPIRBuilder::createTaskgroup(const LocationDescription &Loc,
2171 InsertPointTy AllocaIP,
2172 BodyGenCallbackTy BodyGenCB) {
2173 if (!updateToLocation(Loc))
2174 return InsertPointTy();
2175
2176 uint32_t SrcLocStrSize;
2177 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
2178 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
2179 Value *ThreadID = getOrCreateThreadID(Ident);
2180
2181 // Emit the @__kmpc_taskgroup runtime call to start the taskgroup
2182 Function *TaskgroupFn =
2183 getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_taskgroup);
2184 Builder.CreateCall(TaskgroupFn, {Ident, ThreadID});
2185
2186 BasicBlock *TaskgroupExitBB = splitBB(Builder, true, "taskgroup.exit");
2187 if (Error Err = BodyGenCB(AllocaIP, Builder.saveIP()))
2188 return Err;
2189
2190 Builder.SetInsertPoint(TaskgroupExitBB);
2191 // Emit the @__kmpc_end_taskgroup runtime call to end the taskgroup
2192 Function *EndTaskgroupFn =
2193 getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_taskgroup);
2194 Builder.CreateCall(EndTaskgroupFn, {Ident, ThreadID});
2195
2196 return Builder.saveIP();
2197 }
2198
createSections(const LocationDescription & Loc,InsertPointTy AllocaIP,ArrayRef<StorableBodyGenCallbackTy> SectionCBs,PrivatizeCallbackTy PrivCB,FinalizeCallbackTy FiniCB,bool IsCancellable,bool IsNowait)2199 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createSections(
2200 const LocationDescription &Loc, InsertPointTy AllocaIP,
2201 ArrayRef<StorableBodyGenCallbackTy> SectionCBs, PrivatizeCallbackTy PrivCB,
2202 FinalizeCallbackTy FiniCB, bool IsCancellable, bool IsNowait) {
2203 assert(!isConflictIP(AllocaIP, Loc.IP) && "Dedicated IP allocas required");
2204
2205 if (!updateToLocation(Loc))
2206 return Loc.IP;
2207
2208 // FiniCBWrapper needs to create a branch to the loop finalization block, but
2209 // this has not been created yet at some times when this callback runs.
2210 SmallVector<BranchInst *> CancellationBranches;
2211 auto FiniCBWrapper = [&](InsertPointTy IP) {
2212 if (IP.getBlock()->end() != IP.getPoint())
2213 return FiniCB(IP);
2214 // This must be done otherwise any nested constructs using FinalizeOMPRegion
2215 // will fail because that function requires the Finalization Basic Block to
2216 // have a terminator, which is already removed by EmitOMPRegionBody.
2217 // IP is currently at cancelation block.
2218 BranchInst *DummyBranch = Builder.CreateBr(IP.getBlock());
2219 IP = InsertPointTy(DummyBranch->getParent(), DummyBranch->getIterator());
2220 CancellationBranches.push_back(DummyBranch);
2221 return FiniCB(IP);
2222 };
2223
2224 FinalizationStack.push_back({FiniCBWrapper, OMPD_sections, IsCancellable});
2225
2226 // Each section is emitted as a switch case
2227 // Each finalization callback is handled from clang.EmitOMPSectionDirective()
2228 // -> OMP.createSection() which generates the IR for each section
2229 // Iterate through all sections and emit a switch construct:
2230 // switch (IV) {
2231 // case 0:
2232 // <SectionStmt[0]>;
2233 // break;
2234 // ...
2235 // case <NumSection> - 1:
2236 // <SectionStmt[<NumSection> - 1]>;
2237 // break;
2238 // }
2239 // ...
2240 // section_loop.after:
2241 // <FiniCB>;
2242 auto LoopBodyGenCB = [&](InsertPointTy CodeGenIP, Value *IndVar) -> Error {
2243 Builder.restoreIP(CodeGenIP);
2244 BasicBlock *Continue =
2245 splitBBWithSuffix(Builder, /*CreateBranch=*/false, ".sections.after");
2246 Function *CurFn = Continue->getParent();
2247 SwitchInst *SwitchStmt = Builder.CreateSwitch(IndVar, Continue);
2248
2249 unsigned CaseNumber = 0;
2250 for (auto SectionCB : SectionCBs) {
2251 BasicBlock *CaseBB = BasicBlock::Create(
2252 M.getContext(), "omp_section_loop.body.case", CurFn, Continue);
2253 SwitchStmt->addCase(Builder.getInt32(CaseNumber), CaseBB);
2254 Builder.SetInsertPoint(CaseBB);
2255 BranchInst *CaseEndBr = Builder.CreateBr(Continue);
2256 if (Error Err = SectionCB(InsertPointTy(), {CaseEndBr->getParent(),
2257 CaseEndBr->getIterator()}))
2258 return Err;
2259 CaseNumber++;
2260 }
2261 // remove the existing terminator from body BB since there can be no
2262 // terminators after switch/case
2263 return Error::success();
2264 };
2265 // Loop body ends here
2266 // LowerBound, UpperBound, and STride for createCanonicalLoop
2267 Type *I32Ty = Type::getInt32Ty(M.getContext());
2268 Value *LB = ConstantInt::get(I32Ty, 0);
2269 Value *UB = ConstantInt::get(I32Ty, SectionCBs.size());
2270 Value *ST = ConstantInt::get(I32Ty, 1);
2271 Expected<CanonicalLoopInfo *> LoopInfo = createCanonicalLoop(
2272 Loc, LoopBodyGenCB, LB, UB, ST, true, false, AllocaIP, "section_loop");
2273 if (!LoopInfo)
2274 return LoopInfo.takeError();
2275
2276 InsertPointOrErrorTy WsloopIP =
2277 applyStaticWorkshareLoop(Loc.DL, *LoopInfo, AllocaIP,
2278 WorksharingLoopType::ForStaticLoop, !IsNowait);
2279 if (!WsloopIP)
2280 return WsloopIP.takeError();
2281 InsertPointTy AfterIP = *WsloopIP;
2282
2283 BasicBlock *LoopFini = AfterIP.getBlock()->getSinglePredecessor();
2284 assert(LoopFini && "Bad structure of static workshare loop finalization");
2285
2286 // Apply the finalization callback in LoopAfterBB
2287 auto FiniInfo = FinalizationStack.pop_back_val();
2288 assert(FiniInfo.DK == OMPD_sections &&
2289 "Unexpected finalization stack state!");
2290 if (FinalizeCallbackTy &CB = FiniInfo.FiniCB) {
2291 Builder.restoreIP(AfterIP);
2292 BasicBlock *FiniBB =
2293 splitBBWithSuffix(Builder, /*CreateBranch=*/true, "sections.fini");
2294 if (Error Err = CB(Builder.saveIP()))
2295 return Err;
2296 AfterIP = {FiniBB, FiniBB->begin()};
2297 }
2298
2299 // Now we can fix the dummy branch to point to the right place
2300 for (BranchInst *DummyBranch : CancellationBranches) {
2301 assert(DummyBranch->getNumSuccessors() == 1);
2302 DummyBranch->setSuccessor(0, LoopFini);
2303 }
2304
2305 return AfterIP;
2306 }
2307
2308 OpenMPIRBuilder::InsertPointOrErrorTy
createSection(const LocationDescription & Loc,BodyGenCallbackTy BodyGenCB,FinalizeCallbackTy FiniCB)2309 OpenMPIRBuilder::createSection(const LocationDescription &Loc,
2310 BodyGenCallbackTy BodyGenCB,
2311 FinalizeCallbackTy FiniCB) {
2312 if (!updateToLocation(Loc))
2313 return Loc.IP;
2314
2315 auto FiniCBWrapper = [&](InsertPointTy IP) {
2316 if (IP.getBlock()->end() != IP.getPoint())
2317 return FiniCB(IP);
2318 // This must be done otherwise any nested constructs using FinalizeOMPRegion
2319 // will fail because that function requires the Finalization Basic Block to
2320 // have a terminator, which is already removed by EmitOMPRegionBody.
2321 // IP is currently at cancelation block.
2322 // We need to backtrack to the condition block to fetch
2323 // the exit block and create a branch from cancelation
2324 // to exit block.
2325 IRBuilder<>::InsertPointGuard IPG(Builder);
2326 Builder.restoreIP(IP);
2327 auto *CaseBB = Loc.IP.getBlock();
2328 auto *CondBB = CaseBB->getSinglePredecessor()->getSinglePredecessor();
2329 auto *ExitBB = CondBB->getTerminator()->getSuccessor(1);
2330 Instruction *I = Builder.CreateBr(ExitBB);
2331 IP = InsertPointTy(I->getParent(), I->getIterator());
2332 return FiniCB(IP);
2333 };
2334
2335 Directive OMPD = Directive::OMPD_sections;
2336 // Since we are using Finalization Callback here, HasFinalize
2337 // and IsCancellable have to be true
2338 return EmitOMPInlinedRegion(OMPD, nullptr, nullptr, BodyGenCB, FiniCBWrapper,
2339 /*Conditional*/ false, /*hasFinalize*/ true,
2340 /*IsCancellable*/ true);
2341 }
2342
getInsertPointAfterInstr(Instruction * I)2343 static OpenMPIRBuilder::InsertPointTy getInsertPointAfterInstr(Instruction *I) {
2344 BasicBlock::iterator IT(I);
2345 IT++;
2346 return OpenMPIRBuilder::InsertPointTy(I->getParent(), IT);
2347 }
2348
getGPUThreadID()2349 Value *OpenMPIRBuilder::getGPUThreadID() {
2350 return Builder.CreateCall(
2351 getOrCreateRuntimeFunction(M,
2352 OMPRTL___kmpc_get_hardware_thread_id_in_block),
2353 {});
2354 }
2355
getGPUWarpSize()2356 Value *OpenMPIRBuilder::getGPUWarpSize() {
2357 return Builder.CreateCall(
2358 getOrCreateRuntimeFunction(M, OMPRTL___kmpc_get_warp_size), {});
2359 }
2360
getNVPTXWarpID()2361 Value *OpenMPIRBuilder::getNVPTXWarpID() {
2362 unsigned LaneIDBits = Log2_32(Config.getGridValue().GV_Warp_Size);
2363 return Builder.CreateAShr(getGPUThreadID(), LaneIDBits, "nvptx_warp_id");
2364 }
2365
getNVPTXLaneID()2366 Value *OpenMPIRBuilder::getNVPTXLaneID() {
2367 unsigned LaneIDBits = Log2_32(Config.getGridValue().GV_Warp_Size);
2368 assert(LaneIDBits < 32 && "Invalid LaneIDBits size in NVPTX device.");
2369 unsigned LaneIDMask = ~0u >> (32u - LaneIDBits);
2370 return Builder.CreateAnd(getGPUThreadID(), Builder.getInt32(LaneIDMask),
2371 "nvptx_lane_id");
2372 }
2373
castValueToType(InsertPointTy AllocaIP,Value * From,Type * ToType)2374 Value *OpenMPIRBuilder::castValueToType(InsertPointTy AllocaIP, Value *From,
2375 Type *ToType) {
2376 Type *FromType = From->getType();
2377 uint64_t FromSize = M.getDataLayout().getTypeStoreSize(FromType);
2378 uint64_t ToSize = M.getDataLayout().getTypeStoreSize(ToType);
2379 assert(FromSize > 0 && "From size must be greater than zero");
2380 assert(ToSize > 0 && "To size must be greater than zero");
2381 if (FromType == ToType)
2382 return From;
2383 if (FromSize == ToSize)
2384 return Builder.CreateBitCast(From, ToType);
2385 if (ToType->isIntegerTy() && FromType->isIntegerTy())
2386 return Builder.CreateIntCast(From, ToType, /*isSigned*/ true);
2387 InsertPointTy SaveIP = Builder.saveIP();
2388 Builder.restoreIP(AllocaIP);
2389 Value *CastItem = Builder.CreateAlloca(ToType);
2390 Builder.restoreIP(SaveIP);
2391
2392 Value *ValCastItem = Builder.CreatePointerBitCastOrAddrSpaceCast(
2393 CastItem, Builder.getPtrTy(0));
2394 Builder.CreateStore(From, ValCastItem);
2395 return Builder.CreateLoad(ToType, CastItem);
2396 }
2397
createRuntimeShuffleFunction(InsertPointTy AllocaIP,Value * Element,Type * ElementType,Value * Offset)2398 Value *OpenMPIRBuilder::createRuntimeShuffleFunction(InsertPointTy AllocaIP,
2399 Value *Element,
2400 Type *ElementType,
2401 Value *Offset) {
2402 uint64_t Size = M.getDataLayout().getTypeStoreSize(ElementType);
2403 assert(Size <= 8 && "Unsupported bitwidth in shuffle instruction");
2404
2405 // Cast all types to 32- or 64-bit values before calling shuffle routines.
2406 Type *CastTy = Builder.getIntNTy(Size <= 4 ? 32 : 64);
2407 Value *ElemCast = castValueToType(AllocaIP, Element, CastTy);
2408 Value *WarpSize =
2409 Builder.CreateIntCast(getGPUWarpSize(), Builder.getInt16Ty(), true);
2410 Function *ShuffleFunc = getOrCreateRuntimeFunctionPtr(
2411 Size <= 4 ? RuntimeFunction::OMPRTL___kmpc_shuffle_int32
2412 : RuntimeFunction::OMPRTL___kmpc_shuffle_int64);
2413 Value *WarpSizeCast =
2414 Builder.CreateIntCast(WarpSize, Builder.getInt16Ty(), /*isSigned=*/true);
2415 Value *ShuffleCall =
2416 Builder.CreateCall(ShuffleFunc, {ElemCast, Offset, WarpSizeCast});
2417 return castValueToType(AllocaIP, ShuffleCall, CastTy);
2418 }
2419
shuffleAndStore(InsertPointTy AllocaIP,Value * SrcAddr,Value * DstAddr,Type * ElemType,Value * Offset,Type * ReductionArrayTy)2420 void OpenMPIRBuilder::shuffleAndStore(InsertPointTy AllocaIP, Value *SrcAddr,
2421 Value *DstAddr, Type *ElemType,
2422 Value *Offset, Type *ReductionArrayTy) {
2423 uint64_t Size = M.getDataLayout().getTypeStoreSize(ElemType);
2424 // Create the loop over the big sized data.
2425 // ptr = (void*)Elem;
2426 // ptrEnd = (void*) Elem + 1;
2427 // Step = 8;
2428 // while (ptr + Step < ptrEnd)
2429 // shuffle((int64_t)*ptr);
2430 // Step = 4;
2431 // while (ptr + Step < ptrEnd)
2432 // shuffle((int32_t)*ptr);
2433 // ...
2434 Type *IndexTy = Builder.getIndexTy(
2435 M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
2436 Value *ElemPtr = DstAddr;
2437 Value *Ptr = SrcAddr;
2438 for (unsigned IntSize = 8; IntSize >= 1; IntSize /= 2) {
2439 if (Size < IntSize)
2440 continue;
2441 Type *IntType = Builder.getIntNTy(IntSize * 8);
2442 Ptr = Builder.CreatePointerBitCastOrAddrSpaceCast(
2443 Ptr, Builder.getPtrTy(0), Ptr->getName() + ".ascast");
2444 Value *SrcAddrGEP =
2445 Builder.CreateGEP(ElemType, SrcAddr, {ConstantInt::get(IndexTy, 1)});
2446 ElemPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
2447 ElemPtr, Builder.getPtrTy(0), ElemPtr->getName() + ".ascast");
2448
2449 Function *CurFunc = Builder.GetInsertBlock()->getParent();
2450 if ((Size / IntSize) > 1) {
2451 Value *PtrEnd = Builder.CreatePointerBitCastOrAddrSpaceCast(
2452 SrcAddrGEP, Builder.getPtrTy());
2453 BasicBlock *PreCondBB =
2454 BasicBlock::Create(M.getContext(), ".shuffle.pre_cond");
2455 BasicBlock *ThenBB = BasicBlock::Create(M.getContext(), ".shuffle.then");
2456 BasicBlock *ExitBB = BasicBlock::Create(M.getContext(), ".shuffle.exit");
2457 BasicBlock *CurrentBB = Builder.GetInsertBlock();
2458 emitBlock(PreCondBB, CurFunc);
2459 PHINode *PhiSrc =
2460 Builder.CreatePHI(Ptr->getType(), /*NumReservedValues=*/2);
2461 PhiSrc->addIncoming(Ptr, CurrentBB);
2462 PHINode *PhiDest =
2463 Builder.CreatePHI(ElemPtr->getType(), /*NumReservedValues=*/2);
2464 PhiDest->addIncoming(ElemPtr, CurrentBB);
2465 Ptr = PhiSrc;
2466 ElemPtr = PhiDest;
2467 Value *PtrDiff = Builder.CreatePtrDiff(
2468 Builder.getInt8Ty(), PtrEnd,
2469 Builder.CreatePointerBitCastOrAddrSpaceCast(Ptr, Builder.getPtrTy()));
2470 Builder.CreateCondBr(
2471 Builder.CreateICmpSGT(PtrDiff, Builder.getInt64(IntSize - 1)), ThenBB,
2472 ExitBB);
2473 emitBlock(ThenBB, CurFunc);
2474 Value *Res = createRuntimeShuffleFunction(
2475 AllocaIP,
2476 Builder.CreateAlignedLoad(
2477 IntType, Ptr, M.getDataLayout().getPrefTypeAlign(ElemType)),
2478 IntType, Offset);
2479 Builder.CreateAlignedStore(Res, ElemPtr,
2480 M.getDataLayout().getPrefTypeAlign(ElemType));
2481 Value *LocalPtr =
2482 Builder.CreateGEP(IntType, Ptr, {ConstantInt::get(IndexTy, 1)});
2483 Value *LocalElemPtr =
2484 Builder.CreateGEP(IntType, ElemPtr, {ConstantInt::get(IndexTy, 1)});
2485 PhiSrc->addIncoming(LocalPtr, ThenBB);
2486 PhiDest->addIncoming(LocalElemPtr, ThenBB);
2487 emitBranch(PreCondBB);
2488 emitBlock(ExitBB, CurFunc);
2489 } else {
2490 Value *Res = createRuntimeShuffleFunction(
2491 AllocaIP, Builder.CreateLoad(IntType, Ptr), IntType, Offset);
2492 if (ElemType->isIntegerTy() && ElemType->getScalarSizeInBits() <
2493 Res->getType()->getScalarSizeInBits())
2494 Res = Builder.CreateTrunc(Res, ElemType);
2495 Builder.CreateStore(Res, ElemPtr);
2496 Ptr = Builder.CreateGEP(IntType, Ptr, {ConstantInt::get(IndexTy, 1)});
2497 ElemPtr =
2498 Builder.CreateGEP(IntType, ElemPtr, {ConstantInt::get(IndexTy, 1)});
2499 }
2500 Size = Size % IntSize;
2501 }
2502 }
2503
emitReductionListCopy(InsertPointTy AllocaIP,CopyAction Action,Type * ReductionArrayTy,ArrayRef<ReductionInfo> ReductionInfos,Value * SrcBase,Value * DestBase,CopyOptionsTy CopyOptions)2504 void OpenMPIRBuilder::emitReductionListCopy(
2505 InsertPointTy AllocaIP, CopyAction Action, Type *ReductionArrayTy,
2506 ArrayRef<ReductionInfo> ReductionInfos, Value *SrcBase, Value *DestBase,
2507 CopyOptionsTy CopyOptions) {
2508 Type *IndexTy = Builder.getIndexTy(
2509 M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
2510 Value *RemoteLaneOffset = CopyOptions.RemoteLaneOffset;
2511
2512 // Iterates, element-by-element, through the source Reduce list and
2513 // make a copy.
2514 for (auto En : enumerate(ReductionInfos)) {
2515 const ReductionInfo &RI = En.value();
2516 Value *SrcElementAddr = nullptr;
2517 Value *DestElementAddr = nullptr;
2518 Value *DestElementPtrAddr = nullptr;
2519 // Should we shuffle in an element from a remote lane?
2520 bool ShuffleInElement = false;
2521 // Set to true to update the pointer in the dest Reduce list to a
2522 // newly created element.
2523 bool UpdateDestListPtr = false;
2524
2525 // Step 1.1: Get the address for the src element in the Reduce list.
2526 Value *SrcElementPtrAddr = Builder.CreateInBoundsGEP(
2527 ReductionArrayTy, SrcBase,
2528 {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
2529 SrcElementAddr = Builder.CreateLoad(Builder.getPtrTy(), SrcElementPtrAddr);
2530
2531 // Step 1.2: Create a temporary to store the element in the destination
2532 // Reduce list.
2533 DestElementPtrAddr = Builder.CreateInBoundsGEP(
2534 ReductionArrayTy, DestBase,
2535 {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
2536 switch (Action) {
2537 case CopyAction::RemoteLaneToThread: {
2538 InsertPointTy CurIP = Builder.saveIP();
2539 Builder.restoreIP(AllocaIP);
2540 AllocaInst *DestAlloca = Builder.CreateAlloca(RI.ElementType, nullptr,
2541 ".omp.reduction.element");
2542 DestAlloca->setAlignment(
2543 M.getDataLayout().getPrefTypeAlign(RI.ElementType));
2544 DestElementAddr = DestAlloca;
2545 DestElementAddr =
2546 Builder.CreateAddrSpaceCast(DestElementAddr, Builder.getPtrTy(),
2547 DestElementAddr->getName() + ".ascast");
2548 Builder.restoreIP(CurIP);
2549 ShuffleInElement = true;
2550 UpdateDestListPtr = true;
2551 break;
2552 }
2553 case CopyAction::ThreadCopy: {
2554 DestElementAddr =
2555 Builder.CreateLoad(Builder.getPtrTy(), DestElementPtrAddr);
2556 break;
2557 }
2558 }
2559
2560 // Now that all active lanes have read the element in the
2561 // Reduce list, shuffle over the value from the remote lane.
2562 if (ShuffleInElement) {
2563 shuffleAndStore(AllocaIP, SrcElementAddr, DestElementAddr, RI.ElementType,
2564 RemoteLaneOffset, ReductionArrayTy);
2565 } else {
2566 switch (RI.EvaluationKind) {
2567 case EvalKind::Scalar: {
2568 Value *Elem = Builder.CreateLoad(RI.ElementType, SrcElementAddr);
2569 // Store the source element value to the dest element address.
2570 Builder.CreateStore(Elem, DestElementAddr);
2571 break;
2572 }
2573 case EvalKind::Complex: {
2574 Value *SrcRealPtr = Builder.CreateConstInBoundsGEP2_32(
2575 RI.ElementType, SrcElementAddr, 0, 0, ".realp");
2576 Value *SrcReal = Builder.CreateLoad(
2577 RI.ElementType->getStructElementType(0), SrcRealPtr, ".real");
2578 Value *SrcImgPtr = Builder.CreateConstInBoundsGEP2_32(
2579 RI.ElementType, SrcElementAddr, 0, 1, ".imagp");
2580 Value *SrcImg = Builder.CreateLoad(
2581 RI.ElementType->getStructElementType(1), SrcImgPtr, ".imag");
2582
2583 Value *DestRealPtr = Builder.CreateConstInBoundsGEP2_32(
2584 RI.ElementType, DestElementAddr, 0, 0, ".realp");
2585 Value *DestImgPtr = Builder.CreateConstInBoundsGEP2_32(
2586 RI.ElementType, DestElementAddr, 0, 1, ".imagp");
2587 Builder.CreateStore(SrcReal, DestRealPtr);
2588 Builder.CreateStore(SrcImg, DestImgPtr);
2589 break;
2590 }
2591 case EvalKind::Aggregate: {
2592 Value *SizeVal = Builder.getInt64(
2593 M.getDataLayout().getTypeStoreSize(RI.ElementType));
2594 Builder.CreateMemCpy(
2595 DestElementAddr, M.getDataLayout().getPrefTypeAlign(RI.ElementType),
2596 SrcElementAddr, M.getDataLayout().getPrefTypeAlign(RI.ElementType),
2597 SizeVal, false);
2598 break;
2599 }
2600 };
2601 }
2602
2603 // Step 3.1: Modify reference in dest Reduce list as needed.
2604 // Modifying the reference in Reduce list to point to the newly
2605 // created element. The element is live in the current function
2606 // scope and that of functions it invokes (i.e., reduce_function).
2607 // RemoteReduceData[i] = (void*)&RemoteElem
2608 if (UpdateDestListPtr) {
2609 Value *CastDestAddr = Builder.CreatePointerBitCastOrAddrSpaceCast(
2610 DestElementAddr, Builder.getPtrTy(),
2611 DestElementAddr->getName() + ".ascast");
2612 Builder.CreateStore(CastDestAddr, DestElementPtrAddr);
2613 }
2614 }
2615 }
2616
emitInterWarpCopyFunction(const LocationDescription & Loc,ArrayRef<ReductionInfo> ReductionInfos,AttributeList FuncAttrs)2617 Expected<Function *> OpenMPIRBuilder::emitInterWarpCopyFunction(
2618 const LocationDescription &Loc, ArrayRef<ReductionInfo> ReductionInfos,
2619 AttributeList FuncAttrs) {
2620 IRBuilder<>::InsertPointGuard IPG(Builder);
2621 LLVMContext &Ctx = M.getContext();
2622 FunctionType *FuncTy = FunctionType::get(
2623 Builder.getVoidTy(), {Builder.getPtrTy(), Builder.getInt32Ty()},
2624 /* IsVarArg */ false);
2625 Function *WcFunc =
2626 Function::Create(FuncTy, GlobalVariable::InternalLinkage,
2627 "_omp_reduction_inter_warp_copy_func", &M);
2628 WcFunc->setAttributes(FuncAttrs);
2629 WcFunc->addParamAttr(0, Attribute::NoUndef);
2630 WcFunc->addParamAttr(1, Attribute::NoUndef);
2631 BasicBlock *EntryBB = BasicBlock::Create(M.getContext(), "entry", WcFunc);
2632 Builder.SetInsertPoint(EntryBB);
2633 Builder.SetCurrentDebugLocation(llvm::DebugLoc());
2634
2635 // ReduceList: thread local Reduce list.
2636 // At the stage of the computation when this function is called, partially
2637 // aggregated values reside in the first lane of every active warp.
2638 Argument *ReduceListArg = WcFunc->getArg(0);
2639 // NumWarps: number of warps active in the parallel region. This could
2640 // be smaller than 32 (max warps in a CTA) for partial block reduction.
2641 Argument *NumWarpsArg = WcFunc->getArg(1);
2642
2643 // This array is used as a medium to transfer, one reduce element at a time,
2644 // the data from the first lane of every warp to lanes in the first warp
2645 // in order to perform the final step of a reduction in a parallel region
2646 // (reduction across warps). The array is placed in NVPTX __shared__ memory
2647 // for reduced latency, as well as to have a distinct copy for concurrently
2648 // executing target regions. The array is declared with common linkage so
2649 // as to be shared across compilation units.
2650 StringRef TransferMediumName =
2651 "__openmp_nvptx_data_transfer_temporary_storage";
2652 GlobalVariable *TransferMedium = M.getGlobalVariable(TransferMediumName);
2653 unsigned WarpSize = Config.getGridValue().GV_Warp_Size;
2654 ArrayType *ArrayTy = ArrayType::get(Builder.getInt32Ty(), WarpSize);
2655 if (!TransferMedium) {
2656 TransferMedium = new GlobalVariable(
2657 M, ArrayTy, /*isConstant=*/false, GlobalVariable::WeakAnyLinkage,
2658 UndefValue::get(ArrayTy), TransferMediumName,
2659 /*InsertBefore=*/nullptr, GlobalVariable::NotThreadLocal,
2660 /*AddressSpace=*/3);
2661 }
2662
2663 // Get the CUDA thread id of the current OpenMP thread on the GPU.
2664 Value *GPUThreadID = getGPUThreadID();
2665 // nvptx_lane_id = nvptx_id % warpsize
2666 Value *LaneID = getNVPTXLaneID();
2667 // nvptx_warp_id = nvptx_id / warpsize
2668 Value *WarpID = getNVPTXWarpID();
2669
2670 InsertPointTy AllocaIP =
2671 InsertPointTy(Builder.GetInsertBlock(),
2672 Builder.GetInsertBlock()->getFirstInsertionPt());
2673 Type *Arg0Type = ReduceListArg->getType();
2674 Type *Arg1Type = NumWarpsArg->getType();
2675 Builder.restoreIP(AllocaIP);
2676 AllocaInst *ReduceListAlloca = Builder.CreateAlloca(
2677 Arg0Type, nullptr, ReduceListArg->getName() + ".addr");
2678 AllocaInst *NumWarpsAlloca =
2679 Builder.CreateAlloca(Arg1Type, nullptr, NumWarpsArg->getName() + ".addr");
2680 Value *ReduceListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2681 ReduceListAlloca, Arg0Type, ReduceListAlloca->getName() + ".ascast");
2682 Value *NumWarpsAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2683 NumWarpsAlloca, Builder.getPtrTy(0),
2684 NumWarpsAlloca->getName() + ".ascast");
2685 Builder.CreateStore(ReduceListArg, ReduceListAddrCast);
2686 Builder.CreateStore(NumWarpsArg, NumWarpsAddrCast);
2687 AllocaIP = getInsertPointAfterInstr(NumWarpsAlloca);
2688 InsertPointTy CodeGenIP =
2689 getInsertPointAfterInstr(&Builder.GetInsertBlock()->back());
2690 Builder.restoreIP(CodeGenIP);
2691
2692 Value *ReduceList =
2693 Builder.CreateLoad(Builder.getPtrTy(), ReduceListAddrCast);
2694
2695 for (auto En : enumerate(ReductionInfos)) {
2696 //
2697 // Warp master copies reduce element to transfer medium in __shared__
2698 // memory.
2699 //
2700 const ReductionInfo &RI = En.value();
2701 unsigned RealTySize = M.getDataLayout().getTypeAllocSize(RI.ElementType);
2702 for (unsigned TySize = 4; TySize > 0 && RealTySize > 0; TySize /= 2) {
2703 Type *CType = Builder.getIntNTy(TySize * 8);
2704
2705 unsigned NumIters = RealTySize / TySize;
2706 if (NumIters == 0)
2707 continue;
2708 Value *Cnt = nullptr;
2709 Value *CntAddr = nullptr;
2710 BasicBlock *PrecondBB = nullptr;
2711 BasicBlock *ExitBB = nullptr;
2712 if (NumIters > 1) {
2713 CodeGenIP = Builder.saveIP();
2714 Builder.restoreIP(AllocaIP);
2715 CntAddr =
2716 Builder.CreateAlloca(Builder.getInt32Ty(), nullptr, ".cnt.addr");
2717
2718 CntAddr = Builder.CreateAddrSpaceCast(CntAddr, Builder.getPtrTy(),
2719 CntAddr->getName() + ".ascast");
2720 Builder.restoreIP(CodeGenIP);
2721 Builder.CreateStore(Constant::getNullValue(Builder.getInt32Ty()),
2722 CntAddr,
2723 /*Volatile=*/false);
2724 PrecondBB = BasicBlock::Create(Ctx, "precond");
2725 ExitBB = BasicBlock::Create(Ctx, "exit");
2726 BasicBlock *BodyBB = BasicBlock::Create(Ctx, "body");
2727 emitBlock(PrecondBB, Builder.GetInsertBlock()->getParent());
2728 Cnt = Builder.CreateLoad(Builder.getInt32Ty(), CntAddr,
2729 /*Volatile=*/false);
2730 Value *Cmp = Builder.CreateICmpULT(
2731 Cnt, ConstantInt::get(Builder.getInt32Ty(), NumIters));
2732 Builder.CreateCondBr(Cmp, BodyBB, ExitBB);
2733 emitBlock(BodyBB, Builder.GetInsertBlock()->getParent());
2734 }
2735
2736 // kmpc_barrier.
2737 InsertPointOrErrorTy BarrierIP1 =
2738 createBarrier(LocationDescription(Builder.saveIP(), Loc.DL),
2739 omp::Directive::OMPD_unknown,
2740 /* ForceSimpleCall */ false,
2741 /* CheckCancelFlag */ true);
2742 if (!BarrierIP1)
2743 return BarrierIP1.takeError();
2744 BasicBlock *ThenBB = BasicBlock::Create(Ctx, "then");
2745 BasicBlock *ElseBB = BasicBlock::Create(Ctx, "else");
2746 BasicBlock *MergeBB = BasicBlock::Create(Ctx, "ifcont");
2747
2748 // if (lane_id == 0)
2749 Value *IsWarpMaster = Builder.CreateIsNull(LaneID, "warp_master");
2750 Builder.CreateCondBr(IsWarpMaster, ThenBB, ElseBB);
2751 emitBlock(ThenBB, Builder.GetInsertBlock()->getParent());
2752
2753 // Reduce element = LocalReduceList[i]
2754 auto *RedListArrayTy =
2755 ArrayType::get(Builder.getPtrTy(), ReductionInfos.size());
2756 Type *IndexTy = Builder.getIndexTy(
2757 M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
2758 Value *ElemPtrPtr =
2759 Builder.CreateInBoundsGEP(RedListArrayTy, ReduceList,
2760 {ConstantInt::get(IndexTy, 0),
2761 ConstantInt::get(IndexTy, En.index())});
2762 // elemptr = ((CopyType*)(elemptrptr)) + I
2763 Value *ElemPtr = Builder.CreateLoad(Builder.getPtrTy(), ElemPtrPtr);
2764 if (NumIters > 1)
2765 ElemPtr = Builder.CreateGEP(Builder.getInt32Ty(), ElemPtr, Cnt);
2766
2767 // Get pointer to location in transfer medium.
2768 // MediumPtr = &medium[warp_id]
2769 Value *MediumPtr = Builder.CreateInBoundsGEP(
2770 ArrayTy, TransferMedium, {Builder.getInt64(0), WarpID});
2771 // elem = *elemptr
2772 //*MediumPtr = elem
2773 Value *Elem = Builder.CreateLoad(CType, ElemPtr);
2774 // Store the source element value to the dest element address.
2775 Builder.CreateStore(Elem, MediumPtr,
2776 /*IsVolatile*/ true);
2777 Builder.CreateBr(MergeBB);
2778
2779 // else
2780 emitBlock(ElseBB, Builder.GetInsertBlock()->getParent());
2781 Builder.CreateBr(MergeBB);
2782
2783 // endif
2784 emitBlock(MergeBB, Builder.GetInsertBlock()->getParent());
2785 InsertPointOrErrorTy BarrierIP2 =
2786 createBarrier(LocationDescription(Builder.saveIP(), Loc.DL),
2787 omp::Directive::OMPD_unknown,
2788 /* ForceSimpleCall */ false,
2789 /* CheckCancelFlag */ true);
2790 if (!BarrierIP2)
2791 return BarrierIP2.takeError();
2792
2793 // Warp 0 copies reduce element from transfer medium
2794 BasicBlock *W0ThenBB = BasicBlock::Create(Ctx, "then");
2795 BasicBlock *W0ElseBB = BasicBlock::Create(Ctx, "else");
2796 BasicBlock *W0MergeBB = BasicBlock::Create(Ctx, "ifcont");
2797
2798 Value *NumWarpsVal =
2799 Builder.CreateLoad(Builder.getInt32Ty(), NumWarpsAddrCast);
2800 // Up to 32 threads in warp 0 are active.
2801 Value *IsActiveThread =
2802 Builder.CreateICmpULT(GPUThreadID, NumWarpsVal, "is_active_thread");
2803 Builder.CreateCondBr(IsActiveThread, W0ThenBB, W0ElseBB);
2804
2805 emitBlock(W0ThenBB, Builder.GetInsertBlock()->getParent());
2806
2807 // SecMediumPtr = &medium[tid]
2808 // SrcMediumVal = *SrcMediumPtr
2809 Value *SrcMediumPtrVal = Builder.CreateInBoundsGEP(
2810 ArrayTy, TransferMedium, {Builder.getInt64(0), GPUThreadID});
2811 // TargetElemPtr = (CopyType*)(SrcDataAddr[i]) + I
2812 Value *TargetElemPtrPtr =
2813 Builder.CreateInBoundsGEP(RedListArrayTy, ReduceList,
2814 {ConstantInt::get(IndexTy, 0),
2815 ConstantInt::get(IndexTy, En.index())});
2816 Value *TargetElemPtrVal =
2817 Builder.CreateLoad(Builder.getPtrTy(), TargetElemPtrPtr);
2818 Value *TargetElemPtr = TargetElemPtrVal;
2819 if (NumIters > 1)
2820 TargetElemPtr =
2821 Builder.CreateGEP(Builder.getInt32Ty(), TargetElemPtr, Cnt);
2822
2823 // *TargetElemPtr = SrcMediumVal;
2824 Value *SrcMediumValue =
2825 Builder.CreateLoad(CType, SrcMediumPtrVal, /*IsVolatile*/ true);
2826 Builder.CreateStore(SrcMediumValue, TargetElemPtr);
2827 Builder.CreateBr(W0MergeBB);
2828
2829 emitBlock(W0ElseBB, Builder.GetInsertBlock()->getParent());
2830 Builder.CreateBr(W0MergeBB);
2831
2832 emitBlock(W0MergeBB, Builder.GetInsertBlock()->getParent());
2833
2834 if (NumIters > 1) {
2835 Cnt = Builder.CreateNSWAdd(
2836 Cnt, ConstantInt::get(Builder.getInt32Ty(), /*V=*/1));
2837 Builder.CreateStore(Cnt, CntAddr, /*Volatile=*/false);
2838
2839 auto *CurFn = Builder.GetInsertBlock()->getParent();
2840 emitBranch(PrecondBB);
2841 emitBlock(ExitBB, CurFn);
2842 }
2843 RealTySize %= TySize;
2844 }
2845 }
2846
2847 Builder.CreateRetVoid();
2848
2849 return WcFunc;
2850 }
2851
emitShuffleAndReduceFunction(ArrayRef<ReductionInfo> ReductionInfos,Function * ReduceFn,AttributeList FuncAttrs)2852 Function *OpenMPIRBuilder::emitShuffleAndReduceFunction(
2853 ArrayRef<ReductionInfo> ReductionInfos, Function *ReduceFn,
2854 AttributeList FuncAttrs) {
2855 LLVMContext &Ctx = M.getContext();
2856 IRBuilder<>::InsertPointGuard IPG(Builder);
2857 FunctionType *FuncTy =
2858 FunctionType::get(Builder.getVoidTy(),
2859 {Builder.getPtrTy(), Builder.getInt16Ty(),
2860 Builder.getInt16Ty(), Builder.getInt16Ty()},
2861 /* IsVarArg */ false);
2862 Function *SarFunc =
2863 Function::Create(FuncTy, GlobalVariable::InternalLinkage,
2864 "_omp_reduction_shuffle_and_reduce_func", &M);
2865 SarFunc->setAttributes(FuncAttrs);
2866 SarFunc->addParamAttr(0, Attribute::NoUndef);
2867 SarFunc->addParamAttr(1, Attribute::NoUndef);
2868 SarFunc->addParamAttr(2, Attribute::NoUndef);
2869 SarFunc->addParamAttr(3, Attribute::NoUndef);
2870 SarFunc->addParamAttr(1, Attribute::SExt);
2871 SarFunc->addParamAttr(2, Attribute::SExt);
2872 SarFunc->addParamAttr(3, Attribute::SExt);
2873 BasicBlock *EntryBB = BasicBlock::Create(M.getContext(), "entry", SarFunc);
2874 Builder.SetInsertPoint(EntryBB);
2875 Builder.SetCurrentDebugLocation(llvm::DebugLoc());
2876
2877 // Thread local Reduce list used to host the values of data to be reduced.
2878 Argument *ReduceListArg = SarFunc->getArg(0);
2879 // Current lane id; could be logical.
2880 Argument *LaneIDArg = SarFunc->getArg(1);
2881 // Offset of the remote source lane relative to the current lane.
2882 Argument *RemoteLaneOffsetArg = SarFunc->getArg(2);
2883 // Algorithm version. This is expected to be known at compile time.
2884 Argument *AlgoVerArg = SarFunc->getArg(3);
2885
2886 Type *ReduceListArgType = ReduceListArg->getType();
2887 Type *LaneIDArgType = LaneIDArg->getType();
2888 Type *LaneIDArgPtrType = Builder.getPtrTy(0);
2889 Value *ReduceListAlloca = Builder.CreateAlloca(
2890 ReduceListArgType, nullptr, ReduceListArg->getName() + ".addr");
2891 Value *LaneIdAlloca = Builder.CreateAlloca(LaneIDArgType, nullptr,
2892 LaneIDArg->getName() + ".addr");
2893 Value *RemoteLaneOffsetAlloca = Builder.CreateAlloca(
2894 LaneIDArgType, nullptr, RemoteLaneOffsetArg->getName() + ".addr");
2895 Value *AlgoVerAlloca = Builder.CreateAlloca(LaneIDArgType, nullptr,
2896 AlgoVerArg->getName() + ".addr");
2897 ArrayType *RedListArrayTy =
2898 ArrayType::get(Builder.getPtrTy(), ReductionInfos.size());
2899
2900 // Create a local thread-private variable to host the Reduce list
2901 // from a remote lane.
2902 Instruction *RemoteReductionListAlloca = Builder.CreateAlloca(
2903 RedListArrayTy, nullptr, ".omp.reduction.remote_reduce_list");
2904
2905 Value *ReduceListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2906 ReduceListAlloca, ReduceListArgType,
2907 ReduceListAlloca->getName() + ".ascast");
2908 Value *LaneIdAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2909 LaneIdAlloca, LaneIDArgPtrType, LaneIdAlloca->getName() + ".ascast");
2910 Value *RemoteLaneOffsetAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2911 RemoteLaneOffsetAlloca, LaneIDArgPtrType,
2912 RemoteLaneOffsetAlloca->getName() + ".ascast");
2913 Value *AlgoVerAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2914 AlgoVerAlloca, LaneIDArgPtrType, AlgoVerAlloca->getName() + ".ascast");
2915 Value *RemoteListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2916 RemoteReductionListAlloca, Builder.getPtrTy(),
2917 RemoteReductionListAlloca->getName() + ".ascast");
2918
2919 Builder.CreateStore(ReduceListArg, ReduceListAddrCast);
2920 Builder.CreateStore(LaneIDArg, LaneIdAddrCast);
2921 Builder.CreateStore(RemoteLaneOffsetArg, RemoteLaneOffsetAddrCast);
2922 Builder.CreateStore(AlgoVerArg, AlgoVerAddrCast);
2923
2924 Value *ReduceList = Builder.CreateLoad(ReduceListArgType, ReduceListAddrCast);
2925 Value *LaneId = Builder.CreateLoad(LaneIDArgType, LaneIdAddrCast);
2926 Value *RemoteLaneOffset =
2927 Builder.CreateLoad(LaneIDArgType, RemoteLaneOffsetAddrCast);
2928 Value *AlgoVer = Builder.CreateLoad(LaneIDArgType, AlgoVerAddrCast);
2929
2930 InsertPointTy AllocaIP = getInsertPointAfterInstr(RemoteReductionListAlloca);
2931
2932 // This loop iterates through the list of reduce elements and copies,
2933 // element by element, from a remote lane in the warp to RemoteReduceList,
2934 // hosted on the thread's stack.
2935 emitReductionListCopy(
2936 AllocaIP, CopyAction::RemoteLaneToThread, RedListArrayTy, ReductionInfos,
2937 ReduceList, RemoteListAddrCast, {RemoteLaneOffset, nullptr, nullptr});
2938
2939 // The actions to be performed on the Remote Reduce list is dependent
2940 // on the algorithm version.
2941 //
2942 // if (AlgoVer==0) || (AlgoVer==1 && (LaneId < Offset)) || (AlgoVer==2 &&
2943 // LaneId % 2 == 0 && Offset > 0):
2944 // do the reduction value aggregation
2945 //
2946 // The thread local variable Reduce list is mutated in place to host the
2947 // reduced data, which is the aggregated value produced from local and
2948 // remote lanes.
2949 //
2950 // Note that AlgoVer is expected to be a constant integer known at compile
2951 // time.
2952 // When AlgoVer==0, the first conjunction evaluates to true, making
2953 // the entire predicate true during compile time.
2954 // When AlgoVer==1, the second conjunction has only the second part to be
2955 // evaluated during runtime. Other conjunctions evaluates to false
2956 // during compile time.
2957 // When AlgoVer==2, the third conjunction has only the second part to be
2958 // evaluated during runtime. Other conjunctions evaluates to false
2959 // during compile time.
2960 Value *CondAlgo0 = Builder.CreateIsNull(AlgoVer);
2961 Value *Algo1 = Builder.CreateICmpEQ(AlgoVer, Builder.getInt16(1));
2962 Value *LaneComp = Builder.CreateICmpULT(LaneId, RemoteLaneOffset);
2963 Value *CondAlgo1 = Builder.CreateAnd(Algo1, LaneComp);
2964 Value *Algo2 = Builder.CreateICmpEQ(AlgoVer, Builder.getInt16(2));
2965 Value *LaneIdAnd1 = Builder.CreateAnd(LaneId, Builder.getInt16(1));
2966 Value *LaneIdComp = Builder.CreateIsNull(LaneIdAnd1);
2967 Value *Algo2AndLaneIdComp = Builder.CreateAnd(Algo2, LaneIdComp);
2968 Value *RemoteOffsetComp =
2969 Builder.CreateICmpSGT(RemoteLaneOffset, Builder.getInt16(0));
2970 Value *CondAlgo2 = Builder.CreateAnd(Algo2AndLaneIdComp, RemoteOffsetComp);
2971 Value *CA0OrCA1 = Builder.CreateOr(CondAlgo0, CondAlgo1);
2972 Value *CondReduce = Builder.CreateOr(CA0OrCA1, CondAlgo2);
2973
2974 BasicBlock *ThenBB = BasicBlock::Create(Ctx, "then");
2975 BasicBlock *ElseBB = BasicBlock::Create(Ctx, "else");
2976 BasicBlock *MergeBB = BasicBlock::Create(Ctx, "ifcont");
2977
2978 Builder.CreateCondBr(CondReduce, ThenBB, ElseBB);
2979 emitBlock(ThenBB, Builder.GetInsertBlock()->getParent());
2980 Value *LocalReduceListPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
2981 ReduceList, Builder.getPtrTy());
2982 Value *RemoteReduceListPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
2983 RemoteListAddrCast, Builder.getPtrTy());
2984 Builder.CreateCall(ReduceFn, {LocalReduceListPtr, RemoteReduceListPtr})
2985 ->addFnAttr(Attribute::NoUnwind);
2986 Builder.CreateBr(MergeBB);
2987
2988 emitBlock(ElseBB, Builder.GetInsertBlock()->getParent());
2989 Builder.CreateBr(MergeBB);
2990
2991 emitBlock(MergeBB, Builder.GetInsertBlock()->getParent());
2992
2993 // if (AlgoVer==1 && (LaneId >= Offset)) copy Remote Reduce list to local
2994 // Reduce list.
2995 Algo1 = Builder.CreateICmpEQ(AlgoVer, Builder.getInt16(1));
2996 Value *LaneIdGtOffset = Builder.CreateICmpUGE(LaneId, RemoteLaneOffset);
2997 Value *CondCopy = Builder.CreateAnd(Algo1, LaneIdGtOffset);
2998
2999 BasicBlock *CpyThenBB = BasicBlock::Create(Ctx, "then");
3000 BasicBlock *CpyElseBB = BasicBlock::Create(Ctx, "else");
3001 BasicBlock *CpyMergeBB = BasicBlock::Create(Ctx, "ifcont");
3002 Builder.CreateCondBr(CondCopy, CpyThenBB, CpyElseBB);
3003
3004 emitBlock(CpyThenBB, Builder.GetInsertBlock()->getParent());
3005 emitReductionListCopy(AllocaIP, CopyAction::ThreadCopy, RedListArrayTy,
3006 ReductionInfos, RemoteListAddrCast, ReduceList);
3007 Builder.CreateBr(CpyMergeBB);
3008
3009 emitBlock(CpyElseBB, Builder.GetInsertBlock()->getParent());
3010 Builder.CreateBr(CpyMergeBB);
3011
3012 emitBlock(CpyMergeBB, Builder.GetInsertBlock()->getParent());
3013
3014 Builder.CreateRetVoid();
3015
3016 return SarFunc;
3017 }
3018
emitListToGlobalCopyFunction(ArrayRef<ReductionInfo> ReductionInfos,Type * ReductionsBufferTy,AttributeList FuncAttrs)3019 Function *OpenMPIRBuilder::emitListToGlobalCopyFunction(
3020 ArrayRef<ReductionInfo> ReductionInfos, Type *ReductionsBufferTy,
3021 AttributeList FuncAttrs) {
3022 IRBuilder<>::InsertPointGuard IPG(Builder);
3023 LLVMContext &Ctx = M.getContext();
3024 FunctionType *FuncTy = FunctionType::get(
3025 Builder.getVoidTy(),
3026 {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
3027 /* IsVarArg */ false);
3028 Function *LtGCFunc =
3029 Function::Create(FuncTy, GlobalVariable::InternalLinkage,
3030 "_omp_reduction_list_to_global_copy_func", &M);
3031 LtGCFunc->setAttributes(FuncAttrs);
3032 LtGCFunc->addParamAttr(0, Attribute::NoUndef);
3033 LtGCFunc->addParamAttr(1, Attribute::NoUndef);
3034 LtGCFunc->addParamAttr(2, Attribute::NoUndef);
3035
3036 BasicBlock *EntryBlock = BasicBlock::Create(Ctx, "entry", LtGCFunc);
3037 Builder.SetInsertPoint(EntryBlock);
3038 Builder.SetCurrentDebugLocation(llvm::DebugLoc());
3039
3040 // Buffer: global reduction buffer.
3041 Argument *BufferArg = LtGCFunc->getArg(0);
3042 // Idx: index of the buffer.
3043 Argument *IdxArg = LtGCFunc->getArg(1);
3044 // ReduceList: thread local Reduce list.
3045 Argument *ReduceListArg = LtGCFunc->getArg(2);
3046
3047 Value *BufferArgAlloca = Builder.CreateAlloca(Builder.getPtrTy(), nullptr,
3048 BufferArg->getName() + ".addr");
3049 Value *IdxArgAlloca = Builder.CreateAlloca(Builder.getInt32Ty(), nullptr,
3050 IdxArg->getName() + ".addr");
3051 Value *ReduceListArgAlloca = Builder.CreateAlloca(
3052 Builder.getPtrTy(), nullptr, ReduceListArg->getName() + ".addr");
3053 Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3054 BufferArgAlloca, Builder.getPtrTy(),
3055 BufferArgAlloca->getName() + ".ascast");
3056 Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3057 IdxArgAlloca, Builder.getPtrTy(), IdxArgAlloca->getName() + ".ascast");
3058 Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3059 ReduceListArgAlloca, Builder.getPtrTy(),
3060 ReduceListArgAlloca->getName() + ".ascast");
3061
3062 Builder.CreateStore(BufferArg, BufferArgAddrCast);
3063 Builder.CreateStore(IdxArg, IdxArgAddrCast);
3064 Builder.CreateStore(ReduceListArg, ReduceListArgAddrCast);
3065
3066 Value *LocalReduceList =
3067 Builder.CreateLoad(Builder.getPtrTy(), ReduceListArgAddrCast);
3068 Value *BufferArgVal =
3069 Builder.CreateLoad(Builder.getPtrTy(), BufferArgAddrCast);
3070 Value *Idxs[] = {Builder.CreateLoad(Builder.getInt32Ty(), IdxArgAddrCast)};
3071 Type *IndexTy = Builder.getIndexTy(
3072 M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
3073 for (auto En : enumerate(ReductionInfos)) {
3074 const ReductionInfo &RI = En.value();
3075 auto *RedListArrayTy =
3076 ArrayType::get(Builder.getPtrTy(), ReductionInfos.size());
3077 // Reduce element = LocalReduceList[i]
3078 Value *ElemPtrPtr = Builder.CreateInBoundsGEP(
3079 RedListArrayTy, LocalReduceList,
3080 {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
3081 // elemptr = ((CopyType*)(elemptrptr)) + I
3082 Value *ElemPtr = Builder.CreateLoad(Builder.getPtrTy(), ElemPtrPtr);
3083
3084 // Global = Buffer.VD[Idx];
3085 Value *BufferVD =
3086 Builder.CreateInBoundsGEP(ReductionsBufferTy, BufferArgVal, Idxs);
3087 Value *GlobVal = Builder.CreateConstInBoundsGEP2_32(
3088 ReductionsBufferTy, BufferVD, 0, En.index());
3089
3090 switch (RI.EvaluationKind) {
3091 case EvalKind::Scalar: {
3092 Value *TargetElement = Builder.CreateLoad(RI.ElementType, ElemPtr);
3093 Builder.CreateStore(TargetElement, GlobVal);
3094 break;
3095 }
3096 case EvalKind::Complex: {
3097 Value *SrcRealPtr = Builder.CreateConstInBoundsGEP2_32(
3098 RI.ElementType, ElemPtr, 0, 0, ".realp");
3099 Value *SrcReal = Builder.CreateLoad(
3100 RI.ElementType->getStructElementType(0), SrcRealPtr, ".real");
3101 Value *SrcImgPtr = Builder.CreateConstInBoundsGEP2_32(
3102 RI.ElementType, ElemPtr, 0, 1, ".imagp");
3103 Value *SrcImg = Builder.CreateLoad(
3104 RI.ElementType->getStructElementType(1), SrcImgPtr, ".imag");
3105
3106 Value *DestRealPtr = Builder.CreateConstInBoundsGEP2_32(
3107 RI.ElementType, GlobVal, 0, 0, ".realp");
3108 Value *DestImgPtr = Builder.CreateConstInBoundsGEP2_32(
3109 RI.ElementType, GlobVal, 0, 1, ".imagp");
3110 Builder.CreateStore(SrcReal, DestRealPtr);
3111 Builder.CreateStore(SrcImg, DestImgPtr);
3112 break;
3113 }
3114 case EvalKind::Aggregate: {
3115 Value *SizeVal =
3116 Builder.getInt64(M.getDataLayout().getTypeStoreSize(RI.ElementType));
3117 Builder.CreateMemCpy(
3118 GlobVal, M.getDataLayout().getPrefTypeAlign(RI.ElementType), ElemPtr,
3119 M.getDataLayout().getPrefTypeAlign(RI.ElementType), SizeVal, false);
3120 break;
3121 }
3122 }
3123 }
3124
3125 Builder.CreateRetVoid();
3126 return LtGCFunc;
3127 }
3128
emitListToGlobalReduceFunction(ArrayRef<ReductionInfo> ReductionInfos,Function * ReduceFn,Type * ReductionsBufferTy,AttributeList FuncAttrs)3129 Function *OpenMPIRBuilder::emitListToGlobalReduceFunction(
3130 ArrayRef<ReductionInfo> ReductionInfos, Function *ReduceFn,
3131 Type *ReductionsBufferTy, AttributeList FuncAttrs) {
3132 IRBuilder<>::InsertPointGuard IPG(Builder);
3133 LLVMContext &Ctx = M.getContext();
3134 FunctionType *FuncTy = FunctionType::get(
3135 Builder.getVoidTy(),
3136 {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
3137 /* IsVarArg */ false);
3138 Function *LtGRFunc =
3139 Function::Create(FuncTy, GlobalVariable::InternalLinkage,
3140 "_omp_reduction_list_to_global_reduce_func", &M);
3141 LtGRFunc->setAttributes(FuncAttrs);
3142 LtGRFunc->addParamAttr(0, Attribute::NoUndef);
3143 LtGRFunc->addParamAttr(1, Attribute::NoUndef);
3144 LtGRFunc->addParamAttr(2, Attribute::NoUndef);
3145
3146 BasicBlock *EntryBlock = BasicBlock::Create(Ctx, "entry", LtGRFunc);
3147 Builder.SetInsertPoint(EntryBlock);
3148 Builder.SetCurrentDebugLocation(llvm::DebugLoc());
3149
3150 // Buffer: global reduction buffer.
3151 Argument *BufferArg = LtGRFunc->getArg(0);
3152 // Idx: index of the buffer.
3153 Argument *IdxArg = LtGRFunc->getArg(1);
3154 // ReduceList: thread local Reduce list.
3155 Argument *ReduceListArg = LtGRFunc->getArg(2);
3156
3157 Value *BufferArgAlloca = Builder.CreateAlloca(Builder.getPtrTy(), nullptr,
3158 BufferArg->getName() + ".addr");
3159 Value *IdxArgAlloca = Builder.CreateAlloca(Builder.getInt32Ty(), nullptr,
3160 IdxArg->getName() + ".addr");
3161 Value *ReduceListArgAlloca = Builder.CreateAlloca(
3162 Builder.getPtrTy(), nullptr, ReduceListArg->getName() + ".addr");
3163 auto *RedListArrayTy =
3164 ArrayType::get(Builder.getPtrTy(), ReductionInfos.size());
3165
3166 // 1. Build a list of reduction variables.
3167 // void *RedList[<n>] = {<ReductionVars>[0], ..., <ReductionVars>[<n>-1]};
3168 Value *LocalReduceList =
3169 Builder.CreateAlloca(RedListArrayTy, nullptr, ".omp.reduction.red_list");
3170
3171 Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3172 BufferArgAlloca, Builder.getPtrTy(),
3173 BufferArgAlloca->getName() + ".ascast");
3174 Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3175 IdxArgAlloca, Builder.getPtrTy(), IdxArgAlloca->getName() + ".ascast");
3176 Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3177 ReduceListArgAlloca, Builder.getPtrTy(),
3178 ReduceListArgAlloca->getName() + ".ascast");
3179 Value *LocalReduceListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3180 LocalReduceList, Builder.getPtrTy(),
3181 LocalReduceList->getName() + ".ascast");
3182
3183 Builder.CreateStore(BufferArg, BufferArgAddrCast);
3184 Builder.CreateStore(IdxArg, IdxArgAddrCast);
3185 Builder.CreateStore(ReduceListArg, ReduceListArgAddrCast);
3186
3187 Value *BufferVal = Builder.CreateLoad(Builder.getPtrTy(), BufferArgAddrCast);
3188 Value *Idxs[] = {Builder.CreateLoad(Builder.getInt32Ty(), IdxArgAddrCast)};
3189 Type *IndexTy = Builder.getIndexTy(
3190 M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
3191 for (auto En : enumerate(ReductionInfos)) {
3192 Value *TargetElementPtrPtr = Builder.CreateInBoundsGEP(
3193 RedListArrayTy, LocalReduceListAddrCast,
3194 {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
3195 Value *BufferVD =
3196 Builder.CreateInBoundsGEP(ReductionsBufferTy, BufferVal, Idxs);
3197 // Global = Buffer.VD[Idx];
3198 Value *GlobValPtr = Builder.CreateConstInBoundsGEP2_32(
3199 ReductionsBufferTy, BufferVD, 0, En.index());
3200 Builder.CreateStore(GlobValPtr, TargetElementPtrPtr);
3201 }
3202
3203 // Call reduce_function(GlobalReduceList, ReduceList)
3204 Value *ReduceList =
3205 Builder.CreateLoad(Builder.getPtrTy(), ReduceListArgAddrCast);
3206 Builder.CreateCall(ReduceFn, {LocalReduceListAddrCast, ReduceList})
3207 ->addFnAttr(Attribute::NoUnwind);
3208 Builder.CreateRetVoid();
3209 return LtGRFunc;
3210 }
3211
emitGlobalToListCopyFunction(ArrayRef<ReductionInfo> ReductionInfos,Type * ReductionsBufferTy,AttributeList FuncAttrs)3212 Function *OpenMPIRBuilder::emitGlobalToListCopyFunction(
3213 ArrayRef<ReductionInfo> ReductionInfos, Type *ReductionsBufferTy,
3214 AttributeList FuncAttrs) {
3215 IRBuilder<>::InsertPointGuard IPG(Builder);
3216 LLVMContext &Ctx = M.getContext();
3217 FunctionType *FuncTy = FunctionType::get(
3218 Builder.getVoidTy(),
3219 {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
3220 /* IsVarArg */ false);
3221 Function *LtGCFunc =
3222 Function::Create(FuncTy, GlobalVariable::InternalLinkage,
3223 "_omp_reduction_global_to_list_copy_func", &M);
3224 LtGCFunc->setAttributes(FuncAttrs);
3225 LtGCFunc->addParamAttr(0, Attribute::NoUndef);
3226 LtGCFunc->addParamAttr(1, Attribute::NoUndef);
3227 LtGCFunc->addParamAttr(2, Attribute::NoUndef);
3228
3229 BasicBlock *EntryBlock = BasicBlock::Create(Ctx, "entry", LtGCFunc);
3230 Builder.SetInsertPoint(EntryBlock);
3231 Builder.SetCurrentDebugLocation(llvm::DebugLoc());
3232
3233 // Buffer: global reduction buffer.
3234 Argument *BufferArg = LtGCFunc->getArg(0);
3235 // Idx: index of the buffer.
3236 Argument *IdxArg = LtGCFunc->getArg(1);
3237 // ReduceList: thread local Reduce list.
3238 Argument *ReduceListArg = LtGCFunc->getArg(2);
3239
3240 Value *BufferArgAlloca = Builder.CreateAlloca(Builder.getPtrTy(), nullptr,
3241 BufferArg->getName() + ".addr");
3242 Value *IdxArgAlloca = Builder.CreateAlloca(Builder.getInt32Ty(), nullptr,
3243 IdxArg->getName() + ".addr");
3244 Value *ReduceListArgAlloca = Builder.CreateAlloca(
3245 Builder.getPtrTy(), nullptr, ReduceListArg->getName() + ".addr");
3246 Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3247 BufferArgAlloca, Builder.getPtrTy(),
3248 BufferArgAlloca->getName() + ".ascast");
3249 Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3250 IdxArgAlloca, Builder.getPtrTy(), IdxArgAlloca->getName() + ".ascast");
3251 Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3252 ReduceListArgAlloca, Builder.getPtrTy(),
3253 ReduceListArgAlloca->getName() + ".ascast");
3254 Builder.CreateStore(BufferArg, BufferArgAddrCast);
3255 Builder.CreateStore(IdxArg, IdxArgAddrCast);
3256 Builder.CreateStore(ReduceListArg, ReduceListArgAddrCast);
3257
3258 Value *LocalReduceList =
3259 Builder.CreateLoad(Builder.getPtrTy(), ReduceListArgAddrCast);
3260 Value *BufferVal = Builder.CreateLoad(Builder.getPtrTy(), BufferArgAddrCast);
3261 Value *Idxs[] = {Builder.CreateLoad(Builder.getInt32Ty(), IdxArgAddrCast)};
3262 Type *IndexTy = Builder.getIndexTy(
3263 M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
3264 for (auto En : enumerate(ReductionInfos)) {
3265 const OpenMPIRBuilder::ReductionInfo &RI = En.value();
3266 auto *RedListArrayTy =
3267 ArrayType::get(Builder.getPtrTy(), ReductionInfos.size());
3268 // Reduce element = LocalReduceList[i]
3269 Value *ElemPtrPtr = Builder.CreateInBoundsGEP(
3270 RedListArrayTy, LocalReduceList,
3271 {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
3272 // elemptr = ((CopyType*)(elemptrptr)) + I
3273 Value *ElemPtr = Builder.CreateLoad(Builder.getPtrTy(), ElemPtrPtr);
3274 // Global = Buffer.VD[Idx];
3275 Value *BufferVD =
3276 Builder.CreateInBoundsGEP(ReductionsBufferTy, BufferVal, Idxs);
3277 Value *GlobValPtr = Builder.CreateConstInBoundsGEP2_32(
3278 ReductionsBufferTy, BufferVD, 0, En.index());
3279
3280 switch (RI.EvaluationKind) {
3281 case EvalKind::Scalar: {
3282 Value *TargetElement = Builder.CreateLoad(RI.ElementType, GlobValPtr);
3283 Builder.CreateStore(TargetElement, ElemPtr);
3284 break;
3285 }
3286 case EvalKind::Complex: {
3287 Value *SrcRealPtr = Builder.CreateConstInBoundsGEP2_32(
3288 RI.ElementType, GlobValPtr, 0, 0, ".realp");
3289 Value *SrcReal = Builder.CreateLoad(
3290 RI.ElementType->getStructElementType(0), SrcRealPtr, ".real");
3291 Value *SrcImgPtr = Builder.CreateConstInBoundsGEP2_32(
3292 RI.ElementType, GlobValPtr, 0, 1, ".imagp");
3293 Value *SrcImg = Builder.CreateLoad(
3294 RI.ElementType->getStructElementType(1), SrcImgPtr, ".imag");
3295
3296 Value *DestRealPtr = Builder.CreateConstInBoundsGEP2_32(
3297 RI.ElementType, ElemPtr, 0, 0, ".realp");
3298 Value *DestImgPtr = Builder.CreateConstInBoundsGEP2_32(
3299 RI.ElementType, ElemPtr, 0, 1, ".imagp");
3300 Builder.CreateStore(SrcReal, DestRealPtr);
3301 Builder.CreateStore(SrcImg, DestImgPtr);
3302 break;
3303 }
3304 case EvalKind::Aggregate: {
3305 Value *SizeVal =
3306 Builder.getInt64(M.getDataLayout().getTypeStoreSize(RI.ElementType));
3307 Builder.CreateMemCpy(
3308 ElemPtr, M.getDataLayout().getPrefTypeAlign(RI.ElementType),
3309 GlobValPtr, M.getDataLayout().getPrefTypeAlign(RI.ElementType),
3310 SizeVal, false);
3311 break;
3312 }
3313 }
3314 }
3315
3316 Builder.CreateRetVoid();
3317 return LtGCFunc;
3318 }
3319
emitGlobalToListReduceFunction(ArrayRef<ReductionInfo> ReductionInfos,Function * ReduceFn,Type * ReductionsBufferTy,AttributeList FuncAttrs)3320 Function *OpenMPIRBuilder::emitGlobalToListReduceFunction(
3321 ArrayRef<ReductionInfo> ReductionInfos, Function *ReduceFn,
3322 Type *ReductionsBufferTy, AttributeList FuncAttrs) {
3323 IRBuilder<>::InsertPointGuard IPG(Builder);
3324 LLVMContext &Ctx = M.getContext();
3325 auto *FuncTy = FunctionType::get(
3326 Builder.getVoidTy(),
3327 {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
3328 /* IsVarArg */ false);
3329 Function *LtGRFunc =
3330 Function::Create(FuncTy, GlobalVariable::InternalLinkage,
3331 "_omp_reduction_global_to_list_reduce_func", &M);
3332 LtGRFunc->setAttributes(FuncAttrs);
3333 LtGRFunc->addParamAttr(0, Attribute::NoUndef);
3334 LtGRFunc->addParamAttr(1, Attribute::NoUndef);
3335 LtGRFunc->addParamAttr(2, Attribute::NoUndef);
3336
3337 BasicBlock *EntryBlock = BasicBlock::Create(Ctx, "entry", LtGRFunc);
3338 Builder.SetInsertPoint(EntryBlock);
3339 Builder.SetCurrentDebugLocation(llvm::DebugLoc());
3340
3341 // Buffer: global reduction buffer.
3342 Argument *BufferArg = LtGRFunc->getArg(0);
3343 // Idx: index of the buffer.
3344 Argument *IdxArg = LtGRFunc->getArg(1);
3345 // ReduceList: thread local Reduce list.
3346 Argument *ReduceListArg = LtGRFunc->getArg(2);
3347
3348 Value *BufferArgAlloca = Builder.CreateAlloca(Builder.getPtrTy(), nullptr,
3349 BufferArg->getName() + ".addr");
3350 Value *IdxArgAlloca = Builder.CreateAlloca(Builder.getInt32Ty(), nullptr,
3351 IdxArg->getName() + ".addr");
3352 Value *ReduceListArgAlloca = Builder.CreateAlloca(
3353 Builder.getPtrTy(), nullptr, ReduceListArg->getName() + ".addr");
3354 ArrayType *RedListArrayTy =
3355 ArrayType::get(Builder.getPtrTy(), ReductionInfos.size());
3356
3357 // 1. Build a list of reduction variables.
3358 // void *RedList[<n>] = {<ReductionVars>[0], ..., <ReductionVars>[<n>-1]};
3359 Value *LocalReduceList =
3360 Builder.CreateAlloca(RedListArrayTy, nullptr, ".omp.reduction.red_list");
3361
3362 Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3363 BufferArgAlloca, Builder.getPtrTy(),
3364 BufferArgAlloca->getName() + ".ascast");
3365 Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3366 IdxArgAlloca, Builder.getPtrTy(), IdxArgAlloca->getName() + ".ascast");
3367 Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3368 ReduceListArgAlloca, Builder.getPtrTy(),
3369 ReduceListArgAlloca->getName() + ".ascast");
3370 Value *ReductionList = Builder.CreatePointerBitCastOrAddrSpaceCast(
3371 LocalReduceList, Builder.getPtrTy(),
3372 LocalReduceList->getName() + ".ascast");
3373
3374 Builder.CreateStore(BufferArg, BufferArgAddrCast);
3375 Builder.CreateStore(IdxArg, IdxArgAddrCast);
3376 Builder.CreateStore(ReduceListArg, ReduceListArgAddrCast);
3377
3378 Value *BufferVal = Builder.CreateLoad(Builder.getPtrTy(), BufferArgAddrCast);
3379 Value *Idxs[] = {Builder.CreateLoad(Builder.getInt32Ty(), IdxArgAddrCast)};
3380 Type *IndexTy = Builder.getIndexTy(
3381 M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
3382 for (auto En : enumerate(ReductionInfos)) {
3383 Value *TargetElementPtrPtr = Builder.CreateInBoundsGEP(
3384 RedListArrayTy, ReductionList,
3385 {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
3386 // Global = Buffer.VD[Idx];
3387 Value *BufferVD =
3388 Builder.CreateInBoundsGEP(ReductionsBufferTy, BufferVal, Idxs);
3389 Value *GlobValPtr = Builder.CreateConstInBoundsGEP2_32(
3390 ReductionsBufferTy, BufferVD, 0, En.index());
3391 Builder.CreateStore(GlobValPtr, TargetElementPtrPtr);
3392 }
3393
3394 // Call reduce_function(ReduceList, GlobalReduceList)
3395 Value *ReduceList =
3396 Builder.CreateLoad(Builder.getPtrTy(), ReduceListArgAddrCast);
3397 Builder.CreateCall(ReduceFn, {ReduceList, ReductionList})
3398 ->addFnAttr(Attribute::NoUnwind);
3399 Builder.CreateRetVoid();
3400 return LtGRFunc;
3401 }
3402
getReductionFuncName(StringRef Name) const3403 std::string OpenMPIRBuilder::getReductionFuncName(StringRef Name) const {
3404 std::string Suffix =
3405 createPlatformSpecificName({"omp", "reduction", "reduction_func"});
3406 return (Name + Suffix).str();
3407 }
3408
createReductionFunction(StringRef ReducerName,ArrayRef<ReductionInfo> ReductionInfos,ReductionGenCBKind ReductionGenCBKind,AttributeList FuncAttrs)3409 Expected<Function *> OpenMPIRBuilder::createReductionFunction(
3410 StringRef ReducerName, ArrayRef<ReductionInfo> ReductionInfos,
3411 ReductionGenCBKind ReductionGenCBKind, AttributeList FuncAttrs) {
3412 IRBuilder<>::InsertPointGuard IPG(Builder);
3413 auto *FuncTy = FunctionType::get(Builder.getVoidTy(),
3414 {Builder.getPtrTy(), Builder.getPtrTy()},
3415 /* IsVarArg */ false);
3416 std::string Name = getReductionFuncName(ReducerName);
3417 Function *ReductionFunc =
3418 Function::Create(FuncTy, GlobalVariable::InternalLinkage, Name, &M);
3419 ReductionFunc->setAttributes(FuncAttrs);
3420 ReductionFunc->addParamAttr(0, Attribute::NoUndef);
3421 ReductionFunc->addParamAttr(1, Attribute::NoUndef);
3422 BasicBlock *EntryBB =
3423 BasicBlock::Create(M.getContext(), "entry", ReductionFunc);
3424 Builder.SetInsertPoint(EntryBB);
3425 Builder.SetCurrentDebugLocation(llvm::DebugLoc());
3426
3427 // Need to alloca memory here and deal with the pointers before getting
3428 // LHS/RHS pointers out
3429 Value *LHSArrayPtr = nullptr;
3430 Value *RHSArrayPtr = nullptr;
3431 Argument *Arg0 = ReductionFunc->getArg(0);
3432 Argument *Arg1 = ReductionFunc->getArg(1);
3433 Type *Arg0Type = Arg0->getType();
3434 Type *Arg1Type = Arg1->getType();
3435
3436 Value *LHSAlloca =
3437 Builder.CreateAlloca(Arg0Type, nullptr, Arg0->getName() + ".addr");
3438 Value *RHSAlloca =
3439 Builder.CreateAlloca(Arg1Type, nullptr, Arg1->getName() + ".addr");
3440 Value *LHSAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3441 LHSAlloca, Arg0Type, LHSAlloca->getName() + ".ascast");
3442 Value *RHSAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3443 RHSAlloca, Arg1Type, RHSAlloca->getName() + ".ascast");
3444 Builder.CreateStore(Arg0, LHSAddrCast);
3445 Builder.CreateStore(Arg1, RHSAddrCast);
3446 LHSArrayPtr = Builder.CreateLoad(Arg0Type, LHSAddrCast);
3447 RHSArrayPtr = Builder.CreateLoad(Arg1Type, RHSAddrCast);
3448
3449 Type *RedArrayTy = ArrayType::get(Builder.getPtrTy(), ReductionInfos.size());
3450 Type *IndexTy = Builder.getIndexTy(
3451 M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
3452 SmallVector<Value *> LHSPtrs, RHSPtrs;
3453 for (auto En : enumerate(ReductionInfos)) {
3454 const ReductionInfo &RI = En.value();
3455 Value *RHSI8PtrPtr = Builder.CreateInBoundsGEP(
3456 RedArrayTy, RHSArrayPtr,
3457 {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
3458 Value *RHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), RHSI8PtrPtr);
3459 Value *RHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
3460 RHSI8Ptr, RI.PrivateVariable->getType(),
3461 RHSI8Ptr->getName() + ".ascast");
3462
3463 Value *LHSI8PtrPtr = Builder.CreateInBoundsGEP(
3464 RedArrayTy, LHSArrayPtr,
3465 {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
3466 Value *LHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), LHSI8PtrPtr);
3467 Value *LHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
3468 LHSI8Ptr, RI.Variable->getType(), LHSI8Ptr->getName() + ".ascast");
3469
3470 if (ReductionGenCBKind == ReductionGenCBKind::Clang) {
3471 LHSPtrs.emplace_back(LHSPtr);
3472 RHSPtrs.emplace_back(RHSPtr);
3473 } else {
3474 Value *LHS = Builder.CreateLoad(RI.ElementType, LHSPtr);
3475 Value *RHS = Builder.CreateLoad(RI.ElementType, RHSPtr);
3476 Value *Reduced;
3477 InsertPointOrErrorTy AfterIP =
3478 RI.ReductionGen(Builder.saveIP(), LHS, RHS, Reduced);
3479 if (!AfterIP)
3480 return AfterIP.takeError();
3481 if (!Builder.GetInsertBlock())
3482 return ReductionFunc;
3483 Builder.CreateStore(Reduced, LHSPtr);
3484 }
3485 }
3486
3487 if (ReductionGenCBKind == ReductionGenCBKind::Clang)
3488 for (auto En : enumerate(ReductionInfos)) {
3489 unsigned Index = En.index();
3490 const ReductionInfo &RI = En.value();
3491 Value *LHSFixupPtr, *RHSFixupPtr;
3492 Builder.restoreIP(RI.ReductionGenClang(
3493 Builder.saveIP(), Index, &LHSFixupPtr, &RHSFixupPtr, ReductionFunc));
3494
3495 // Fix the CallBack code genereated to use the correct Values for the LHS
3496 // and RHS
3497 LHSFixupPtr->replaceUsesWithIf(
3498 LHSPtrs[Index], [ReductionFunc](const Use &U) {
3499 return cast<Instruction>(U.getUser())->getParent()->getParent() ==
3500 ReductionFunc;
3501 });
3502 RHSFixupPtr->replaceUsesWithIf(
3503 RHSPtrs[Index], [ReductionFunc](const Use &U) {
3504 return cast<Instruction>(U.getUser())->getParent()->getParent() ==
3505 ReductionFunc;
3506 });
3507 }
3508
3509 Builder.CreateRetVoid();
3510 return ReductionFunc;
3511 }
3512
3513 static void
checkReductionInfos(ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos,bool IsGPU)3514 checkReductionInfos(ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos,
3515 bool IsGPU) {
3516 for (const OpenMPIRBuilder::ReductionInfo &RI : ReductionInfos) {
3517 (void)RI;
3518 assert(RI.Variable && "expected non-null variable");
3519 assert(RI.PrivateVariable && "expected non-null private variable");
3520 assert((RI.ReductionGen || RI.ReductionGenClang) &&
3521 "expected non-null reduction generator callback");
3522 if (!IsGPU) {
3523 assert(
3524 RI.Variable->getType() == RI.PrivateVariable->getType() &&
3525 "expected variables and their private equivalents to have the same "
3526 "type");
3527 }
3528 assert(RI.Variable->getType()->isPointerTy() &&
3529 "expected variables to be pointers");
3530 }
3531 }
3532
createReductionsGPU(const LocationDescription & Loc,InsertPointTy AllocaIP,InsertPointTy CodeGenIP,ArrayRef<ReductionInfo> ReductionInfos,bool IsNoWait,bool IsTeamsReduction,ReductionGenCBKind ReductionGenCBKind,std::optional<omp::GV> GridValue,unsigned ReductionBufNum,Value * SrcLocInfo)3533 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductionsGPU(
3534 const LocationDescription &Loc, InsertPointTy AllocaIP,
3535 InsertPointTy CodeGenIP, ArrayRef<ReductionInfo> ReductionInfos,
3536 bool IsNoWait, bool IsTeamsReduction, ReductionGenCBKind ReductionGenCBKind,
3537 std::optional<omp::GV> GridValue, unsigned ReductionBufNum,
3538 Value *SrcLocInfo) {
3539 if (!updateToLocation(Loc))
3540 return InsertPointTy();
3541 Builder.restoreIP(CodeGenIP);
3542 checkReductionInfos(ReductionInfos, /*IsGPU*/ true);
3543 LLVMContext &Ctx = M.getContext();
3544
3545 // Source location for the ident struct
3546 if (!SrcLocInfo) {
3547 uint32_t SrcLocStrSize;
3548 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3549 SrcLocInfo = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3550 }
3551
3552 if (ReductionInfos.size() == 0)
3553 return Builder.saveIP();
3554
3555 BasicBlock *ContinuationBlock = nullptr;
3556 if (ReductionGenCBKind != ReductionGenCBKind::Clang) {
3557 // Copied code from createReductions
3558 BasicBlock *InsertBlock = Loc.IP.getBlock();
3559 ContinuationBlock =
3560 InsertBlock->splitBasicBlock(Loc.IP.getPoint(), "reduce.finalize");
3561 InsertBlock->getTerminator()->eraseFromParent();
3562 Builder.SetInsertPoint(InsertBlock, InsertBlock->end());
3563 }
3564
3565 Function *CurFunc = Builder.GetInsertBlock()->getParent();
3566 AttributeList FuncAttrs;
3567 AttrBuilder AttrBldr(Ctx);
3568 for (auto Attr : CurFunc->getAttributes().getFnAttrs())
3569 AttrBldr.addAttribute(Attr);
3570 AttrBldr.removeAttribute(Attribute::OptimizeNone);
3571 FuncAttrs = FuncAttrs.addFnAttributes(Ctx, AttrBldr);
3572
3573 CodeGenIP = Builder.saveIP();
3574 Expected<Function *> ReductionResult =
3575 createReductionFunction(Builder.GetInsertBlock()->getParent()->getName(),
3576 ReductionInfos, ReductionGenCBKind, FuncAttrs);
3577 if (!ReductionResult)
3578 return ReductionResult.takeError();
3579 Function *ReductionFunc = *ReductionResult;
3580 Builder.restoreIP(CodeGenIP);
3581
3582 // Set the grid value in the config needed for lowering later on
3583 if (GridValue.has_value())
3584 Config.setGridValue(GridValue.value());
3585 else
3586 Config.setGridValue(getGridValue(T, ReductionFunc));
3587
3588 // Build res = __kmpc_reduce{_nowait}(<gtid>, <n>, sizeof(RedList),
3589 // RedList, shuffle_reduce_func, interwarp_copy_func);
3590 // or
3591 // Build res = __kmpc_reduce_teams_nowait_simple(<loc>, <gtid>, <lck>);
3592 Value *Res;
3593
3594 // 1. Build a list of reduction variables.
3595 // void *RedList[<n>] = {<ReductionVars>[0], ..., <ReductionVars>[<n>-1]};
3596 auto Size = ReductionInfos.size();
3597 Type *PtrTy = PointerType::getUnqual(Ctx);
3598 Type *RedArrayTy = ArrayType::get(PtrTy, Size);
3599 CodeGenIP = Builder.saveIP();
3600 Builder.restoreIP(AllocaIP);
3601 Value *ReductionListAlloca =
3602 Builder.CreateAlloca(RedArrayTy, nullptr, ".omp.reduction.red_list");
3603 Value *ReductionList = Builder.CreatePointerBitCastOrAddrSpaceCast(
3604 ReductionListAlloca, PtrTy, ReductionListAlloca->getName() + ".ascast");
3605 Builder.restoreIP(CodeGenIP);
3606 Type *IndexTy = Builder.getIndexTy(
3607 M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
3608 for (auto En : enumerate(ReductionInfos)) {
3609 const ReductionInfo &RI = En.value();
3610 Value *ElemPtr = Builder.CreateInBoundsGEP(
3611 RedArrayTy, ReductionList,
3612 {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
3613 Value *CastElem =
3614 Builder.CreatePointerBitCastOrAddrSpaceCast(RI.PrivateVariable, PtrTy);
3615 Builder.CreateStore(CastElem, ElemPtr);
3616 }
3617 CodeGenIP = Builder.saveIP();
3618 Function *SarFunc =
3619 emitShuffleAndReduceFunction(ReductionInfos, ReductionFunc, FuncAttrs);
3620 Expected<Function *> CopyResult =
3621 emitInterWarpCopyFunction(Loc, ReductionInfos, FuncAttrs);
3622 if (!CopyResult)
3623 return CopyResult.takeError();
3624 Function *WcFunc = *CopyResult;
3625 Builder.restoreIP(CodeGenIP);
3626
3627 Value *RL = Builder.CreatePointerBitCastOrAddrSpaceCast(ReductionList, PtrTy);
3628
3629 unsigned MaxDataSize = 0;
3630 SmallVector<Type *> ReductionTypeArgs;
3631 for (auto En : enumerate(ReductionInfos)) {
3632 auto Size = M.getDataLayout().getTypeStoreSize(En.value().ElementType);
3633 if (Size > MaxDataSize)
3634 MaxDataSize = Size;
3635 ReductionTypeArgs.emplace_back(En.value().ElementType);
3636 }
3637 Value *ReductionDataSize =
3638 Builder.getInt64(MaxDataSize * ReductionInfos.size());
3639 if (!IsTeamsReduction) {
3640 Value *SarFuncCast =
3641 Builder.CreatePointerBitCastOrAddrSpaceCast(SarFunc, PtrTy);
3642 Value *WcFuncCast =
3643 Builder.CreatePointerBitCastOrAddrSpaceCast(WcFunc, PtrTy);
3644 Value *Args[] = {SrcLocInfo, ReductionDataSize, RL, SarFuncCast,
3645 WcFuncCast};
3646 Function *Pv2Ptr = getOrCreateRuntimeFunctionPtr(
3647 RuntimeFunction::OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2);
3648 Res = Builder.CreateCall(Pv2Ptr, Args);
3649 } else {
3650 CodeGenIP = Builder.saveIP();
3651 StructType *ReductionsBufferTy = StructType::create(
3652 Ctx, ReductionTypeArgs, "struct._globalized_locals_ty");
3653 Function *RedFixedBuferFn = getOrCreateRuntimeFunctionPtr(
3654 RuntimeFunction::OMPRTL___kmpc_reduction_get_fixed_buffer);
3655 Function *LtGCFunc = emitListToGlobalCopyFunction(
3656 ReductionInfos, ReductionsBufferTy, FuncAttrs);
3657 Function *LtGRFunc = emitListToGlobalReduceFunction(
3658 ReductionInfos, ReductionFunc, ReductionsBufferTy, FuncAttrs);
3659 Function *GtLCFunc = emitGlobalToListCopyFunction(
3660 ReductionInfos, ReductionsBufferTy, FuncAttrs);
3661 Function *GtLRFunc = emitGlobalToListReduceFunction(
3662 ReductionInfos, ReductionFunc, ReductionsBufferTy, FuncAttrs);
3663 Builder.restoreIP(CodeGenIP);
3664
3665 Value *KernelTeamsReductionPtr = Builder.CreateCall(
3666 RedFixedBuferFn, {}, "_openmp_teams_reductions_buffer_$_$ptr");
3667
3668 Value *Args3[] = {SrcLocInfo,
3669 KernelTeamsReductionPtr,
3670 Builder.getInt32(ReductionBufNum),
3671 ReductionDataSize,
3672 RL,
3673 SarFunc,
3674 WcFunc,
3675 LtGCFunc,
3676 LtGRFunc,
3677 GtLCFunc,
3678 GtLRFunc};
3679
3680 Function *TeamsReduceFn = getOrCreateRuntimeFunctionPtr(
3681 RuntimeFunction::OMPRTL___kmpc_nvptx_teams_reduce_nowait_v2);
3682 Res = Builder.CreateCall(TeamsReduceFn, Args3);
3683 }
3684
3685 // 5. Build if (res == 1)
3686 BasicBlock *ExitBB = BasicBlock::Create(Ctx, ".omp.reduction.done");
3687 BasicBlock *ThenBB = BasicBlock::Create(Ctx, ".omp.reduction.then");
3688 Value *Cond = Builder.CreateICmpEQ(Res, Builder.getInt32(1));
3689 Builder.CreateCondBr(Cond, ThenBB, ExitBB);
3690
3691 // 6. Build then branch: where we have reduced values in the master
3692 // thread in each team.
3693 // __kmpc_end_reduce{_nowait}(<gtid>);
3694 // break;
3695 emitBlock(ThenBB, CurFunc);
3696
3697 // Add emission of __kmpc_end_reduce{_nowait}(<gtid>);
3698 for (auto En : enumerate(ReductionInfos)) {
3699 const ReductionInfo &RI = En.value();
3700 Value *LHS = RI.Variable;
3701 Value *RHS =
3702 Builder.CreatePointerBitCastOrAddrSpaceCast(RI.PrivateVariable, PtrTy);
3703
3704 if (ReductionGenCBKind == ReductionGenCBKind::Clang) {
3705 Value *LHSPtr, *RHSPtr;
3706 Builder.restoreIP(RI.ReductionGenClang(Builder.saveIP(), En.index(),
3707 &LHSPtr, &RHSPtr, CurFunc));
3708
3709 // Fix the CallBack code genereated to use the correct Values for the LHS
3710 // and RHS
3711 LHSPtr->replaceUsesWithIf(LHS, [ReductionFunc](const Use &U) {
3712 return cast<Instruction>(U.getUser())->getParent()->getParent() ==
3713 ReductionFunc;
3714 });
3715 RHSPtr->replaceUsesWithIf(RHS, [ReductionFunc](const Use &U) {
3716 return cast<Instruction>(U.getUser())->getParent()->getParent() ==
3717 ReductionFunc;
3718 });
3719 } else {
3720 Value *LHSValue = Builder.CreateLoad(RI.ElementType, LHS, "final.lhs");
3721 Value *RHSValue = Builder.CreateLoad(RI.ElementType, RHS, "final.rhs");
3722 Value *Reduced;
3723 InsertPointOrErrorTy AfterIP =
3724 RI.ReductionGen(Builder.saveIP(), RHSValue, LHSValue, Reduced);
3725 if (!AfterIP)
3726 return AfterIP.takeError();
3727 Builder.CreateStore(Reduced, LHS, false);
3728 }
3729 }
3730 emitBlock(ExitBB, CurFunc);
3731 if (ContinuationBlock) {
3732 Builder.CreateBr(ContinuationBlock);
3733 Builder.SetInsertPoint(ContinuationBlock);
3734 }
3735 Config.setEmitLLVMUsed();
3736
3737 return Builder.saveIP();
3738 }
3739
getFreshReductionFunc(Module & M)3740 static Function *getFreshReductionFunc(Module &M) {
3741 Type *VoidTy = Type::getVoidTy(M.getContext());
3742 Type *Int8PtrTy = PointerType::getUnqual(M.getContext());
3743 auto *FuncTy =
3744 FunctionType::get(VoidTy, {Int8PtrTy, Int8PtrTy}, /* IsVarArg */ false);
3745 return Function::Create(FuncTy, GlobalVariable::InternalLinkage,
3746 ".omp.reduction.func", &M);
3747 }
3748
populateReductionFunction(Function * ReductionFunc,ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos,IRBuilder<> & Builder,ArrayRef<bool> IsByRef,bool IsGPU)3749 static Error populateReductionFunction(
3750 Function *ReductionFunc,
3751 ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos,
3752 IRBuilder<> &Builder, ArrayRef<bool> IsByRef, bool IsGPU) {
3753 IRBuilder<>::InsertPointGuard IPG(Builder);
3754 Module *Module = ReductionFunc->getParent();
3755 BasicBlock *ReductionFuncBlock =
3756 BasicBlock::Create(Module->getContext(), "", ReductionFunc);
3757 Builder.SetInsertPoint(ReductionFuncBlock);
3758 Builder.SetCurrentDebugLocation(llvm::DebugLoc());
3759 Value *LHSArrayPtr = nullptr;
3760 Value *RHSArrayPtr = nullptr;
3761 if (IsGPU) {
3762 // Need to alloca memory here and deal with the pointers before getting
3763 // LHS/RHS pointers out
3764 //
3765 Argument *Arg0 = ReductionFunc->getArg(0);
3766 Argument *Arg1 = ReductionFunc->getArg(1);
3767 Type *Arg0Type = Arg0->getType();
3768 Type *Arg1Type = Arg1->getType();
3769
3770 Value *LHSAlloca =
3771 Builder.CreateAlloca(Arg0Type, nullptr, Arg0->getName() + ".addr");
3772 Value *RHSAlloca =
3773 Builder.CreateAlloca(Arg1Type, nullptr, Arg1->getName() + ".addr");
3774 Value *LHSAddrCast =
3775 Builder.CreatePointerBitCastOrAddrSpaceCast(LHSAlloca, Arg0Type);
3776 Value *RHSAddrCast =
3777 Builder.CreatePointerBitCastOrAddrSpaceCast(RHSAlloca, Arg1Type);
3778 Builder.CreateStore(Arg0, LHSAddrCast);
3779 Builder.CreateStore(Arg1, RHSAddrCast);
3780 LHSArrayPtr = Builder.CreateLoad(Arg0Type, LHSAddrCast);
3781 RHSArrayPtr = Builder.CreateLoad(Arg1Type, RHSAddrCast);
3782 } else {
3783 LHSArrayPtr = ReductionFunc->getArg(0);
3784 RHSArrayPtr = ReductionFunc->getArg(1);
3785 }
3786
3787 unsigned NumReductions = ReductionInfos.size();
3788 Type *RedArrayTy = ArrayType::get(Builder.getPtrTy(), NumReductions);
3789
3790 for (auto En : enumerate(ReductionInfos)) {
3791 const OpenMPIRBuilder::ReductionInfo &RI = En.value();
3792 Value *LHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
3793 RedArrayTy, LHSArrayPtr, 0, En.index());
3794 Value *LHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), LHSI8PtrPtr);
3795 Value *LHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
3796 LHSI8Ptr, RI.Variable->getType());
3797 Value *LHS = Builder.CreateLoad(RI.ElementType, LHSPtr);
3798 Value *RHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
3799 RedArrayTy, RHSArrayPtr, 0, En.index());
3800 Value *RHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), RHSI8PtrPtr);
3801 Value *RHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
3802 RHSI8Ptr, RI.PrivateVariable->getType());
3803 Value *RHS = Builder.CreateLoad(RI.ElementType, RHSPtr);
3804 Value *Reduced;
3805 OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
3806 RI.ReductionGen(Builder.saveIP(), LHS, RHS, Reduced);
3807 if (!AfterIP)
3808 return AfterIP.takeError();
3809
3810 Builder.restoreIP(*AfterIP);
3811 // TODO: Consider flagging an error.
3812 if (!Builder.GetInsertBlock())
3813 return Error::success();
3814
3815 // store is inside of the reduction region when using by-ref
3816 if (!IsByRef[En.index()])
3817 Builder.CreateStore(Reduced, LHSPtr);
3818 }
3819 Builder.CreateRetVoid();
3820 return Error::success();
3821 }
3822
createReductions(const LocationDescription & Loc,InsertPointTy AllocaIP,ArrayRef<ReductionInfo> ReductionInfos,ArrayRef<bool> IsByRef,bool IsNoWait,bool IsTeamsReduction)3823 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductions(
3824 const LocationDescription &Loc, InsertPointTy AllocaIP,
3825 ArrayRef<ReductionInfo> ReductionInfos, ArrayRef<bool> IsByRef,
3826 bool IsNoWait, bool IsTeamsReduction) {
3827 assert(ReductionInfos.size() == IsByRef.size());
3828 if (Config.isGPU())
3829 return createReductionsGPU(Loc, AllocaIP, Builder.saveIP(), ReductionInfos,
3830 IsNoWait, IsTeamsReduction);
3831
3832 checkReductionInfos(ReductionInfos, /*IsGPU*/ false);
3833
3834 if (!updateToLocation(Loc))
3835 return InsertPointTy();
3836
3837 if (ReductionInfos.size() == 0)
3838 return Builder.saveIP();
3839
3840 BasicBlock *InsertBlock = Loc.IP.getBlock();
3841 BasicBlock *ContinuationBlock =
3842 InsertBlock->splitBasicBlock(Loc.IP.getPoint(), "reduce.finalize");
3843 InsertBlock->getTerminator()->eraseFromParent();
3844
3845 // Create and populate array of type-erased pointers to private reduction
3846 // values.
3847 unsigned NumReductions = ReductionInfos.size();
3848 Type *RedArrayTy = ArrayType::get(Builder.getPtrTy(), NumReductions);
3849 Builder.SetInsertPoint(AllocaIP.getBlock()->getTerminator());
3850 Value *RedArray = Builder.CreateAlloca(RedArrayTy, nullptr, "red.array");
3851
3852 Builder.SetInsertPoint(InsertBlock, InsertBlock->end());
3853
3854 for (auto En : enumerate(ReductionInfos)) {
3855 unsigned Index = En.index();
3856 const ReductionInfo &RI = En.value();
3857 Value *RedArrayElemPtr = Builder.CreateConstInBoundsGEP2_64(
3858 RedArrayTy, RedArray, 0, Index, "red.array.elem." + Twine(Index));
3859 Builder.CreateStore(RI.PrivateVariable, RedArrayElemPtr);
3860 }
3861
3862 // Emit a call to the runtime function that orchestrates the reduction.
3863 // Declare the reduction function in the process.
3864 Type *IndexTy = Builder.getIndexTy(
3865 M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
3866 Function *Func = Builder.GetInsertBlock()->getParent();
3867 Module *Module = Func->getParent();
3868 uint32_t SrcLocStrSize;
3869 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3870 bool CanGenerateAtomic = all_of(ReductionInfos, [](const ReductionInfo &RI) {
3871 return RI.AtomicReductionGen;
3872 });
3873 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize,
3874 CanGenerateAtomic
3875 ? IdentFlag::OMP_IDENT_FLAG_ATOMIC_REDUCE
3876 : IdentFlag(0));
3877 Value *ThreadId = getOrCreateThreadID(Ident);
3878 Constant *NumVariables = Builder.getInt32(NumReductions);
3879 const DataLayout &DL = Module->getDataLayout();
3880 unsigned RedArrayByteSize = DL.getTypeStoreSize(RedArrayTy);
3881 Constant *RedArraySize = ConstantInt::get(IndexTy, RedArrayByteSize);
3882 Function *ReductionFunc = getFreshReductionFunc(*Module);
3883 Value *Lock = getOMPCriticalRegionLock(".reduction");
3884 Function *ReduceFunc = getOrCreateRuntimeFunctionPtr(
3885 IsNoWait ? RuntimeFunction::OMPRTL___kmpc_reduce_nowait
3886 : RuntimeFunction::OMPRTL___kmpc_reduce);
3887 CallInst *ReduceCall =
3888 Builder.CreateCall(ReduceFunc,
3889 {Ident, ThreadId, NumVariables, RedArraySize, RedArray,
3890 ReductionFunc, Lock},
3891 "reduce");
3892
3893 // Create final reduction entry blocks for the atomic and non-atomic case.
3894 // Emit IR that dispatches control flow to one of the blocks based on the
3895 // reduction supporting the atomic mode.
3896 BasicBlock *NonAtomicRedBlock =
3897 BasicBlock::Create(Module->getContext(), "reduce.switch.nonatomic", Func);
3898 BasicBlock *AtomicRedBlock =
3899 BasicBlock::Create(Module->getContext(), "reduce.switch.atomic", Func);
3900 SwitchInst *Switch =
3901 Builder.CreateSwitch(ReduceCall, ContinuationBlock, /* NumCases */ 2);
3902 Switch->addCase(Builder.getInt32(1), NonAtomicRedBlock);
3903 Switch->addCase(Builder.getInt32(2), AtomicRedBlock);
3904
3905 // Populate the non-atomic reduction using the elementwise reduction function.
3906 // This loads the elements from the global and private variables and reduces
3907 // them before storing back the result to the global variable.
3908 Builder.SetInsertPoint(NonAtomicRedBlock);
3909 for (auto En : enumerate(ReductionInfos)) {
3910 const ReductionInfo &RI = En.value();
3911 Type *ValueType = RI.ElementType;
3912 // We have one less load for by-ref case because that load is now inside of
3913 // the reduction region
3914 Value *RedValue = RI.Variable;
3915 if (!IsByRef[En.index()]) {
3916 RedValue = Builder.CreateLoad(ValueType, RI.Variable,
3917 "red.value." + Twine(En.index()));
3918 }
3919 Value *PrivateRedValue =
3920 Builder.CreateLoad(ValueType, RI.PrivateVariable,
3921 "red.private.value." + Twine(En.index()));
3922 Value *Reduced;
3923 InsertPointOrErrorTy AfterIP =
3924 RI.ReductionGen(Builder.saveIP(), RedValue, PrivateRedValue, Reduced);
3925 if (!AfterIP)
3926 return AfterIP.takeError();
3927 Builder.restoreIP(*AfterIP);
3928
3929 if (!Builder.GetInsertBlock())
3930 return InsertPointTy();
3931 // for by-ref case, the load is inside of the reduction region
3932 if (!IsByRef[En.index()])
3933 Builder.CreateStore(Reduced, RI.Variable);
3934 }
3935 Function *EndReduceFunc = getOrCreateRuntimeFunctionPtr(
3936 IsNoWait ? RuntimeFunction::OMPRTL___kmpc_end_reduce_nowait
3937 : RuntimeFunction::OMPRTL___kmpc_end_reduce);
3938 Builder.CreateCall(EndReduceFunc, {Ident, ThreadId, Lock});
3939 Builder.CreateBr(ContinuationBlock);
3940
3941 // Populate the atomic reduction using the atomic elementwise reduction
3942 // function. There are no loads/stores here because they will be happening
3943 // inside the atomic elementwise reduction.
3944 Builder.SetInsertPoint(AtomicRedBlock);
3945 if (CanGenerateAtomic && llvm::none_of(IsByRef, [](bool P) { return P; })) {
3946 for (const ReductionInfo &RI : ReductionInfos) {
3947 InsertPointOrErrorTy AfterIP = RI.AtomicReductionGen(
3948 Builder.saveIP(), RI.ElementType, RI.Variable, RI.PrivateVariable);
3949 if (!AfterIP)
3950 return AfterIP.takeError();
3951 Builder.restoreIP(*AfterIP);
3952 if (!Builder.GetInsertBlock())
3953 return InsertPointTy();
3954 }
3955 Builder.CreateBr(ContinuationBlock);
3956 } else {
3957 Builder.CreateUnreachable();
3958 }
3959
3960 // Populate the outlined reduction function using the elementwise reduction
3961 // function. Partial values are extracted from the type-erased array of
3962 // pointers to private variables.
3963 Error Err = populateReductionFunction(ReductionFunc, ReductionInfos, Builder,
3964 IsByRef, /*isGPU=*/false);
3965 if (Err)
3966 return Err;
3967
3968 if (!Builder.GetInsertBlock())
3969 return InsertPointTy();
3970
3971 Builder.SetInsertPoint(ContinuationBlock);
3972 return Builder.saveIP();
3973 }
3974
3975 OpenMPIRBuilder::InsertPointOrErrorTy
createMaster(const LocationDescription & Loc,BodyGenCallbackTy BodyGenCB,FinalizeCallbackTy FiniCB)3976 OpenMPIRBuilder::createMaster(const LocationDescription &Loc,
3977 BodyGenCallbackTy BodyGenCB,
3978 FinalizeCallbackTy FiniCB) {
3979 if (!updateToLocation(Loc))
3980 return Loc.IP;
3981
3982 Directive OMPD = Directive::OMPD_master;
3983 uint32_t SrcLocStrSize;
3984 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3985 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3986 Value *ThreadId = getOrCreateThreadID(Ident);
3987 Value *Args[] = {Ident, ThreadId};
3988
3989 Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_master);
3990 Instruction *EntryCall = Builder.CreateCall(EntryRTLFn, Args);
3991
3992 Function *ExitRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_master);
3993 Instruction *ExitCall = Builder.CreateCall(ExitRTLFn, Args);
3994
3995 return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
3996 /*Conditional*/ true, /*hasFinalize*/ true);
3997 }
3998
3999 OpenMPIRBuilder::InsertPointOrErrorTy
createMasked(const LocationDescription & Loc,BodyGenCallbackTy BodyGenCB,FinalizeCallbackTy FiniCB,Value * Filter)4000 OpenMPIRBuilder::createMasked(const LocationDescription &Loc,
4001 BodyGenCallbackTy BodyGenCB,
4002 FinalizeCallbackTy FiniCB, Value *Filter) {
4003 if (!updateToLocation(Loc))
4004 return Loc.IP;
4005
4006 Directive OMPD = Directive::OMPD_masked;
4007 uint32_t SrcLocStrSize;
4008 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4009 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4010 Value *ThreadId = getOrCreateThreadID(Ident);
4011 Value *Args[] = {Ident, ThreadId, Filter};
4012 Value *ArgsEnd[] = {Ident, ThreadId};
4013
4014 Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_masked);
4015 Instruction *EntryCall = Builder.CreateCall(EntryRTLFn, Args);
4016
4017 Function *ExitRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_masked);
4018 Instruction *ExitCall = Builder.CreateCall(ExitRTLFn, ArgsEnd);
4019
4020 return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
4021 /*Conditional*/ true, /*hasFinalize*/ true);
4022 }
4023
createLoopSkeleton(DebugLoc DL,Value * TripCount,Function * F,BasicBlock * PreInsertBefore,BasicBlock * PostInsertBefore,const Twine & Name)4024 CanonicalLoopInfo *OpenMPIRBuilder::createLoopSkeleton(
4025 DebugLoc DL, Value *TripCount, Function *F, BasicBlock *PreInsertBefore,
4026 BasicBlock *PostInsertBefore, const Twine &Name) {
4027 Module *M = F->getParent();
4028 LLVMContext &Ctx = M->getContext();
4029 Type *IndVarTy = TripCount->getType();
4030
4031 // Create the basic block structure.
4032 BasicBlock *Preheader =
4033 BasicBlock::Create(Ctx, "omp_" + Name + ".preheader", F, PreInsertBefore);
4034 BasicBlock *Header =
4035 BasicBlock::Create(Ctx, "omp_" + Name + ".header", F, PreInsertBefore);
4036 BasicBlock *Cond =
4037 BasicBlock::Create(Ctx, "omp_" + Name + ".cond", F, PreInsertBefore);
4038 BasicBlock *Body =
4039 BasicBlock::Create(Ctx, "omp_" + Name + ".body", F, PreInsertBefore);
4040 BasicBlock *Latch =
4041 BasicBlock::Create(Ctx, "omp_" + Name + ".inc", F, PostInsertBefore);
4042 BasicBlock *Exit =
4043 BasicBlock::Create(Ctx, "omp_" + Name + ".exit", F, PostInsertBefore);
4044 BasicBlock *After =
4045 BasicBlock::Create(Ctx, "omp_" + Name + ".after", F, PostInsertBefore);
4046
4047 // Use specified DebugLoc for new instructions.
4048 Builder.SetCurrentDebugLocation(DL);
4049
4050 Builder.SetInsertPoint(Preheader);
4051 Builder.CreateBr(Header);
4052
4053 Builder.SetInsertPoint(Header);
4054 PHINode *IndVarPHI = Builder.CreatePHI(IndVarTy, 2, "omp_" + Name + ".iv");
4055 IndVarPHI->addIncoming(ConstantInt::get(IndVarTy, 0), Preheader);
4056 Builder.CreateBr(Cond);
4057
4058 Builder.SetInsertPoint(Cond);
4059 Value *Cmp =
4060 Builder.CreateICmpULT(IndVarPHI, TripCount, "omp_" + Name + ".cmp");
4061 Builder.CreateCondBr(Cmp, Body, Exit);
4062
4063 Builder.SetInsertPoint(Body);
4064 Builder.CreateBr(Latch);
4065
4066 Builder.SetInsertPoint(Latch);
4067 Value *Next = Builder.CreateAdd(IndVarPHI, ConstantInt::get(IndVarTy, 1),
4068 "omp_" + Name + ".next", /*HasNUW=*/true);
4069 Builder.CreateBr(Header);
4070 IndVarPHI->addIncoming(Next, Latch);
4071
4072 Builder.SetInsertPoint(Exit);
4073 Builder.CreateBr(After);
4074
4075 // Remember and return the canonical control flow.
4076 LoopInfos.emplace_front();
4077 CanonicalLoopInfo *CL = &LoopInfos.front();
4078
4079 CL->Header = Header;
4080 CL->Cond = Cond;
4081 CL->Latch = Latch;
4082 CL->Exit = Exit;
4083
4084 #ifndef NDEBUG
4085 CL->assertOK();
4086 #endif
4087 return CL;
4088 }
4089
4090 Expected<CanonicalLoopInfo *>
createCanonicalLoop(const LocationDescription & Loc,LoopBodyGenCallbackTy BodyGenCB,Value * TripCount,const Twine & Name)4091 OpenMPIRBuilder::createCanonicalLoop(const LocationDescription &Loc,
4092 LoopBodyGenCallbackTy BodyGenCB,
4093 Value *TripCount, const Twine &Name) {
4094 BasicBlock *BB = Loc.IP.getBlock();
4095 BasicBlock *NextBB = BB->getNextNode();
4096
4097 CanonicalLoopInfo *CL = createLoopSkeleton(Loc.DL, TripCount, BB->getParent(),
4098 NextBB, NextBB, Name);
4099 BasicBlock *After = CL->getAfter();
4100
4101 // If location is not set, don't connect the loop.
4102 if (updateToLocation(Loc)) {
4103 // Split the loop at the insertion point: Branch to the preheader and move
4104 // every following instruction to after the loop (the After BB). Also, the
4105 // new successor is the loop's after block.
4106 spliceBB(Builder, After, /*CreateBranch=*/false);
4107 Builder.CreateBr(CL->getPreheader());
4108 }
4109
4110 // Emit the body content. We do it after connecting the loop to the CFG to
4111 // avoid that the callback encounters degenerate BBs.
4112 if (Error Err = BodyGenCB(CL->getBodyIP(), CL->getIndVar()))
4113 return Err;
4114
4115 #ifndef NDEBUG
4116 CL->assertOK();
4117 #endif
4118 return CL;
4119 }
4120
calculateCanonicalLoopTripCount(const LocationDescription & Loc,Value * Start,Value * Stop,Value * Step,bool IsSigned,bool InclusiveStop,const Twine & Name)4121 Value *OpenMPIRBuilder::calculateCanonicalLoopTripCount(
4122 const LocationDescription &Loc, Value *Start, Value *Stop, Value *Step,
4123 bool IsSigned, bool InclusiveStop, const Twine &Name) {
4124
4125 // Consider the following difficulties (assuming 8-bit signed integers):
4126 // * Adding \p Step to the loop counter which passes \p Stop may overflow:
4127 // DO I = 1, 100, 50
4128 /// * A \p Step of INT_MIN cannot not be normalized to a positive direction:
4129 // DO I = 100, 0, -128
4130
4131 // Start, Stop and Step must be of the same integer type.
4132 auto *IndVarTy = cast<IntegerType>(Start->getType());
4133 assert(IndVarTy == Stop->getType() && "Stop type mismatch");
4134 assert(IndVarTy == Step->getType() && "Step type mismatch");
4135
4136 updateToLocation(Loc);
4137
4138 ConstantInt *Zero = ConstantInt::get(IndVarTy, 0);
4139 ConstantInt *One = ConstantInt::get(IndVarTy, 1);
4140
4141 // Like Step, but always positive.
4142 Value *Incr = Step;
4143
4144 // Distance between Start and Stop; always positive.
4145 Value *Span;
4146
4147 // Condition whether there are no iterations are executed at all, e.g. because
4148 // UB < LB.
4149 Value *ZeroCmp;
4150
4151 if (IsSigned) {
4152 // Ensure that increment is positive. If not, negate and invert LB and UB.
4153 Value *IsNeg = Builder.CreateICmpSLT(Step, Zero);
4154 Incr = Builder.CreateSelect(IsNeg, Builder.CreateNeg(Step), Step);
4155 Value *LB = Builder.CreateSelect(IsNeg, Stop, Start);
4156 Value *UB = Builder.CreateSelect(IsNeg, Start, Stop);
4157 Span = Builder.CreateSub(UB, LB, "", false, true);
4158 ZeroCmp = Builder.CreateICmp(
4159 InclusiveStop ? CmpInst::ICMP_SLT : CmpInst::ICMP_SLE, UB, LB);
4160 } else {
4161 Span = Builder.CreateSub(Stop, Start, "", true);
4162 ZeroCmp = Builder.CreateICmp(
4163 InclusiveStop ? CmpInst::ICMP_ULT : CmpInst::ICMP_ULE, Stop, Start);
4164 }
4165
4166 Value *CountIfLooping;
4167 if (InclusiveStop) {
4168 CountIfLooping = Builder.CreateAdd(Builder.CreateUDiv(Span, Incr), One);
4169 } else {
4170 // Avoid incrementing past stop since it could overflow.
4171 Value *CountIfTwo = Builder.CreateAdd(
4172 Builder.CreateUDiv(Builder.CreateSub(Span, One), Incr), One);
4173 Value *OneCmp = Builder.CreateICmp(CmpInst::ICMP_ULE, Span, Incr);
4174 CountIfLooping = Builder.CreateSelect(OneCmp, One, CountIfTwo);
4175 }
4176
4177 return Builder.CreateSelect(ZeroCmp, Zero, CountIfLooping,
4178 "omp_" + Name + ".tripcount");
4179 }
4180
createCanonicalLoop(const LocationDescription & Loc,LoopBodyGenCallbackTy BodyGenCB,Value * Start,Value * Stop,Value * Step,bool IsSigned,bool InclusiveStop,InsertPointTy ComputeIP,const Twine & Name)4181 Expected<CanonicalLoopInfo *> OpenMPIRBuilder::createCanonicalLoop(
4182 const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
4183 Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
4184 InsertPointTy ComputeIP, const Twine &Name) {
4185 LocationDescription ComputeLoc =
4186 ComputeIP.isSet() ? LocationDescription(ComputeIP, Loc.DL) : Loc;
4187
4188 Value *TripCount = calculateCanonicalLoopTripCount(
4189 ComputeLoc, Start, Stop, Step, IsSigned, InclusiveStop, Name);
4190
4191 auto BodyGen = [=](InsertPointTy CodeGenIP, Value *IV) {
4192 Builder.restoreIP(CodeGenIP);
4193 Value *Span = Builder.CreateMul(IV, Step);
4194 Value *IndVar = Builder.CreateAdd(Span, Start);
4195 return BodyGenCB(Builder.saveIP(), IndVar);
4196 };
4197 LocationDescription LoopLoc =
4198 ComputeIP.isSet()
4199 ? Loc
4200 : LocationDescription(Builder.saveIP(),
4201 Builder.getCurrentDebugLocation());
4202 return createCanonicalLoop(LoopLoc, BodyGen, TripCount, Name);
4203 }
4204
4205 // Returns an LLVM function to call for initializing loop bounds using OpenMP
4206 // static scheduling for composite `distribute parallel for` depending on
4207 // `type`. Only i32 and i64 are supported by the runtime. Always interpret
4208 // integers as unsigned similarly to CanonicalLoopInfo.
4209 static FunctionCallee
getKmpcDistForStaticInitForType(Type * Ty,Module & M,OpenMPIRBuilder & OMPBuilder)4210 getKmpcDistForStaticInitForType(Type *Ty, Module &M,
4211 OpenMPIRBuilder &OMPBuilder) {
4212 unsigned Bitwidth = Ty->getIntegerBitWidth();
4213 if (Bitwidth == 32)
4214 return OMPBuilder.getOrCreateRuntimeFunction(
4215 M, omp::RuntimeFunction::OMPRTL___kmpc_dist_for_static_init_4u);
4216 if (Bitwidth == 64)
4217 return OMPBuilder.getOrCreateRuntimeFunction(
4218 M, omp::RuntimeFunction::OMPRTL___kmpc_dist_for_static_init_8u);
4219 llvm_unreachable("unknown OpenMP loop iterator bitwidth");
4220 }
4221
4222 // Returns an LLVM function to call for initializing loop bounds using OpenMP
4223 // static scheduling depending on `type`. Only i32 and i64 are supported by the
4224 // runtime. Always interpret integers as unsigned similarly to
4225 // CanonicalLoopInfo.
getKmpcForStaticInitForType(Type * Ty,Module & M,OpenMPIRBuilder & OMPBuilder)4226 static FunctionCallee getKmpcForStaticInitForType(Type *Ty, Module &M,
4227 OpenMPIRBuilder &OMPBuilder) {
4228 unsigned Bitwidth = Ty->getIntegerBitWidth();
4229 if (Bitwidth == 32)
4230 return OMPBuilder.getOrCreateRuntimeFunction(
4231 M, omp::RuntimeFunction::OMPRTL___kmpc_for_static_init_4u);
4232 if (Bitwidth == 64)
4233 return OMPBuilder.getOrCreateRuntimeFunction(
4234 M, omp::RuntimeFunction::OMPRTL___kmpc_for_static_init_8u);
4235 llvm_unreachable("unknown OpenMP loop iterator bitwidth");
4236 }
4237
applyStaticWorkshareLoop(DebugLoc DL,CanonicalLoopInfo * CLI,InsertPointTy AllocaIP,WorksharingLoopType LoopType,bool NeedsBarrier)4238 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::applyStaticWorkshareLoop(
4239 DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
4240 WorksharingLoopType LoopType, bool NeedsBarrier) {
4241 assert(CLI->isValid() && "Requires a valid canonical loop");
4242 assert(!isConflictIP(AllocaIP, CLI->getPreheaderIP()) &&
4243 "Require dedicated allocate IP");
4244
4245 // Set up the source location value for OpenMP runtime.
4246 Builder.restoreIP(CLI->getPreheaderIP());
4247 Builder.SetCurrentDebugLocation(DL);
4248
4249 uint32_t SrcLocStrSize;
4250 Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
4251 Value *SrcLoc = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4252
4253 // Declare useful OpenMP runtime functions.
4254 Value *IV = CLI->getIndVar();
4255 Type *IVTy = IV->getType();
4256 FunctionCallee StaticInit =
4257 LoopType == WorksharingLoopType::DistributeForStaticLoop
4258 ? getKmpcDistForStaticInitForType(IVTy, M, *this)
4259 : getKmpcForStaticInitForType(IVTy, M, *this);
4260 FunctionCallee StaticFini =
4261 getOrCreateRuntimeFunction(M, omp::OMPRTL___kmpc_for_static_fini);
4262
4263 // Allocate space for computed loop bounds as expected by the "init" function.
4264 Builder.SetInsertPoint(AllocaIP.getBlock()->getFirstNonPHIOrDbgOrAlloca());
4265
4266 Type *I32Type = Type::getInt32Ty(M.getContext());
4267 Value *PLastIter = Builder.CreateAlloca(I32Type, nullptr, "p.lastiter");
4268 Value *PLowerBound = Builder.CreateAlloca(IVTy, nullptr, "p.lowerbound");
4269 Value *PUpperBound = Builder.CreateAlloca(IVTy, nullptr, "p.upperbound");
4270 Value *PStride = Builder.CreateAlloca(IVTy, nullptr, "p.stride");
4271 CLI->setLastIter(PLastIter);
4272
4273 // At the end of the preheader, prepare for calling the "init" function by
4274 // storing the current loop bounds into the allocated space. A canonical loop
4275 // always iterates from 0 to trip-count with step 1. Note that "init" expects
4276 // and produces an inclusive upper bound.
4277 Builder.SetInsertPoint(CLI->getPreheader()->getTerminator());
4278 Constant *Zero = ConstantInt::get(IVTy, 0);
4279 Constant *One = ConstantInt::get(IVTy, 1);
4280 Builder.CreateStore(Zero, PLowerBound);
4281 Value *UpperBound = Builder.CreateSub(CLI->getTripCount(), One);
4282 Builder.CreateStore(UpperBound, PUpperBound);
4283 Builder.CreateStore(One, PStride);
4284
4285 Value *ThreadNum = getOrCreateThreadID(SrcLoc);
4286
4287 OMPScheduleType SchedType =
4288 (LoopType == WorksharingLoopType::DistributeStaticLoop)
4289 ? OMPScheduleType::OrderedDistribute
4290 : OMPScheduleType::UnorderedStatic;
4291 Constant *SchedulingType =
4292 ConstantInt::get(I32Type, static_cast<int>(SchedType));
4293
4294 // Call the "init" function and update the trip count of the loop with the
4295 // value it produced.
4296 SmallVector<Value *, 10> Args(
4297 {SrcLoc, ThreadNum, SchedulingType, PLastIter, PLowerBound, PUpperBound});
4298 if (LoopType == WorksharingLoopType::DistributeForStaticLoop) {
4299 Value *PDistUpperBound =
4300 Builder.CreateAlloca(IVTy, nullptr, "p.distupperbound");
4301 Args.push_back(PDistUpperBound);
4302 }
4303 Args.append({PStride, One, Zero});
4304 Builder.CreateCall(StaticInit, Args);
4305 Value *LowerBound = Builder.CreateLoad(IVTy, PLowerBound);
4306 Value *InclusiveUpperBound = Builder.CreateLoad(IVTy, PUpperBound);
4307 Value *TripCountMinusOne = Builder.CreateSub(InclusiveUpperBound, LowerBound);
4308 Value *TripCount = Builder.CreateAdd(TripCountMinusOne, One);
4309 CLI->setTripCount(TripCount);
4310
4311 // Update all uses of the induction variable except the one in the condition
4312 // block that compares it with the actual upper bound, and the increment in
4313 // the latch block.
4314
4315 CLI->mapIndVar([&](Instruction *OldIV) -> Value * {
4316 Builder.SetInsertPoint(CLI->getBody(),
4317 CLI->getBody()->getFirstInsertionPt());
4318 Builder.SetCurrentDebugLocation(DL);
4319 return Builder.CreateAdd(OldIV, LowerBound);
4320 });
4321
4322 // In the "exit" block, call the "fini" function.
4323 Builder.SetInsertPoint(CLI->getExit(),
4324 CLI->getExit()->getTerminator()->getIterator());
4325 Builder.CreateCall(StaticFini, {SrcLoc, ThreadNum});
4326
4327 // Add the barrier if requested.
4328 if (NeedsBarrier) {
4329 InsertPointOrErrorTy BarrierIP =
4330 createBarrier(LocationDescription(Builder.saveIP(), DL),
4331 omp::Directive::OMPD_for, /* ForceSimpleCall */ false,
4332 /* CheckCancelFlag */ false);
4333 if (!BarrierIP)
4334 return BarrierIP.takeError();
4335 }
4336
4337 InsertPointTy AfterIP = CLI->getAfterIP();
4338 CLI->invalidate();
4339
4340 return AfterIP;
4341 }
4342
4343 OpenMPIRBuilder::InsertPointOrErrorTy
applyStaticChunkedWorkshareLoop(DebugLoc DL,CanonicalLoopInfo * CLI,InsertPointTy AllocaIP,bool NeedsBarrier,Value * ChunkSize)4344 OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(DebugLoc DL,
4345 CanonicalLoopInfo *CLI,
4346 InsertPointTy AllocaIP,
4347 bool NeedsBarrier,
4348 Value *ChunkSize) {
4349 assert(CLI->isValid() && "Requires a valid canonical loop");
4350 assert(ChunkSize && "Chunk size is required");
4351
4352 LLVMContext &Ctx = CLI->getFunction()->getContext();
4353 Value *IV = CLI->getIndVar();
4354 Value *OrigTripCount = CLI->getTripCount();
4355 Type *IVTy = IV->getType();
4356 assert(IVTy->getIntegerBitWidth() <= 64 &&
4357 "Max supported tripcount bitwidth is 64 bits");
4358 Type *InternalIVTy = IVTy->getIntegerBitWidth() <= 32 ? Type::getInt32Ty(Ctx)
4359 : Type::getInt64Ty(Ctx);
4360 Type *I32Type = Type::getInt32Ty(M.getContext());
4361 Constant *Zero = ConstantInt::get(InternalIVTy, 0);
4362 Constant *One = ConstantInt::get(InternalIVTy, 1);
4363
4364 // Declare useful OpenMP runtime functions.
4365 FunctionCallee StaticInit =
4366 getKmpcForStaticInitForType(InternalIVTy, M, *this);
4367 FunctionCallee StaticFini =
4368 getOrCreateRuntimeFunction(M, omp::OMPRTL___kmpc_for_static_fini);
4369
4370 // Allocate space for computed loop bounds as expected by the "init" function.
4371 Builder.restoreIP(AllocaIP);
4372 Builder.SetCurrentDebugLocation(DL);
4373 Value *PLastIter = Builder.CreateAlloca(I32Type, nullptr, "p.lastiter");
4374 Value *PLowerBound =
4375 Builder.CreateAlloca(InternalIVTy, nullptr, "p.lowerbound");
4376 Value *PUpperBound =
4377 Builder.CreateAlloca(InternalIVTy, nullptr, "p.upperbound");
4378 Value *PStride = Builder.CreateAlloca(InternalIVTy, nullptr, "p.stride");
4379 CLI->setLastIter(PLastIter);
4380
4381 // Set up the source location value for the OpenMP runtime.
4382 Builder.restoreIP(CLI->getPreheaderIP());
4383 Builder.SetCurrentDebugLocation(DL);
4384
4385 // TODO: Detect overflow in ubsan or max-out with current tripcount.
4386 Value *CastedChunkSize =
4387 Builder.CreateZExtOrTrunc(ChunkSize, InternalIVTy, "chunksize");
4388 Value *CastedTripCount =
4389 Builder.CreateZExt(OrigTripCount, InternalIVTy, "tripcount");
4390
4391 Constant *SchedulingType = ConstantInt::get(
4392 I32Type, static_cast<int>(OMPScheduleType::UnorderedStaticChunked));
4393 Builder.CreateStore(Zero, PLowerBound);
4394 Value *OrigUpperBound = Builder.CreateSub(CastedTripCount, One);
4395 Builder.CreateStore(OrigUpperBound, PUpperBound);
4396 Builder.CreateStore(One, PStride);
4397
4398 // Call the "init" function and update the trip count of the loop with the
4399 // value it produced.
4400 uint32_t SrcLocStrSize;
4401 Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
4402 Value *SrcLoc = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4403 Value *ThreadNum = getOrCreateThreadID(SrcLoc);
4404 Builder.CreateCall(StaticInit,
4405 {/*loc=*/SrcLoc, /*global_tid=*/ThreadNum,
4406 /*schedtype=*/SchedulingType, /*plastiter=*/PLastIter,
4407 /*plower=*/PLowerBound, /*pupper=*/PUpperBound,
4408 /*pstride=*/PStride, /*incr=*/One,
4409 /*chunk=*/CastedChunkSize});
4410
4411 // Load values written by the "init" function.
4412 Value *FirstChunkStart =
4413 Builder.CreateLoad(InternalIVTy, PLowerBound, "omp_firstchunk.lb");
4414 Value *FirstChunkStop =
4415 Builder.CreateLoad(InternalIVTy, PUpperBound, "omp_firstchunk.ub");
4416 Value *FirstChunkEnd = Builder.CreateAdd(FirstChunkStop, One);
4417 Value *ChunkRange =
4418 Builder.CreateSub(FirstChunkEnd, FirstChunkStart, "omp_chunk.range");
4419 Value *NextChunkStride =
4420 Builder.CreateLoad(InternalIVTy, PStride, "omp_dispatch.stride");
4421
4422 // Create outer "dispatch" loop for enumerating the chunks.
4423 BasicBlock *DispatchEnter = splitBB(Builder, true);
4424 Value *DispatchCounter;
4425
4426 // It is safe to assume this didn't return an error because the callback
4427 // passed into createCanonicalLoop is the only possible error source, and it
4428 // always returns success.
4429 CanonicalLoopInfo *DispatchCLI = cantFail(createCanonicalLoop(
4430 {Builder.saveIP(), DL},
4431 [&](InsertPointTy BodyIP, Value *Counter) {
4432 DispatchCounter = Counter;
4433 return Error::success();
4434 },
4435 FirstChunkStart, CastedTripCount, NextChunkStride,
4436 /*IsSigned=*/false, /*InclusiveStop=*/false, /*ComputeIP=*/{},
4437 "dispatch"));
4438
4439 // Remember the BasicBlocks of the dispatch loop we need, then invalidate to
4440 // not have to preserve the canonical invariant.
4441 BasicBlock *DispatchBody = DispatchCLI->getBody();
4442 BasicBlock *DispatchLatch = DispatchCLI->getLatch();
4443 BasicBlock *DispatchExit = DispatchCLI->getExit();
4444 BasicBlock *DispatchAfter = DispatchCLI->getAfter();
4445 DispatchCLI->invalidate();
4446
4447 // Rewire the original loop to become the chunk loop inside the dispatch loop.
4448 redirectTo(DispatchAfter, CLI->getAfter(), DL);
4449 redirectTo(CLI->getExit(), DispatchLatch, DL);
4450 redirectTo(DispatchBody, DispatchEnter, DL);
4451
4452 // Prepare the prolog of the chunk loop.
4453 Builder.restoreIP(CLI->getPreheaderIP());
4454 Builder.SetCurrentDebugLocation(DL);
4455
4456 // Compute the number of iterations of the chunk loop.
4457 Builder.SetInsertPoint(CLI->getPreheader()->getTerminator());
4458 Value *ChunkEnd = Builder.CreateAdd(DispatchCounter, ChunkRange);
4459 Value *IsLastChunk =
4460 Builder.CreateICmpUGE(ChunkEnd, CastedTripCount, "omp_chunk.is_last");
4461 Value *CountUntilOrigTripCount =
4462 Builder.CreateSub(CastedTripCount, DispatchCounter);
4463 Value *ChunkTripCount = Builder.CreateSelect(
4464 IsLastChunk, CountUntilOrigTripCount, ChunkRange, "omp_chunk.tripcount");
4465 Value *BackcastedChunkTC =
4466 Builder.CreateTrunc(ChunkTripCount, IVTy, "omp_chunk.tripcount.trunc");
4467 CLI->setTripCount(BackcastedChunkTC);
4468
4469 // Update all uses of the induction variable except the one in the condition
4470 // block that compares it with the actual upper bound, and the increment in
4471 // the latch block.
4472 Value *BackcastedDispatchCounter =
4473 Builder.CreateTrunc(DispatchCounter, IVTy, "omp_dispatch.iv.trunc");
4474 CLI->mapIndVar([&](Instruction *) -> Value * {
4475 Builder.restoreIP(CLI->getBodyIP());
4476 return Builder.CreateAdd(IV, BackcastedDispatchCounter);
4477 });
4478
4479 // In the "exit" block, call the "fini" function.
4480 Builder.SetInsertPoint(DispatchExit, DispatchExit->getFirstInsertionPt());
4481 Builder.CreateCall(StaticFini, {SrcLoc, ThreadNum});
4482
4483 // Add the barrier if requested.
4484 if (NeedsBarrier) {
4485 InsertPointOrErrorTy AfterIP =
4486 createBarrier(LocationDescription(Builder.saveIP(), DL), OMPD_for,
4487 /*ForceSimpleCall=*/false, /*CheckCancelFlag=*/false);
4488 if (!AfterIP)
4489 return AfterIP.takeError();
4490 }
4491
4492 #ifndef NDEBUG
4493 // Even though we currently do not support applying additional methods to it,
4494 // the chunk loop should remain a canonical loop.
4495 CLI->assertOK();
4496 #endif
4497
4498 return InsertPointTy(DispatchAfter, DispatchAfter->getFirstInsertionPt());
4499 }
4500
4501 // Returns an LLVM function to call for executing an OpenMP static worksharing
4502 // for loop depending on `type`. Only i32 and i64 are supported by the runtime.
4503 // Always interpret integers as unsigned similarly to CanonicalLoopInfo.
4504 static FunctionCallee
getKmpcForStaticLoopForType(Type * Ty,OpenMPIRBuilder * OMPBuilder,WorksharingLoopType LoopType)4505 getKmpcForStaticLoopForType(Type *Ty, OpenMPIRBuilder *OMPBuilder,
4506 WorksharingLoopType LoopType) {
4507 unsigned Bitwidth = Ty->getIntegerBitWidth();
4508 Module &M = OMPBuilder->M;
4509 switch (LoopType) {
4510 case WorksharingLoopType::ForStaticLoop:
4511 if (Bitwidth == 32)
4512 return OMPBuilder->getOrCreateRuntimeFunction(
4513 M, omp::RuntimeFunction::OMPRTL___kmpc_for_static_loop_4u);
4514 if (Bitwidth == 64)
4515 return OMPBuilder->getOrCreateRuntimeFunction(
4516 M, omp::RuntimeFunction::OMPRTL___kmpc_for_static_loop_8u);
4517 break;
4518 case WorksharingLoopType::DistributeStaticLoop:
4519 if (Bitwidth == 32)
4520 return OMPBuilder->getOrCreateRuntimeFunction(
4521 M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_static_loop_4u);
4522 if (Bitwidth == 64)
4523 return OMPBuilder->getOrCreateRuntimeFunction(
4524 M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_static_loop_8u);
4525 break;
4526 case WorksharingLoopType::DistributeForStaticLoop:
4527 if (Bitwidth == 32)
4528 return OMPBuilder->getOrCreateRuntimeFunction(
4529 M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_for_static_loop_4u);
4530 if (Bitwidth == 64)
4531 return OMPBuilder->getOrCreateRuntimeFunction(
4532 M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_for_static_loop_8u);
4533 break;
4534 }
4535 if (Bitwidth != 32 && Bitwidth != 64) {
4536 llvm_unreachable("Unknown OpenMP loop iterator bitwidth");
4537 }
4538 llvm_unreachable("Unknown type of OpenMP worksharing loop");
4539 }
4540
4541 // Inserts a call to proper OpenMP Device RTL function which handles
4542 // loop worksharing.
createTargetLoopWorkshareCall(OpenMPIRBuilder * OMPBuilder,WorksharingLoopType LoopType,BasicBlock * InsertBlock,Value * Ident,Value * LoopBodyArg,Value * TripCount,Function & LoopBodyFn)4543 static void createTargetLoopWorkshareCall(OpenMPIRBuilder *OMPBuilder,
4544 WorksharingLoopType LoopType,
4545 BasicBlock *InsertBlock, Value *Ident,
4546 Value *LoopBodyArg, Value *TripCount,
4547 Function &LoopBodyFn) {
4548 Type *TripCountTy = TripCount->getType();
4549 Module &M = OMPBuilder->M;
4550 IRBuilder<> &Builder = OMPBuilder->Builder;
4551 FunctionCallee RTLFn =
4552 getKmpcForStaticLoopForType(TripCountTy, OMPBuilder, LoopType);
4553 SmallVector<Value *, 8> RealArgs;
4554 RealArgs.push_back(Ident);
4555 RealArgs.push_back(&LoopBodyFn);
4556 RealArgs.push_back(LoopBodyArg);
4557 RealArgs.push_back(TripCount);
4558 if (LoopType == WorksharingLoopType::DistributeStaticLoop) {
4559 RealArgs.push_back(ConstantInt::get(TripCountTy, 0));
4560 Builder.restoreIP({InsertBlock, std::prev(InsertBlock->end())});
4561 Builder.CreateCall(RTLFn, RealArgs);
4562 return;
4563 }
4564 FunctionCallee RTLNumThreads = OMPBuilder->getOrCreateRuntimeFunction(
4565 M, omp::RuntimeFunction::OMPRTL_omp_get_num_threads);
4566 Builder.restoreIP({InsertBlock, std::prev(InsertBlock->end())});
4567 Value *NumThreads = Builder.CreateCall(RTLNumThreads, {});
4568
4569 RealArgs.push_back(
4570 Builder.CreateZExtOrTrunc(NumThreads, TripCountTy, "num.threads.cast"));
4571 RealArgs.push_back(ConstantInt::get(TripCountTy, 0));
4572 if (LoopType == WorksharingLoopType::DistributeForStaticLoop) {
4573 RealArgs.push_back(ConstantInt::get(TripCountTy, 0));
4574 }
4575
4576 Builder.CreateCall(RTLFn, RealArgs);
4577 }
4578
workshareLoopTargetCallback(OpenMPIRBuilder * OMPIRBuilder,CanonicalLoopInfo * CLI,Value * Ident,Function & OutlinedFn,const SmallVector<Instruction *,4> & ToBeDeleted,WorksharingLoopType LoopType)4579 static void workshareLoopTargetCallback(
4580 OpenMPIRBuilder *OMPIRBuilder, CanonicalLoopInfo *CLI, Value *Ident,
4581 Function &OutlinedFn, const SmallVector<Instruction *, 4> &ToBeDeleted,
4582 WorksharingLoopType LoopType) {
4583 IRBuilder<> &Builder = OMPIRBuilder->Builder;
4584 BasicBlock *Preheader = CLI->getPreheader();
4585 Value *TripCount = CLI->getTripCount();
4586
4587 // After loop body outling, the loop body contains only set up
4588 // of loop body argument structure and the call to the outlined
4589 // loop body function. Firstly, we need to move setup of loop body args
4590 // into loop preheader.
4591 Preheader->splice(std::prev(Preheader->end()), CLI->getBody(),
4592 CLI->getBody()->begin(), std::prev(CLI->getBody()->end()));
4593
4594 // The next step is to remove the whole loop. We do not it need anymore.
4595 // That's why make an unconditional branch from loop preheader to loop
4596 // exit block
4597 Builder.restoreIP({Preheader, Preheader->end()});
4598 Builder.SetCurrentDebugLocation(Preheader->getTerminator()->getDebugLoc());
4599 Preheader->getTerminator()->eraseFromParent();
4600 Builder.CreateBr(CLI->getExit());
4601
4602 // Delete dead loop blocks
4603 OpenMPIRBuilder::OutlineInfo CleanUpInfo;
4604 SmallPtrSet<BasicBlock *, 32> RegionBlockSet;
4605 SmallVector<BasicBlock *, 32> BlocksToBeRemoved;
4606 CleanUpInfo.EntryBB = CLI->getHeader();
4607 CleanUpInfo.ExitBB = CLI->getExit();
4608 CleanUpInfo.collectBlocks(RegionBlockSet, BlocksToBeRemoved);
4609 DeleteDeadBlocks(BlocksToBeRemoved);
4610
4611 // Find the instruction which corresponds to loop body argument structure
4612 // and remove the call to loop body function instruction.
4613 Value *LoopBodyArg;
4614 User *OutlinedFnUser = OutlinedFn.getUniqueUndroppableUser();
4615 assert(OutlinedFnUser &&
4616 "Expected unique undroppable user of outlined function");
4617 CallInst *OutlinedFnCallInstruction = dyn_cast<CallInst>(OutlinedFnUser);
4618 assert(OutlinedFnCallInstruction && "Expected outlined function call");
4619 assert((OutlinedFnCallInstruction->getParent() == Preheader) &&
4620 "Expected outlined function call to be located in loop preheader");
4621 // Check in case no argument structure has been passed.
4622 if (OutlinedFnCallInstruction->arg_size() > 1)
4623 LoopBodyArg = OutlinedFnCallInstruction->getArgOperand(1);
4624 else
4625 LoopBodyArg = Constant::getNullValue(Builder.getPtrTy());
4626 OutlinedFnCallInstruction->eraseFromParent();
4627
4628 createTargetLoopWorkshareCall(OMPIRBuilder, LoopType, Preheader, Ident,
4629 LoopBodyArg, TripCount, OutlinedFn);
4630
4631 for (auto &ToBeDeletedItem : ToBeDeleted)
4632 ToBeDeletedItem->eraseFromParent();
4633 CLI->invalidate();
4634 }
4635
4636 OpenMPIRBuilder::InsertPointTy
applyWorkshareLoopTarget(DebugLoc DL,CanonicalLoopInfo * CLI,InsertPointTy AllocaIP,WorksharingLoopType LoopType)4637 OpenMPIRBuilder::applyWorkshareLoopTarget(DebugLoc DL, CanonicalLoopInfo *CLI,
4638 InsertPointTy AllocaIP,
4639 WorksharingLoopType LoopType) {
4640 uint32_t SrcLocStrSize;
4641 Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
4642 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4643
4644 OutlineInfo OI;
4645 OI.OuterAllocaBB = CLI->getPreheader();
4646 Function *OuterFn = CLI->getPreheader()->getParent();
4647
4648 // Instructions which need to be deleted at the end of code generation
4649 SmallVector<Instruction *, 4> ToBeDeleted;
4650
4651 OI.OuterAllocaBB = AllocaIP.getBlock();
4652
4653 // Mark the body loop as region which needs to be extracted
4654 OI.EntryBB = CLI->getBody();
4655 OI.ExitBB = CLI->getLatch()->splitBasicBlock(CLI->getLatch()->begin(),
4656 "omp.prelatch", true);
4657
4658 // Prepare loop body for extraction
4659 Builder.restoreIP({CLI->getPreheader(), CLI->getPreheader()->begin()});
4660
4661 // Insert new loop counter variable which will be used only in loop
4662 // body.
4663 AllocaInst *NewLoopCnt = Builder.CreateAlloca(CLI->getIndVarType(), 0, "");
4664 Instruction *NewLoopCntLoad =
4665 Builder.CreateLoad(CLI->getIndVarType(), NewLoopCnt);
4666 // New loop counter instructions are redundant in the loop preheader when
4667 // code generation for workshare loop is finshed. That's why mark them as
4668 // ready for deletion.
4669 ToBeDeleted.push_back(NewLoopCntLoad);
4670 ToBeDeleted.push_back(NewLoopCnt);
4671
4672 // Analyse loop body region. Find all input variables which are used inside
4673 // loop body region.
4674 SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
4675 SmallVector<BasicBlock *, 32> Blocks;
4676 OI.collectBlocks(ParallelRegionBlockSet, Blocks);
4677
4678 CodeExtractorAnalysisCache CEAC(*OuterFn);
4679 CodeExtractor Extractor(Blocks,
4680 /* DominatorTree */ nullptr,
4681 /* AggregateArgs */ true,
4682 /* BlockFrequencyInfo */ nullptr,
4683 /* BranchProbabilityInfo */ nullptr,
4684 /* AssumptionCache */ nullptr,
4685 /* AllowVarArgs */ true,
4686 /* AllowAlloca */ true,
4687 /* AllocationBlock */ CLI->getPreheader(),
4688 /* Suffix */ ".omp_wsloop",
4689 /* AggrArgsIn0AddrSpace */ true);
4690
4691 BasicBlock *CommonExit = nullptr;
4692 SetVector<Value *> SinkingCands, HoistingCands;
4693
4694 // Find allocas outside the loop body region which are used inside loop
4695 // body
4696 Extractor.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
4697
4698 // We need to model loop body region as the function f(cnt, loop_arg).
4699 // That's why we replace loop induction variable by the new counter
4700 // which will be one of loop body function argument
4701 SmallVector<User *> Users(CLI->getIndVar()->user_begin(),
4702 CLI->getIndVar()->user_end());
4703 for (auto Use : Users) {
4704 if (Instruction *Inst = dyn_cast<Instruction>(Use)) {
4705 if (ParallelRegionBlockSet.count(Inst->getParent())) {
4706 Inst->replaceUsesOfWith(CLI->getIndVar(), NewLoopCntLoad);
4707 }
4708 }
4709 }
4710 // Make sure that loop counter variable is not merged into loop body
4711 // function argument structure and it is passed as separate variable
4712 OI.ExcludeArgsFromAggregate.push_back(NewLoopCntLoad);
4713
4714 // PostOutline CB is invoked when loop body function is outlined and
4715 // loop body is replaced by call to outlined function. We need to add
4716 // call to OpenMP device rtl inside loop preheader. OpenMP device rtl
4717 // function will handle loop control logic.
4718 //
4719 OI.PostOutlineCB = [=, ToBeDeletedVec =
4720 std::move(ToBeDeleted)](Function &OutlinedFn) {
4721 workshareLoopTargetCallback(this, CLI, Ident, OutlinedFn, ToBeDeletedVec,
4722 LoopType);
4723 };
4724 addOutlineInfo(std::move(OI));
4725 return CLI->getAfterIP();
4726 }
4727
applyWorkshareLoop(DebugLoc DL,CanonicalLoopInfo * CLI,InsertPointTy AllocaIP,bool NeedsBarrier,omp::ScheduleKind SchedKind,Value * ChunkSize,bool HasSimdModifier,bool HasMonotonicModifier,bool HasNonmonotonicModifier,bool HasOrderedClause,WorksharingLoopType LoopType)4728 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::applyWorkshareLoop(
4729 DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
4730 bool NeedsBarrier, omp::ScheduleKind SchedKind, Value *ChunkSize,
4731 bool HasSimdModifier, bool HasMonotonicModifier,
4732 bool HasNonmonotonicModifier, bool HasOrderedClause,
4733 WorksharingLoopType LoopType) {
4734 if (Config.isTargetDevice())
4735 return applyWorkshareLoopTarget(DL, CLI, AllocaIP, LoopType);
4736 OMPScheduleType EffectiveScheduleType = computeOpenMPScheduleType(
4737 SchedKind, ChunkSize, HasSimdModifier, HasMonotonicModifier,
4738 HasNonmonotonicModifier, HasOrderedClause);
4739
4740 bool IsOrdered = (EffectiveScheduleType & OMPScheduleType::ModifierOrdered) ==
4741 OMPScheduleType::ModifierOrdered;
4742 switch (EffectiveScheduleType & ~OMPScheduleType::ModifierMask) {
4743 case OMPScheduleType::BaseStatic:
4744 assert(!ChunkSize && "No chunk size with static-chunked schedule");
4745 if (IsOrdered)
4746 return applyDynamicWorkshareLoop(DL, CLI, AllocaIP, EffectiveScheduleType,
4747 NeedsBarrier, ChunkSize);
4748 // FIXME: Monotonicity ignored?
4749 return applyStaticWorkshareLoop(DL, CLI, AllocaIP, LoopType, NeedsBarrier);
4750
4751 case OMPScheduleType::BaseStaticChunked:
4752 if (IsOrdered)
4753 return applyDynamicWorkshareLoop(DL, CLI, AllocaIP, EffectiveScheduleType,
4754 NeedsBarrier, ChunkSize);
4755 // FIXME: Monotonicity ignored?
4756 return applyStaticChunkedWorkshareLoop(DL, CLI, AllocaIP, NeedsBarrier,
4757 ChunkSize);
4758
4759 case OMPScheduleType::BaseRuntime:
4760 case OMPScheduleType::BaseAuto:
4761 case OMPScheduleType::BaseGreedy:
4762 case OMPScheduleType::BaseBalanced:
4763 case OMPScheduleType::BaseSteal:
4764 case OMPScheduleType::BaseGuidedSimd:
4765 case OMPScheduleType::BaseRuntimeSimd:
4766 assert(!ChunkSize &&
4767 "schedule type does not support user-defined chunk sizes");
4768 [[fallthrough]];
4769 case OMPScheduleType::BaseDynamicChunked:
4770 case OMPScheduleType::BaseGuidedChunked:
4771 case OMPScheduleType::BaseGuidedIterativeChunked:
4772 case OMPScheduleType::BaseGuidedAnalyticalChunked:
4773 case OMPScheduleType::BaseStaticBalancedChunked:
4774 return applyDynamicWorkshareLoop(DL, CLI, AllocaIP, EffectiveScheduleType,
4775 NeedsBarrier, ChunkSize);
4776
4777 default:
4778 llvm_unreachable("Unknown/unimplemented schedule kind");
4779 }
4780 }
4781
4782 /// Returns an LLVM function to call for initializing loop bounds using OpenMP
4783 /// dynamic scheduling depending on `type`. Only i32 and i64 are supported by
4784 /// the runtime. Always interpret integers as unsigned similarly to
4785 /// CanonicalLoopInfo.
4786 static FunctionCallee
getKmpcForDynamicInitForType(Type * Ty,Module & M,OpenMPIRBuilder & OMPBuilder)4787 getKmpcForDynamicInitForType(Type *Ty, Module &M, OpenMPIRBuilder &OMPBuilder) {
4788 unsigned Bitwidth = Ty->getIntegerBitWidth();
4789 if (Bitwidth == 32)
4790 return OMPBuilder.getOrCreateRuntimeFunction(
4791 M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_init_4u);
4792 if (Bitwidth == 64)
4793 return OMPBuilder.getOrCreateRuntimeFunction(
4794 M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_init_8u);
4795 llvm_unreachable("unknown OpenMP loop iterator bitwidth");
4796 }
4797
4798 /// Returns an LLVM function to call for updating the next loop using OpenMP
4799 /// dynamic scheduling depending on `type`. Only i32 and i64 are supported by
4800 /// the runtime. Always interpret integers as unsigned similarly to
4801 /// CanonicalLoopInfo.
4802 static FunctionCallee
getKmpcForDynamicNextForType(Type * Ty,Module & M,OpenMPIRBuilder & OMPBuilder)4803 getKmpcForDynamicNextForType(Type *Ty, Module &M, OpenMPIRBuilder &OMPBuilder) {
4804 unsigned Bitwidth = Ty->getIntegerBitWidth();
4805 if (Bitwidth == 32)
4806 return OMPBuilder.getOrCreateRuntimeFunction(
4807 M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_next_4u);
4808 if (Bitwidth == 64)
4809 return OMPBuilder.getOrCreateRuntimeFunction(
4810 M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_next_8u);
4811 llvm_unreachable("unknown OpenMP loop iterator bitwidth");
4812 }
4813
4814 /// Returns an LLVM function to call for finalizing the dynamic loop using
4815 /// depending on `type`. Only i32 and i64 are supported by the runtime. Always
4816 /// interpret integers as unsigned similarly to CanonicalLoopInfo.
4817 static FunctionCallee
getKmpcForDynamicFiniForType(Type * Ty,Module & M,OpenMPIRBuilder & OMPBuilder)4818 getKmpcForDynamicFiniForType(Type *Ty, Module &M, OpenMPIRBuilder &OMPBuilder) {
4819 unsigned Bitwidth = Ty->getIntegerBitWidth();
4820 if (Bitwidth == 32)
4821 return OMPBuilder.getOrCreateRuntimeFunction(
4822 M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_fini_4u);
4823 if (Bitwidth == 64)
4824 return OMPBuilder.getOrCreateRuntimeFunction(
4825 M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_fini_8u);
4826 llvm_unreachable("unknown OpenMP loop iterator bitwidth");
4827 }
4828
4829 OpenMPIRBuilder::InsertPointOrErrorTy
applyDynamicWorkshareLoop(DebugLoc DL,CanonicalLoopInfo * CLI,InsertPointTy AllocaIP,OMPScheduleType SchedType,bool NeedsBarrier,Value * Chunk)4830 OpenMPIRBuilder::applyDynamicWorkshareLoop(DebugLoc DL, CanonicalLoopInfo *CLI,
4831 InsertPointTy AllocaIP,
4832 OMPScheduleType SchedType,
4833 bool NeedsBarrier, Value *Chunk) {
4834 assert(CLI->isValid() && "Requires a valid canonical loop");
4835 assert(!isConflictIP(AllocaIP, CLI->getPreheaderIP()) &&
4836 "Require dedicated allocate IP");
4837 assert(isValidWorkshareLoopScheduleType(SchedType) &&
4838 "Require valid schedule type");
4839
4840 bool Ordered = (SchedType & OMPScheduleType::ModifierOrdered) ==
4841 OMPScheduleType::ModifierOrdered;
4842
4843 // Set up the source location value for OpenMP runtime.
4844 Builder.SetCurrentDebugLocation(DL);
4845
4846 uint32_t SrcLocStrSize;
4847 Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
4848 Value *SrcLoc = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4849
4850 // Declare useful OpenMP runtime functions.
4851 Value *IV = CLI->getIndVar();
4852 Type *IVTy = IV->getType();
4853 FunctionCallee DynamicInit = getKmpcForDynamicInitForType(IVTy, M, *this);
4854 FunctionCallee DynamicNext = getKmpcForDynamicNextForType(IVTy, M, *this);
4855
4856 // Allocate space for computed loop bounds as expected by the "init" function.
4857 Builder.SetInsertPoint(AllocaIP.getBlock()->getFirstNonPHIOrDbgOrAlloca());
4858 Type *I32Type = Type::getInt32Ty(M.getContext());
4859 Value *PLastIter = Builder.CreateAlloca(I32Type, nullptr, "p.lastiter");
4860 Value *PLowerBound = Builder.CreateAlloca(IVTy, nullptr, "p.lowerbound");
4861 Value *PUpperBound = Builder.CreateAlloca(IVTy, nullptr, "p.upperbound");
4862 Value *PStride = Builder.CreateAlloca(IVTy, nullptr, "p.stride");
4863 CLI->setLastIter(PLastIter);
4864
4865 // At the end of the preheader, prepare for calling the "init" function by
4866 // storing the current loop bounds into the allocated space. A canonical loop
4867 // always iterates from 0 to trip-count with step 1. Note that "init" expects
4868 // and produces an inclusive upper bound.
4869 BasicBlock *PreHeader = CLI->getPreheader();
4870 Builder.SetInsertPoint(PreHeader->getTerminator());
4871 Constant *One = ConstantInt::get(IVTy, 1);
4872 Builder.CreateStore(One, PLowerBound);
4873 Value *UpperBound = CLI->getTripCount();
4874 Builder.CreateStore(UpperBound, PUpperBound);
4875 Builder.CreateStore(One, PStride);
4876
4877 BasicBlock *Header = CLI->getHeader();
4878 BasicBlock *Exit = CLI->getExit();
4879 BasicBlock *Cond = CLI->getCond();
4880 BasicBlock *Latch = CLI->getLatch();
4881 InsertPointTy AfterIP = CLI->getAfterIP();
4882
4883 // The CLI will be "broken" in the code below, as the loop is no longer
4884 // a valid canonical loop.
4885
4886 if (!Chunk)
4887 Chunk = One;
4888
4889 Value *ThreadNum = getOrCreateThreadID(SrcLoc);
4890
4891 Constant *SchedulingType =
4892 ConstantInt::get(I32Type, static_cast<int>(SchedType));
4893
4894 // Call the "init" function.
4895 Builder.CreateCall(DynamicInit,
4896 {SrcLoc, ThreadNum, SchedulingType, /* LowerBound */ One,
4897 UpperBound, /* step */ One, Chunk});
4898
4899 // An outer loop around the existing one.
4900 BasicBlock *OuterCond = BasicBlock::Create(
4901 PreHeader->getContext(), Twine(PreHeader->getName()) + ".outer.cond",
4902 PreHeader->getParent());
4903 // This needs to be 32-bit always, so can't use the IVTy Zero above.
4904 Builder.SetInsertPoint(OuterCond, OuterCond->getFirstInsertionPt());
4905 Value *Res =
4906 Builder.CreateCall(DynamicNext, {SrcLoc, ThreadNum, PLastIter,
4907 PLowerBound, PUpperBound, PStride});
4908 Constant *Zero32 = ConstantInt::get(I32Type, 0);
4909 Value *MoreWork = Builder.CreateCmp(CmpInst::ICMP_NE, Res, Zero32);
4910 Value *LowerBound =
4911 Builder.CreateSub(Builder.CreateLoad(IVTy, PLowerBound), One, "lb");
4912 Builder.CreateCondBr(MoreWork, Header, Exit);
4913
4914 // Change PHI-node in loop header to use outer cond rather than preheader,
4915 // and set IV to the LowerBound.
4916 Instruction *Phi = &Header->front();
4917 auto *PI = cast<PHINode>(Phi);
4918 PI->setIncomingBlock(0, OuterCond);
4919 PI->setIncomingValue(0, LowerBound);
4920
4921 // Then set the pre-header to jump to the OuterCond
4922 Instruction *Term = PreHeader->getTerminator();
4923 auto *Br = cast<BranchInst>(Term);
4924 Br->setSuccessor(0, OuterCond);
4925
4926 // Modify the inner condition:
4927 // * Use the UpperBound returned from the DynamicNext call.
4928 // * jump to the loop outer loop when done with one of the inner loops.
4929 Builder.SetInsertPoint(Cond, Cond->getFirstInsertionPt());
4930 UpperBound = Builder.CreateLoad(IVTy, PUpperBound, "ub");
4931 Instruction *Comp = &*Builder.GetInsertPoint();
4932 auto *CI = cast<CmpInst>(Comp);
4933 CI->setOperand(1, UpperBound);
4934 // Redirect the inner exit to branch to outer condition.
4935 Instruction *Branch = &Cond->back();
4936 auto *BI = cast<BranchInst>(Branch);
4937 assert(BI->getSuccessor(1) == Exit);
4938 BI->setSuccessor(1, OuterCond);
4939
4940 // Call the "fini" function if "ordered" is present in wsloop directive.
4941 if (Ordered) {
4942 Builder.SetInsertPoint(&Latch->back());
4943 FunctionCallee DynamicFini = getKmpcForDynamicFiniForType(IVTy, M, *this);
4944 Builder.CreateCall(DynamicFini, {SrcLoc, ThreadNum});
4945 }
4946
4947 // Add the barrier if requested.
4948 if (NeedsBarrier) {
4949 Builder.SetInsertPoint(&Exit->back());
4950 InsertPointOrErrorTy BarrierIP =
4951 createBarrier(LocationDescription(Builder.saveIP(), DL),
4952 omp::Directive::OMPD_for, /* ForceSimpleCall */ false,
4953 /* CheckCancelFlag */ false);
4954 if (!BarrierIP)
4955 return BarrierIP.takeError();
4956 }
4957
4958 CLI->invalidate();
4959 return AfterIP;
4960 }
4961
4962 /// Redirect all edges that branch to \p OldTarget to \p NewTarget. That is,
4963 /// after this \p OldTarget will be orphaned.
redirectAllPredecessorsTo(BasicBlock * OldTarget,BasicBlock * NewTarget,DebugLoc DL)4964 static void redirectAllPredecessorsTo(BasicBlock *OldTarget,
4965 BasicBlock *NewTarget, DebugLoc DL) {
4966 for (BasicBlock *Pred : make_early_inc_range(predecessors(OldTarget)))
4967 redirectTo(Pred, NewTarget, DL);
4968 }
4969
4970 /// Determine which blocks in \p BBs are reachable from outside and remove the
4971 /// ones that are not reachable from the function.
removeUnusedBlocksFromParent(ArrayRef<BasicBlock * > BBs)4972 static void removeUnusedBlocksFromParent(ArrayRef<BasicBlock *> BBs) {
4973 SmallPtrSet<BasicBlock *, 6> BBsToErase(llvm::from_range, BBs);
4974 auto HasRemainingUses = [&BBsToErase](BasicBlock *BB) {
4975 for (Use &U : BB->uses()) {
4976 auto *UseInst = dyn_cast<Instruction>(U.getUser());
4977 if (!UseInst)
4978 continue;
4979 if (BBsToErase.count(UseInst->getParent()))
4980 continue;
4981 return true;
4982 }
4983 return false;
4984 };
4985
4986 while (BBsToErase.remove_if(HasRemainingUses)) {
4987 // Try again if anything was removed.
4988 }
4989
4990 SmallVector<BasicBlock *, 7> BBVec(BBsToErase.begin(), BBsToErase.end());
4991 DeleteDeadBlocks(BBVec);
4992 }
4993
4994 CanonicalLoopInfo *
collapseLoops(DebugLoc DL,ArrayRef<CanonicalLoopInfo * > Loops,InsertPointTy ComputeIP)4995 OpenMPIRBuilder::collapseLoops(DebugLoc DL, ArrayRef<CanonicalLoopInfo *> Loops,
4996 InsertPointTy ComputeIP) {
4997 assert(Loops.size() >= 1 && "At least one loop required");
4998 size_t NumLoops = Loops.size();
4999
5000 // Nothing to do if there is already just one loop.
5001 if (NumLoops == 1)
5002 return Loops.front();
5003
5004 CanonicalLoopInfo *Outermost = Loops.front();
5005 CanonicalLoopInfo *Innermost = Loops.back();
5006 BasicBlock *OrigPreheader = Outermost->getPreheader();
5007 BasicBlock *OrigAfter = Outermost->getAfter();
5008 Function *F = OrigPreheader->getParent();
5009
5010 // Loop control blocks that may become orphaned later.
5011 SmallVector<BasicBlock *, 12> OldControlBBs;
5012 OldControlBBs.reserve(6 * Loops.size());
5013 for (CanonicalLoopInfo *Loop : Loops)
5014 Loop->collectControlBlocks(OldControlBBs);
5015
5016 // Setup the IRBuilder for inserting the trip count computation.
5017 Builder.SetCurrentDebugLocation(DL);
5018 if (ComputeIP.isSet())
5019 Builder.restoreIP(ComputeIP);
5020 else
5021 Builder.restoreIP(Outermost->getPreheaderIP());
5022
5023 // Derive the collapsed' loop trip count.
5024 // TODO: Find common/largest indvar type.
5025 Value *CollapsedTripCount = nullptr;
5026 for (CanonicalLoopInfo *L : Loops) {
5027 assert(L->isValid() &&
5028 "All loops to collapse must be valid canonical loops");
5029 Value *OrigTripCount = L->getTripCount();
5030 if (!CollapsedTripCount) {
5031 CollapsedTripCount = OrigTripCount;
5032 continue;
5033 }
5034
5035 // TODO: Enable UndefinedSanitizer to diagnose an overflow here.
5036 CollapsedTripCount = Builder.CreateMul(CollapsedTripCount, OrigTripCount,
5037 {}, /*HasNUW=*/true);
5038 }
5039
5040 // Create the collapsed loop control flow.
5041 CanonicalLoopInfo *Result =
5042 createLoopSkeleton(DL, CollapsedTripCount, F,
5043 OrigPreheader->getNextNode(), OrigAfter, "collapsed");
5044
5045 // Build the collapsed loop body code.
5046 // Start with deriving the input loop induction variables from the collapsed
5047 // one, using a divmod scheme. To preserve the original loops' order, the
5048 // innermost loop use the least significant bits.
5049 Builder.restoreIP(Result->getBodyIP());
5050
5051 Value *Leftover = Result->getIndVar();
5052 SmallVector<Value *> NewIndVars;
5053 NewIndVars.resize(NumLoops);
5054 for (int i = NumLoops - 1; i >= 1; --i) {
5055 Value *OrigTripCount = Loops[i]->getTripCount();
5056
5057 Value *NewIndVar = Builder.CreateURem(Leftover, OrigTripCount);
5058 NewIndVars[i] = NewIndVar;
5059
5060 Leftover = Builder.CreateUDiv(Leftover, OrigTripCount);
5061 }
5062 // Outermost loop gets all the remaining bits.
5063 NewIndVars[0] = Leftover;
5064
5065 // Construct the loop body control flow.
5066 // We progressively construct the branch structure following in direction of
5067 // the control flow, from the leading in-between code, the loop nest body, the
5068 // trailing in-between code, and rejoining the collapsed loop's latch.
5069 // ContinueBlock and ContinuePred keep track of the source(s) of next edge. If
5070 // the ContinueBlock is set, continue with that block. If ContinuePred, use
5071 // its predecessors as sources.
5072 BasicBlock *ContinueBlock = Result->getBody();
5073 BasicBlock *ContinuePred = nullptr;
5074 auto ContinueWith = [&ContinueBlock, &ContinuePred, DL](BasicBlock *Dest,
5075 BasicBlock *NextSrc) {
5076 if (ContinueBlock)
5077 redirectTo(ContinueBlock, Dest, DL);
5078 else
5079 redirectAllPredecessorsTo(ContinuePred, Dest, DL);
5080
5081 ContinueBlock = nullptr;
5082 ContinuePred = NextSrc;
5083 };
5084
5085 // The code before the nested loop of each level.
5086 // Because we are sinking it into the nest, it will be executed more often
5087 // that the original loop. More sophisticated schemes could keep track of what
5088 // the in-between code is and instantiate it only once per thread.
5089 for (size_t i = 0; i < NumLoops - 1; ++i)
5090 ContinueWith(Loops[i]->getBody(), Loops[i + 1]->getHeader());
5091
5092 // Connect the loop nest body.
5093 ContinueWith(Innermost->getBody(), Innermost->getLatch());
5094
5095 // The code after the nested loop at each level.
5096 for (size_t i = NumLoops - 1; i > 0; --i)
5097 ContinueWith(Loops[i]->getAfter(), Loops[i - 1]->getLatch());
5098
5099 // Connect the finished loop to the collapsed loop latch.
5100 ContinueWith(Result->getLatch(), nullptr);
5101
5102 // Replace the input loops with the new collapsed loop.
5103 redirectTo(Outermost->getPreheader(), Result->getPreheader(), DL);
5104 redirectTo(Result->getAfter(), Outermost->getAfter(), DL);
5105
5106 // Replace the input loop indvars with the derived ones.
5107 for (size_t i = 0; i < NumLoops; ++i)
5108 Loops[i]->getIndVar()->replaceAllUsesWith(NewIndVars[i]);
5109
5110 // Remove unused parts of the input loops.
5111 removeUnusedBlocksFromParent(OldControlBBs);
5112
5113 for (CanonicalLoopInfo *L : Loops)
5114 L->invalidate();
5115
5116 #ifndef NDEBUG
5117 Result->assertOK();
5118 #endif
5119 return Result;
5120 }
5121
5122 std::vector<CanonicalLoopInfo *>
tileLoops(DebugLoc DL,ArrayRef<CanonicalLoopInfo * > Loops,ArrayRef<Value * > TileSizes)5123 OpenMPIRBuilder::tileLoops(DebugLoc DL, ArrayRef<CanonicalLoopInfo *> Loops,
5124 ArrayRef<Value *> TileSizes) {
5125 assert(TileSizes.size() == Loops.size() &&
5126 "Must pass as many tile sizes as there are loops");
5127 int NumLoops = Loops.size();
5128 assert(NumLoops >= 1 && "At least one loop to tile required");
5129
5130 CanonicalLoopInfo *OutermostLoop = Loops.front();
5131 CanonicalLoopInfo *InnermostLoop = Loops.back();
5132 Function *F = OutermostLoop->getBody()->getParent();
5133 BasicBlock *InnerEnter = InnermostLoop->getBody();
5134 BasicBlock *InnerLatch = InnermostLoop->getLatch();
5135
5136 // Loop control blocks that may become orphaned later.
5137 SmallVector<BasicBlock *, 12> OldControlBBs;
5138 OldControlBBs.reserve(6 * Loops.size());
5139 for (CanonicalLoopInfo *Loop : Loops)
5140 Loop->collectControlBlocks(OldControlBBs);
5141
5142 // Collect original trip counts and induction variable to be accessible by
5143 // index. Also, the structure of the original loops is not preserved during
5144 // the construction of the tiled loops, so do it before we scavenge the BBs of
5145 // any original CanonicalLoopInfo.
5146 SmallVector<Value *, 4> OrigTripCounts, OrigIndVars;
5147 for (CanonicalLoopInfo *L : Loops) {
5148 assert(L->isValid() && "All input loops must be valid canonical loops");
5149 OrigTripCounts.push_back(L->getTripCount());
5150 OrigIndVars.push_back(L->getIndVar());
5151 }
5152
5153 // Collect the code between loop headers. These may contain SSA definitions
5154 // that are used in the loop nest body. To be usable with in the innermost
5155 // body, these BasicBlocks will be sunk into the loop nest body. That is,
5156 // these instructions may be executed more often than before the tiling.
5157 // TODO: It would be sufficient to only sink them into body of the
5158 // corresponding tile loop.
5159 SmallVector<std::pair<BasicBlock *, BasicBlock *>, 4> InbetweenCode;
5160 for (int i = 0; i < NumLoops - 1; ++i) {
5161 CanonicalLoopInfo *Surrounding = Loops[i];
5162 CanonicalLoopInfo *Nested = Loops[i + 1];
5163
5164 BasicBlock *EnterBB = Surrounding->getBody();
5165 BasicBlock *ExitBB = Nested->getHeader();
5166 InbetweenCode.emplace_back(EnterBB, ExitBB);
5167 }
5168
5169 // Compute the trip counts of the floor loops.
5170 Builder.SetCurrentDebugLocation(DL);
5171 Builder.restoreIP(OutermostLoop->getPreheaderIP());
5172 SmallVector<Value *, 4> FloorCount, FloorRems;
5173 for (int i = 0; i < NumLoops; ++i) {
5174 Value *TileSize = TileSizes[i];
5175 Value *OrigTripCount = OrigTripCounts[i];
5176 Type *IVType = OrigTripCount->getType();
5177
5178 Value *FloorTripCount = Builder.CreateUDiv(OrigTripCount, TileSize);
5179 Value *FloorTripRem = Builder.CreateURem(OrigTripCount, TileSize);
5180
5181 // 0 if tripcount divides the tilesize, 1 otherwise.
5182 // 1 means we need an additional iteration for a partial tile.
5183 //
5184 // Unfortunately we cannot just use the roundup-formula
5185 // (tripcount + tilesize - 1)/tilesize
5186 // because the summation might overflow. We do not want introduce undefined
5187 // behavior when the untiled loop nest did not.
5188 Value *FloorTripOverflow =
5189 Builder.CreateICmpNE(FloorTripRem, ConstantInt::get(IVType, 0));
5190
5191 FloorTripOverflow = Builder.CreateZExt(FloorTripOverflow, IVType);
5192 FloorTripCount =
5193 Builder.CreateAdd(FloorTripCount, FloorTripOverflow,
5194 "omp_floor" + Twine(i) + ".tripcount", true);
5195
5196 // Remember some values for later use.
5197 FloorCount.push_back(FloorTripCount);
5198 FloorRems.push_back(FloorTripRem);
5199 }
5200
5201 // Generate the new loop nest, from the outermost to the innermost.
5202 std::vector<CanonicalLoopInfo *> Result;
5203 Result.reserve(NumLoops * 2);
5204
5205 // The basic block of the surrounding loop that enters the nest generated
5206 // loop.
5207 BasicBlock *Enter = OutermostLoop->getPreheader();
5208
5209 // The basic block of the surrounding loop where the inner code should
5210 // continue.
5211 BasicBlock *Continue = OutermostLoop->getAfter();
5212
5213 // Where the next loop basic block should be inserted.
5214 BasicBlock *OutroInsertBefore = InnermostLoop->getExit();
5215
5216 auto EmbeddNewLoop =
5217 [this, DL, F, InnerEnter, &Enter, &Continue, &OutroInsertBefore](
5218 Value *TripCount, const Twine &Name) -> CanonicalLoopInfo * {
5219 CanonicalLoopInfo *EmbeddedLoop = createLoopSkeleton(
5220 DL, TripCount, F, InnerEnter, OutroInsertBefore, Name);
5221 redirectTo(Enter, EmbeddedLoop->getPreheader(), DL);
5222 redirectTo(EmbeddedLoop->getAfter(), Continue, DL);
5223
5224 // Setup the position where the next embedded loop connects to this loop.
5225 Enter = EmbeddedLoop->getBody();
5226 Continue = EmbeddedLoop->getLatch();
5227 OutroInsertBefore = EmbeddedLoop->getLatch();
5228 return EmbeddedLoop;
5229 };
5230
5231 auto EmbeddNewLoops = [&Result, &EmbeddNewLoop](ArrayRef<Value *> TripCounts,
5232 const Twine &NameBase) {
5233 for (auto P : enumerate(TripCounts)) {
5234 CanonicalLoopInfo *EmbeddedLoop =
5235 EmbeddNewLoop(P.value(), NameBase + Twine(P.index()));
5236 Result.push_back(EmbeddedLoop);
5237 }
5238 };
5239
5240 EmbeddNewLoops(FloorCount, "floor");
5241
5242 // Within the innermost floor loop, emit the code that computes the tile
5243 // sizes.
5244 Builder.SetInsertPoint(Enter->getTerminator());
5245 SmallVector<Value *, 4> TileCounts;
5246 for (int i = 0; i < NumLoops; ++i) {
5247 CanonicalLoopInfo *FloorLoop = Result[i];
5248 Value *TileSize = TileSizes[i];
5249
5250 Value *FloorIsEpilogue =
5251 Builder.CreateICmpEQ(FloorLoop->getIndVar(), FloorCount[i]);
5252 Value *TileTripCount =
5253 Builder.CreateSelect(FloorIsEpilogue, FloorRems[i], TileSize);
5254
5255 TileCounts.push_back(TileTripCount);
5256 }
5257
5258 // Create the tile loops.
5259 EmbeddNewLoops(TileCounts, "tile");
5260
5261 // Insert the inbetween code into the body.
5262 BasicBlock *BodyEnter = Enter;
5263 BasicBlock *BodyEntered = nullptr;
5264 for (std::pair<BasicBlock *, BasicBlock *> P : InbetweenCode) {
5265 BasicBlock *EnterBB = P.first;
5266 BasicBlock *ExitBB = P.second;
5267
5268 if (BodyEnter)
5269 redirectTo(BodyEnter, EnterBB, DL);
5270 else
5271 redirectAllPredecessorsTo(BodyEntered, EnterBB, DL);
5272
5273 BodyEnter = nullptr;
5274 BodyEntered = ExitBB;
5275 }
5276
5277 // Append the original loop nest body into the generated loop nest body.
5278 if (BodyEnter)
5279 redirectTo(BodyEnter, InnerEnter, DL);
5280 else
5281 redirectAllPredecessorsTo(BodyEntered, InnerEnter, DL);
5282 redirectAllPredecessorsTo(InnerLatch, Continue, DL);
5283
5284 // Replace the original induction variable with an induction variable computed
5285 // from the tile and floor induction variables.
5286 Builder.restoreIP(Result.back()->getBodyIP());
5287 for (int i = 0; i < NumLoops; ++i) {
5288 CanonicalLoopInfo *FloorLoop = Result[i];
5289 CanonicalLoopInfo *TileLoop = Result[NumLoops + i];
5290 Value *OrigIndVar = OrigIndVars[i];
5291 Value *Size = TileSizes[i];
5292
5293 Value *Scale =
5294 Builder.CreateMul(Size, FloorLoop->getIndVar(), {}, /*HasNUW=*/true);
5295 Value *Shift =
5296 Builder.CreateAdd(Scale, TileLoop->getIndVar(), {}, /*HasNUW=*/true);
5297 OrigIndVar->replaceAllUsesWith(Shift);
5298 }
5299
5300 // Remove unused parts of the original loops.
5301 removeUnusedBlocksFromParent(OldControlBBs);
5302
5303 for (CanonicalLoopInfo *L : Loops)
5304 L->invalidate();
5305
5306 #ifndef NDEBUG
5307 for (CanonicalLoopInfo *GenL : Result)
5308 GenL->assertOK();
5309 #endif
5310 return Result;
5311 }
5312
5313 /// Attach metadata \p Properties to the basic block described by \p BB. If the
5314 /// basic block already has metadata, the basic block properties are appended.
addBasicBlockMetadata(BasicBlock * BB,ArrayRef<Metadata * > Properties)5315 static void addBasicBlockMetadata(BasicBlock *BB,
5316 ArrayRef<Metadata *> Properties) {
5317 // Nothing to do if no property to attach.
5318 if (Properties.empty())
5319 return;
5320
5321 LLVMContext &Ctx = BB->getContext();
5322 SmallVector<Metadata *> NewProperties;
5323 NewProperties.push_back(nullptr);
5324
5325 // If the basic block already has metadata, prepend it to the new metadata.
5326 MDNode *Existing = BB->getTerminator()->getMetadata(LLVMContext::MD_loop);
5327 if (Existing)
5328 append_range(NewProperties, drop_begin(Existing->operands(), 1));
5329
5330 append_range(NewProperties, Properties);
5331 MDNode *BasicBlockID = MDNode::getDistinct(Ctx, NewProperties);
5332 BasicBlockID->replaceOperandWith(0, BasicBlockID);
5333
5334 BB->getTerminator()->setMetadata(LLVMContext::MD_loop, BasicBlockID);
5335 }
5336
5337 /// Attach loop metadata \p Properties to the loop described by \p Loop. If the
5338 /// loop already has metadata, the loop properties are appended.
addLoopMetadata(CanonicalLoopInfo * Loop,ArrayRef<Metadata * > Properties)5339 static void addLoopMetadata(CanonicalLoopInfo *Loop,
5340 ArrayRef<Metadata *> Properties) {
5341 assert(Loop->isValid() && "Expecting a valid CanonicalLoopInfo");
5342
5343 // Attach metadata to the loop's latch
5344 BasicBlock *Latch = Loop->getLatch();
5345 assert(Latch && "A valid CanonicalLoopInfo must have a unique latch");
5346 addBasicBlockMetadata(Latch, Properties);
5347 }
5348
5349 /// Attach llvm.access.group metadata to the memref instructions of \p Block
addSimdMetadata(BasicBlock * Block,MDNode * AccessGroup,LoopInfo & LI)5350 static void addSimdMetadata(BasicBlock *Block, MDNode *AccessGroup,
5351 LoopInfo &LI) {
5352 for (Instruction &I : *Block) {
5353 if (I.mayReadOrWriteMemory()) {
5354 // TODO: This instruction may already have access group from
5355 // other pragmas e.g. #pragma clang loop vectorize. Append
5356 // so that the existing metadata is not overwritten.
5357 I.setMetadata(LLVMContext::MD_access_group, AccessGroup);
5358 }
5359 }
5360 }
5361
unrollLoopFull(DebugLoc,CanonicalLoopInfo * Loop)5362 void OpenMPIRBuilder::unrollLoopFull(DebugLoc, CanonicalLoopInfo *Loop) {
5363 LLVMContext &Ctx = Builder.getContext();
5364 addLoopMetadata(
5365 Loop, {MDNode::get(Ctx, MDString::get(Ctx, "llvm.loop.unroll.enable")),
5366 MDNode::get(Ctx, MDString::get(Ctx, "llvm.loop.unroll.full"))});
5367 }
5368
unrollLoopHeuristic(DebugLoc,CanonicalLoopInfo * Loop)5369 void OpenMPIRBuilder::unrollLoopHeuristic(DebugLoc, CanonicalLoopInfo *Loop) {
5370 LLVMContext &Ctx = Builder.getContext();
5371 addLoopMetadata(
5372 Loop, {
5373 MDNode::get(Ctx, MDString::get(Ctx, "llvm.loop.unroll.enable")),
5374 });
5375 }
5376
createIfVersion(CanonicalLoopInfo * CanonicalLoop,Value * IfCond,ValueToValueMapTy & VMap,LoopAnalysis & LIA,LoopInfo & LI,Loop * L,const Twine & NamePrefix)5377 void OpenMPIRBuilder::createIfVersion(CanonicalLoopInfo *CanonicalLoop,
5378 Value *IfCond, ValueToValueMapTy &VMap,
5379 LoopAnalysis &LIA, LoopInfo &LI, Loop *L,
5380 const Twine &NamePrefix) {
5381 Function *F = CanonicalLoop->getFunction();
5382
5383 // We can't do
5384 // if (cond) {
5385 // simd_loop;
5386 // } else {
5387 // non_simd_loop;
5388 // }
5389 // because then the CanonicalLoopInfo would only point to one of the loops:
5390 // leading to other constructs operating on the same loop to malfunction.
5391 // Instead generate
5392 // while (...) {
5393 // if (cond) {
5394 // simd_body;
5395 // } else {
5396 // not_simd_body;
5397 // }
5398 // }
5399 // At least for simple loops, LLVM seems able to hoist the if out of the loop
5400 // body at -O3
5401
5402 // Define where if branch should be inserted
5403 auto SplitBeforeIt = CanonicalLoop->getBody()->getFirstNonPHIIt();
5404
5405 // Create additional blocks for the if statement
5406 BasicBlock *Cond = SplitBeforeIt->getParent();
5407 llvm::LLVMContext &C = Cond->getContext();
5408 llvm::BasicBlock *ThenBlock = llvm::BasicBlock::Create(
5409 C, NamePrefix + ".if.then", Cond->getParent(), Cond->getNextNode());
5410 llvm::BasicBlock *ElseBlock = llvm::BasicBlock::Create(
5411 C, NamePrefix + ".if.else", Cond->getParent(), CanonicalLoop->getExit());
5412
5413 // Create if condition branch.
5414 Builder.SetInsertPoint(SplitBeforeIt);
5415 Instruction *BrInstr =
5416 Builder.CreateCondBr(IfCond, ThenBlock, /*ifFalse*/ ElseBlock);
5417 InsertPointTy IP{BrInstr->getParent(), ++BrInstr->getIterator()};
5418 // Then block contains branch to omp loop body which needs to be vectorized
5419 spliceBB(IP, ThenBlock, false, Builder.getCurrentDebugLocation());
5420 ThenBlock->replaceSuccessorsPhiUsesWith(Cond, ThenBlock);
5421
5422 Builder.SetInsertPoint(ElseBlock);
5423
5424 // Clone loop for the else branch
5425 SmallVector<BasicBlock *, 8> NewBlocks;
5426
5427 SmallVector<BasicBlock *, 8> ExistingBlocks;
5428 ExistingBlocks.reserve(L->getNumBlocks() + 1);
5429 ExistingBlocks.push_back(ThenBlock);
5430 ExistingBlocks.append(L->block_begin(), L->block_end());
5431 // Cond is the block that has the if clause condition
5432 // LoopCond is omp_loop.cond
5433 // LoopHeader is omp_loop.header
5434 BasicBlock *LoopCond = Cond->getUniquePredecessor();
5435 BasicBlock *LoopHeader = LoopCond->getUniquePredecessor();
5436 assert(LoopCond && LoopHeader && "Invalid loop structure");
5437 for (BasicBlock *Block : ExistingBlocks) {
5438 if (Block == L->getLoopPreheader() || Block == L->getLoopLatch() ||
5439 Block == LoopHeader || Block == LoopCond || Block == Cond) {
5440 continue;
5441 }
5442 BasicBlock *NewBB = CloneBasicBlock(Block, VMap, "", F);
5443
5444 // fix name not to be omp.if.then
5445 if (Block == ThenBlock)
5446 NewBB->setName(NamePrefix + ".if.else");
5447
5448 NewBB->moveBefore(CanonicalLoop->getExit());
5449 VMap[Block] = NewBB;
5450 NewBlocks.push_back(NewBB);
5451 }
5452 remapInstructionsInBlocks(NewBlocks, VMap);
5453 Builder.CreateBr(NewBlocks.front());
5454
5455 // The loop latch must have only one predecessor. Currently it is branched to
5456 // from both the 'then' and 'else' branches.
5457 L->getLoopLatch()->splitBasicBlock(
5458 L->getLoopLatch()->begin(), NamePrefix + ".pre_latch", /*Before=*/true);
5459
5460 // Ensure that the then block is added to the loop so we add the attributes in
5461 // the next step
5462 L->addBasicBlockToLoop(ThenBlock, LI);
5463 }
5464
5465 unsigned
getOpenMPDefaultSimdAlign(const Triple & TargetTriple,const StringMap<bool> & Features)5466 OpenMPIRBuilder::getOpenMPDefaultSimdAlign(const Triple &TargetTriple,
5467 const StringMap<bool> &Features) {
5468 if (TargetTriple.isX86()) {
5469 if (Features.lookup("avx512f"))
5470 return 512;
5471 else if (Features.lookup("avx"))
5472 return 256;
5473 return 128;
5474 }
5475 if (TargetTriple.isPPC())
5476 return 128;
5477 if (TargetTriple.isWasm())
5478 return 128;
5479 return 0;
5480 }
5481
applySimd(CanonicalLoopInfo * CanonicalLoop,MapVector<Value *,Value * > AlignedVars,Value * IfCond,OrderKind Order,ConstantInt * Simdlen,ConstantInt * Safelen)5482 void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop,
5483 MapVector<Value *, Value *> AlignedVars,
5484 Value *IfCond, OrderKind Order,
5485 ConstantInt *Simdlen, ConstantInt *Safelen) {
5486 LLVMContext &Ctx = Builder.getContext();
5487
5488 Function *F = CanonicalLoop->getFunction();
5489
5490 // TODO: We should not rely on pass manager. Currently we use pass manager
5491 // only for getting llvm::Loop which corresponds to given CanonicalLoopInfo
5492 // object. We should have a method which returns all blocks between
5493 // CanonicalLoopInfo::getHeader() and CanonicalLoopInfo::getAfter()
5494 FunctionAnalysisManager FAM;
5495 FAM.registerPass([]() { return DominatorTreeAnalysis(); });
5496 FAM.registerPass([]() { return LoopAnalysis(); });
5497 FAM.registerPass([]() { return PassInstrumentationAnalysis(); });
5498
5499 LoopAnalysis LIA;
5500 LoopInfo &&LI = LIA.run(*F, FAM);
5501
5502 Loop *L = LI.getLoopFor(CanonicalLoop->getHeader());
5503 if (AlignedVars.size()) {
5504 InsertPointTy IP = Builder.saveIP();
5505 for (auto &AlignedItem : AlignedVars) {
5506 Value *AlignedPtr = AlignedItem.first;
5507 Value *Alignment = AlignedItem.second;
5508 Instruction *loadInst = dyn_cast<Instruction>(AlignedPtr);
5509 Builder.SetInsertPoint(loadInst->getNextNode());
5510 Builder.CreateAlignmentAssumption(F->getDataLayout(), AlignedPtr,
5511 Alignment);
5512 }
5513 Builder.restoreIP(IP);
5514 }
5515
5516 if (IfCond) {
5517 ValueToValueMapTy VMap;
5518 createIfVersion(CanonicalLoop, IfCond, VMap, LIA, LI, L, "simd");
5519 }
5520
5521 SmallSet<BasicBlock *, 8> Reachable;
5522
5523 // Get the basic blocks from the loop in which memref instructions
5524 // can be found.
5525 // TODO: Generalize getting all blocks inside a CanonicalizeLoopInfo,
5526 // preferably without running any passes.
5527 for (BasicBlock *Block : L->getBlocks()) {
5528 if (Block == CanonicalLoop->getCond() ||
5529 Block == CanonicalLoop->getHeader())
5530 continue;
5531 Reachable.insert(Block);
5532 }
5533
5534 SmallVector<Metadata *> LoopMDList;
5535
5536 // In presence of finite 'safelen', it may be unsafe to mark all
5537 // the memory instructions parallel, because loop-carried
5538 // dependences of 'safelen' iterations are possible.
5539 // If clause order(concurrent) is specified then the memory instructions
5540 // are marked parallel even if 'safelen' is finite.
5541 if ((Safelen == nullptr) || (Order == OrderKind::OMP_ORDER_concurrent)) {
5542 // Add access group metadata to memory-access instructions.
5543 MDNode *AccessGroup = MDNode::getDistinct(Ctx, {});
5544 for (BasicBlock *BB : Reachable)
5545 addSimdMetadata(BB, AccessGroup, LI);
5546 // TODO: If the loop has existing parallel access metadata, have
5547 // to combine two lists.
5548 LoopMDList.push_back(MDNode::get(
5549 Ctx, {MDString::get(Ctx, "llvm.loop.parallel_accesses"), AccessGroup}));
5550 }
5551
5552 // FIXME: the IF clause shares a loop backedge for the SIMD and non-SIMD
5553 // versions so we can't add the loop attributes in that case.
5554 if (IfCond) {
5555 // we can still add llvm.loop.parallel_access
5556 addLoopMetadata(CanonicalLoop, LoopMDList);
5557 return;
5558 }
5559
5560 // Use the above access group metadata to create loop level
5561 // metadata, which should be distinct for each loop.
5562 ConstantAsMetadata *BoolConst =
5563 ConstantAsMetadata::get(ConstantInt::getTrue(Type::getInt1Ty(Ctx)));
5564 LoopMDList.push_back(MDNode::get(
5565 Ctx, {MDString::get(Ctx, "llvm.loop.vectorize.enable"), BoolConst}));
5566
5567 if (Simdlen || Safelen) {
5568 // If both simdlen and safelen clauses are specified, the value of the
5569 // simdlen parameter must be less than or equal to the value of the safelen
5570 // parameter. Therefore, use safelen only in the absence of simdlen.
5571 ConstantInt *VectorizeWidth = Simdlen == nullptr ? Safelen : Simdlen;
5572 LoopMDList.push_back(
5573 MDNode::get(Ctx, {MDString::get(Ctx, "llvm.loop.vectorize.width"),
5574 ConstantAsMetadata::get(VectorizeWidth)}));
5575 }
5576
5577 addLoopMetadata(CanonicalLoop, LoopMDList);
5578 }
5579
5580 /// Create the TargetMachine object to query the backend for optimization
5581 /// preferences.
5582 ///
5583 /// Ideally, this would be passed from the front-end to the OpenMPBuilder, but
5584 /// e.g. Clang does not pass it to its CodeGen layer and creates it only when
5585 /// needed for the LLVM pass pipline. We use some default options to avoid
5586 /// having to pass too many settings from the frontend that probably do not
5587 /// matter.
5588 ///
5589 /// Currently, TargetMachine is only used sometimes by the unrollLoopPartial
5590 /// method. If we are going to use TargetMachine for more purposes, especially
5591 /// those that are sensitive to TargetOptions, RelocModel and CodeModel, it
5592 /// might become be worth requiring front-ends to pass on their TargetMachine,
5593 /// or at least cache it between methods. Note that while fontends such as Clang
5594 /// have just a single main TargetMachine per translation unit, "target-cpu" and
5595 /// "target-features" that determine the TargetMachine are per-function and can
5596 /// be overrided using __attribute__((target("OPTIONS"))).
5597 static std::unique_ptr<TargetMachine>
createTargetMachine(Function * F,CodeGenOptLevel OptLevel)5598 createTargetMachine(Function *F, CodeGenOptLevel OptLevel) {
5599 Module *M = F->getParent();
5600
5601 StringRef CPU = F->getFnAttribute("target-cpu").getValueAsString();
5602 StringRef Features = F->getFnAttribute("target-features").getValueAsString();
5603 const llvm::Triple &Triple = M->getTargetTriple();
5604
5605 std::string Error;
5606 const llvm::Target *TheTarget = TargetRegistry::lookupTarget(Triple, Error);
5607 if (!TheTarget)
5608 return {};
5609
5610 llvm::TargetOptions Options;
5611 return std::unique_ptr<TargetMachine>(TheTarget->createTargetMachine(
5612 Triple, CPU, Features, Options, /*RelocModel=*/std::nullopt,
5613 /*CodeModel=*/std::nullopt, OptLevel));
5614 }
5615
5616 /// Heuristically determine the best-performant unroll factor for \p CLI. This
5617 /// depends on the target processor. We are re-using the same heuristics as the
5618 /// LoopUnrollPass.
computeHeuristicUnrollFactor(CanonicalLoopInfo * CLI)5619 static int32_t computeHeuristicUnrollFactor(CanonicalLoopInfo *CLI) {
5620 Function *F = CLI->getFunction();
5621
5622 // Assume the user requests the most aggressive unrolling, even if the rest of
5623 // the code is optimized using a lower setting.
5624 CodeGenOptLevel OptLevel = CodeGenOptLevel::Aggressive;
5625 std::unique_ptr<TargetMachine> TM = createTargetMachine(F, OptLevel);
5626
5627 FunctionAnalysisManager FAM;
5628 FAM.registerPass([]() { return TargetLibraryAnalysis(); });
5629 FAM.registerPass([]() { return AssumptionAnalysis(); });
5630 FAM.registerPass([]() { return DominatorTreeAnalysis(); });
5631 FAM.registerPass([]() { return LoopAnalysis(); });
5632 FAM.registerPass([]() { return ScalarEvolutionAnalysis(); });
5633 FAM.registerPass([]() { return PassInstrumentationAnalysis(); });
5634 TargetIRAnalysis TIRA;
5635 if (TM)
5636 TIRA = TargetIRAnalysis(
5637 [&](const Function &F) { return TM->getTargetTransformInfo(F); });
5638 FAM.registerPass([&]() { return TIRA; });
5639
5640 TargetIRAnalysis::Result &&TTI = TIRA.run(*F, FAM);
5641 ScalarEvolutionAnalysis SEA;
5642 ScalarEvolution &&SE = SEA.run(*F, FAM);
5643 DominatorTreeAnalysis DTA;
5644 DominatorTree &&DT = DTA.run(*F, FAM);
5645 LoopAnalysis LIA;
5646 LoopInfo &&LI = LIA.run(*F, FAM);
5647 AssumptionAnalysis ACT;
5648 AssumptionCache &&AC = ACT.run(*F, FAM);
5649 OptimizationRemarkEmitter ORE{F};
5650
5651 Loop *L = LI.getLoopFor(CLI->getHeader());
5652 assert(L && "Expecting CanonicalLoopInfo to be recognized as a loop");
5653
5654 TargetTransformInfo::UnrollingPreferences UP = gatherUnrollingPreferences(
5655 L, SE, TTI,
5656 /*BlockFrequencyInfo=*/nullptr,
5657 /*ProfileSummaryInfo=*/nullptr, ORE, static_cast<int>(OptLevel),
5658 /*UserThreshold=*/std::nullopt,
5659 /*UserCount=*/std::nullopt,
5660 /*UserAllowPartial=*/true,
5661 /*UserAllowRuntime=*/true,
5662 /*UserUpperBound=*/std::nullopt,
5663 /*UserFullUnrollMaxCount=*/std::nullopt);
5664
5665 UP.Force = true;
5666
5667 // Account for additional optimizations taking place before the LoopUnrollPass
5668 // would unroll the loop.
5669 UP.Threshold *= UnrollThresholdFactor;
5670 UP.PartialThreshold *= UnrollThresholdFactor;
5671
5672 // Use normal unroll factors even if the rest of the code is optimized for
5673 // size.
5674 UP.OptSizeThreshold = UP.Threshold;
5675 UP.PartialOptSizeThreshold = UP.PartialThreshold;
5676
5677 LLVM_DEBUG(dbgs() << "Unroll heuristic thresholds:\n"
5678 << " Threshold=" << UP.Threshold << "\n"
5679 << " PartialThreshold=" << UP.PartialThreshold << "\n"
5680 << " OptSizeThreshold=" << UP.OptSizeThreshold << "\n"
5681 << " PartialOptSizeThreshold="
5682 << UP.PartialOptSizeThreshold << "\n");
5683
5684 // Disable peeling.
5685 TargetTransformInfo::PeelingPreferences PP =
5686 gatherPeelingPreferences(L, SE, TTI,
5687 /*UserAllowPeeling=*/false,
5688 /*UserAllowProfileBasedPeeling=*/false,
5689 /*UnrollingSpecficValues=*/false);
5690
5691 SmallPtrSet<const Value *, 32> EphValues;
5692 CodeMetrics::collectEphemeralValues(L, &AC, EphValues);
5693
5694 // Assume that reads and writes to stack variables can be eliminated by
5695 // Mem2Reg, SROA or LICM. That is, don't count them towards the loop body's
5696 // size.
5697 for (BasicBlock *BB : L->blocks()) {
5698 for (Instruction &I : *BB) {
5699 Value *Ptr;
5700 if (auto *Load = dyn_cast<LoadInst>(&I)) {
5701 Ptr = Load->getPointerOperand();
5702 } else if (auto *Store = dyn_cast<StoreInst>(&I)) {
5703 Ptr = Store->getPointerOperand();
5704 } else
5705 continue;
5706
5707 Ptr = Ptr->stripPointerCasts();
5708
5709 if (auto *Alloca = dyn_cast<AllocaInst>(Ptr)) {
5710 if (Alloca->getParent() == &F->getEntryBlock())
5711 EphValues.insert(&I);
5712 }
5713 }
5714 }
5715
5716 UnrollCostEstimator UCE(L, TTI, EphValues, UP.BEInsns);
5717
5718 // Loop is not unrollable if the loop contains certain instructions.
5719 if (!UCE.canUnroll()) {
5720 LLVM_DEBUG(dbgs() << "Loop not considered unrollable\n");
5721 return 1;
5722 }
5723
5724 LLVM_DEBUG(dbgs() << "Estimated loop size is " << UCE.getRolledLoopSize()
5725 << "\n");
5726
5727 // TODO: Determine trip count of \p CLI if constant, computeUnrollCount might
5728 // be able to use it.
5729 int TripCount = 0;
5730 int MaxTripCount = 0;
5731 bool MaxOrZero = false;
5732 unsigned TripMultiple = 0;
5733
5734 bool UseUpperBound = false;
5735 computeUnrollCount(L, TTI, DT, &LI, &AC, SE, EphValues, &ORE, TripCount,
5736 MaxTripCount, MaxOrZero, TripMultiple, UCE, UP, PP,
5737 UseUpperBound);
5738 unsigned Factor = UP.Count;
5739 LLVM_DEBUG(dbgs() << "Suggesting unroll factor of " << Factor << "\n");
5740
5741 // This function returns 1 to signal to not unroll a loop.
5742 if (Factor == 0)
5743 return 1;
5744 return Factor;
5745 }
5746
unrollLoopPartial(DebugLoc DL,CanonicalLoopInfo * Loop,int32_t Factor,CanonicalLoopInfo ** UnrolledCLI)5747 void OpenMPIRBuilder::unrollLoopPartial(DebugLoc DL, CanonicalLoopInfo *Loop,
5748 int32_t Factor,
5749 CanonicalLoopInfo **UnrolledCLI) {
5750 assert(Factor >= 0 && "Unroll factor must not be negative");
5751
5752 Function *F = Loop->getFunction();
5753 LLVMContext &Ctx = F->getContext();
5754
5755 // If the unrolled loop is not used for another loop-associated directive, it
5756 // is sufficient to add metadata for the LoopUnrollPass.
5757 if (!UnrolledCLI) {
5758 SmallVector<Metadata *, 2> LoopMetadata;
5759 LoopMetadata.push_back(
5760 MDNode::get(Ctx, MDString::get(Ctx, "llvm.loop.unroll.enable")));
5761
5762 if (Factor >= 1) {
5763 ConstantAsMetadata *FactorConst = ConstantAsMetadata::get(
5764 ConstantInt::get(Type::getInt32Ty(Ctx), APInt(32, Factor)));
5765 LoopMetadata.push_back(MDNode::get(
5766 Ctx, {MDString::get(Ctx, "llvm.loop.unroll.count"), FactorConst}));
5767 }
5768
5769 addLoopMetadata(Loop, LoopMetadata);
5770 return;
5771 }
5772
5773 // Heuristically determine the unroll factor.
5774 if (Factor == 0)
5775 Factor = computeHeuristicUnrollFactor(Loop);
5776
5777 // No change required with unroll factor 1.
5778 if (Factor == 1) {
5779 *UnrolledCLI = Loop;
5780 return;
5781 }
5782
5783 assert(Factor >= 2 &&
5784 "unrolling only makes sense with a factor of 2 or larger");
5785
5786 Type *IndVarTy = Loop->getIndVarType();
5787
5788 // Apply partial unrolling by tiling the loop by the unroll-factor, then fully
5789 // unroll the inner loop.
5790 Value *FactorVal =
5791 ConstantInt::get(IndVarTy, APInt(IndVarTy->getIntegerBitWidth(), Factor,
5792 /*isSigned=*/false));
5793 std::vector<CanonicalLoopInfo *> LoopNest =
5794 tileLoops(DL, {Loop}, {FactorVal});
5795 assert(LoopNest.size() == 2 && "Expect 2 loops after tiling");
5796 *UnrolledCLI = LoopNest[0];
5797 CanonicalLoopInfo *InnerLoop = LoopNest[1];
5798
5799 // LoopUnrollPass can only fully unroll loops with constant trip count.
5800 // Unroll by the unroll factor with a fallback epilog for the remainder
5801 // iterations if necessary.
5802 ConstantAsMetadata *FactorConst = ConstantAsMetadata::get(
5803 ConstantInt::get(Type::getInt32Ty(Ctx), APInt(32, Factor)));
5804 addLoopMetadata(
5805 InnerLoop,
5806 {MDNode::get(Ctx, MDString::get(Ctx, "llvm.loop.unroll.enable")),
5807 MDNode::get(
5808 Ctx, {MDString::get(Ctx, "llvm.loop.unroll.count"), FactorConst})});
5809
5810 #ifndef NDEBUG
5811 (*UnrolledCLI)->assertOK();
5812 #endif
5813 }
5814
5815 OpenMPIRBuilder::InsertPointTy
createCopyPrivate(const LocationDescription & Loc,llvm::Value * BufSize,llvm::Value * CpyBuf,llvm::Value * CpyFn,llvm::Value * DidIt)5816 OpenMPIRBuilder::createCopyPrivate(const LocationDescription &Loc,
5817 llvm::Value *BufSize, llvm::Value *CpyBuf,
5818 llvm::Value *CpyFn, llvm::Value *DidIt) {
5819 if (!updateToLocation(Loc))
5820 return Loc.IP;
5821
5822 uint32_t SrcLocStrSize;
5823 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5824 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5825 Value *ThreadId = getOrCreateThreadID(Ident);
5826
5827 llvm::Value *DidItLD = Builder.CreateLoad(Builder.getInt32Ty(), DidIt);
5828
5829 Value *Args[] = {Ident, ThreadId, BufSize, CpyBuf, CpyFn, DidItLD};
5830
5831 Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_copyprivate);
5832 Builder.CreateCall(Fn, Args);
5833
5834 return Builder.saveIP();
5835 }
5836
createSingle(const LocationDescription & Loc,BodyGenCallbackTy BodyGenCB,FinalizeCallbackTy FiniCB,bool IsNowait,ArrayRef<llvm::Value * > CPVars,ArrayRef<llvm::Function * > CPFuncs)5837 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createSingle(
5838 const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
5839 FinalizeCallbackTy FiniCB, bool IsNowait, ArrayRef<llvm::Value *> CPVars,
5840 ArrayRef<llvm::Function *> CPFuncs) {
5841
5842 if (!updateToLocation(Loc))
5843 return Loc.IP;
5844
5845 // If needed allocate and initialize `DidIt` with 0.
5846 // DidIt: flag variable: 1=single thread; 0=not single thread.
5847 llvm::Value *DidIt = nullptr;
5848 if (!CPVars.empty()) {
5849 DidIt = Builder.CreateAlloca(llvm::Type::getInt32Ty(Builder.getContext()));
5850 Builder.CreateStore(Builder.getInt32(0), DidIt);
5851 }
5852
5853 Directive OMPD = Directive::OMPD_single;
5854 uint32_t SrcLocStrSize;
5855 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5856 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5857 Value *ThreadId = getOrCreateThreadID(Ident);
5858 Value *Args[] = {Ident, ThreadId};
5859
5860 Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_single);
5861 Instruction *EntryCall = Builder.CreateCall(EntryRTLFn, Args);
5862
5863 Function *ExitRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_single);
5864 Instruction *ExitCall = Builder.CreateCall(ExitRTLFn, Args);
5865
5866 auto FiniCBWrapper = [&](InsertPointTy IP) -> Error {
5867 if (Error Err = FiniCB(IP))
5868 return Err;
5869
5870 // The thread that executes the single region must set `DidIt` to 1.
5871 // This is used by __kmpc_copyprivate, to know if the caller is the
5872 // single thread or not.
5873 if (DidIt)
5874 Builder.CreateStore(Builder.getInt32(1), DidIt);
5875
5876 return Error::success();
5877 };
5878
5879 // generates the following:
5880 // if (__kmpc_single()) {
5881 // .... single region ...
5882 // __kmpc_end_single
5883 // }
5884 // __kmpc_copyprivate
5885 // __kmpc_barrier
5886
5887 InsertPointOrErrorTy AfterIP =
5888 EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCBWrapper,
5889 /*Conditional*/ true,
5890 /*hasFinalize*/ true);
5891 if (!AfterIP)
5892 return AfterIP.takeError();
5893
5894 if (DidIt) {
5895 for (size_t I = 0, E = CPVars.size(); I < E; ++I)
5896 // NOTE BufSize is currently unused, so just pass 0.
5897 createCopyPrivate(LocationDescription(Builder.saveIP(), Loc.DL),
5898 /*BufSize=*/ConstantInt::get(Int64, 0), CPVars[I],
5899 CPFuncs[I], DidIt);
5900 // NOTE __kmpc_copyprivate already inserts a barrier
5901 } else if (!IsNowait) {
5902 InsertPointOrErrorTy AfterIP =
5903 createBarrier(LocationDescription(Builder.saveIP(), Loc.DL),
5904 omp::Directive::OMPD_unknown, /* ForceSimpleCall */ false,
5905 /* CheckCancelFlag */ false);
5906 if (!AfterIP)
5907 return AfterIP.takeError();
5908 }
5909 return Builder.saveIP();
5910 }
5911
createCritical(const LocationDescription & Loc,BodyGenCallbackTy BodyGenCB,FinalizeCallbackTy FiniCB,StringRef CriticalName,Value * HintInst)5912 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createCritical(
5913 const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
5914 FinalizeCallbackTy FiniCB, StringRef CriticalName, Value *HintInst) {
5915
5916 if (!updateToLocation(Loc))
5917 return Loc.IP;
5918
5919 Directive OMPD = Directive::OMPD_critical;
5920 uint32_t SrcLocStrSize;
5921 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5922 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5923 Value *ThreadId = getOrCreateThreadID(Ident);
5924 Value *LockVar = getOMPCriticalRegionLock(CriticalName);
5925 Value *Args[] = {Ident, ThreadId, LockVar};
5926
5927 SmallVector<llvm::Value *, 4> EnterArgs(std::begin(Args), std::end(Args));
5928 Function *RTFn = nullptr;
5929 if (HintInst) {
5930 // Add Hint to entry Args and create call
5931 EnterArgs.push_back(HintInst);
5932 RTFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_critical_with_hint);
5933 } else {
5934 RTFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_critical);
5935 }
5936 Instruction *EntryCall = Builder.CreateCall(RTFn, EnterArgs);
5937
5938 Function *ExitRTLFn =
5939 getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_critical);
5940 Instruction *ExitCall = Builder.CreateCall(ExitRTLFn, Args);
5941
5942 return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
5943 /*Conditional*/ false, /*hasFinalize*/ true);
5944 }
5945
5946 OpenMPIRBuilder::InsertPointTy
createOrderedDepend(const LocationDescription & Loc,InsertPointTy AllocaIP,unsigned NumLoops,ArrayRef<llvm::Value * > StoreValues,const Twine & Name,bool IsDependSource)5947 OpenMPIRBuilder::createOrderedDepend(const LocationDescription &Loc,
5948 InsertPointTy AllocaIP, unsigned NumLoops,
5949 ArrayRef<llvm::Value *> StoreValues,
5950 const Twine &Name, bool IsDependSource) {
5951 assert(
5952 llvm::all_of(StoreValues,
5953 [](Value *SV) { return SV->getType()->isIntegerTy(64); }) &&
5954 "OpenMP runtime requires depend vec with i64 type");
5955
5956 if (!updateToLocation(Loc))
5957 return Loc.IP;
5958
5959 // Allocate space for vector and generate alloc instruction.
5960 auto *ArrI64Ty = ArrayType::get(Int64, NumLoops);
5961 Builder.restoreIP(AllocaIP);
5962 AllocaInst *ArgsBase = Builder.CreateAlloca(ArrI64Ty, nullptr, Name);
5963 ArgsBase->setAlignment(Align(8));
5964 Builder.restoreIP(Loc.IP);
5965
5966 // Store the index value with offset in depend vector.
5967 for (unsigned I = 0; I < NumLoops; ++I) {
5968 Value *DependAddrGEPIter = Builder.CreateInBoundsGEP(
5969 ArrI64Ty, ArgsBase, {Builder.getInt64(0), Builder.getInt64(I)});
5970 StoreInst *STInst = Builder.CreateStore(StoreValues[I], DependAddrGEPIter);
5971 STInst->setAlignment(Align(8));
5972 }
5973
5974 Value *DependBaseAddrGEP = Builder.CreateInBoundsGEP(
5975 ArrI64Ty, ArgsBase, {Builder.getInt64(0), Builder.getInt64(0)});
5976
5977 uint32_t SrcLocStrSize;
5978 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5979 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5980 Value *ThreadId = getOrCreateThreadID(Ident);
5981 Value *Args[] = {Ident, ThreadId, DependBaseAddrGEP};
5982
5983 Function *RTLFn = nullptr;
5984 if (IsDependSource)
5985 RTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_doacross_post);
5986 else
5987 RTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_doacross_wait);
5988 Builder.CreateCall(RTLFn, Args);
5989
5990 return Builder.saveIP();
5991 }
5992
createOrderedThreadsSimd(const LocationDescription & Loc,BodyGenCallbackTy BodyGenCB,FinalizeCallbackTy FiniCB,bool IsThreads)5993 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createOrderedThreadsSimd(
5994 const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
5995 FinalizeCallbackTy FiniCB, bool IsThreads) {
5996 if (!updateToLocation(Loc))
5997 return Loc.IP;
5998
5999 Directive OMPD = Directive::OMPD_ordered;
6000 Instruction *EntryCall = nullptr;
6001 Instruction *ExitCall = nullptr;
6002
6003 if (IsThreads) {
6004 uint32_t SrcLocStrSize;
6005 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
6006 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
6007 Value *ThreadId = getOrCreateThreadID(Ident);
6008 Value *Args[] = {Ident, ThreadId};
6009
6010 Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_ordered);
6011 EntryCall = Builder.CreateCall(EntryRTLFn, Args);
6012
6013 Function *ExitRTLFn =
6014 getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_ordered);
6015 ExitCall = Builder.CreateCall(ExitRTLFn, Args);
6016 }
6017
6018 return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
6019 /*Conditional*/ false, /*hasFinalize*/ true);
6020 }
6021
EmitOMPInlinedRegion(Directive OMPD,Instruction * EntryCall,Instruction * ExitCall,BodyGenCallbackTy BodyGenCB,FinalizeCallbackTy FiniCB,bool Conditional,bool HasFinalize,bool IsCancellable)6022 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::EmitOMPInlinedRegion(
6023 Directive OMPD, Instruction *EntryCall, Instruction *ExitCall,
6024 BodyGenCallbackTy BodyGenCB, FinalizeCallbackTy FiniCB, bool Conditional,
6025 bool HasFinalize, bool IsCancellable) {
6026
6027 if (HasFinalize)
6028 FinalizationStack.push_back({FiniCB, OMPD, IsCancellable});
6029
6030 // Create inlined region's entry and body blocks, in preparation
6031 // for conditional creation
6032 BasicBlock *EntryBB = Builder.GetInsertBlock();
6033 Instruction *SplitPos = EntryBB->getTerminator();
6034 if (!isa_and_nonnull<BranchInst>(SplitPos))
6035 SplitPos = new UnreachableInst(Builder.getContext(), EntryBB);
6036 BasicBlock *ExitBB = EntryBB->splitBasicBlock(SplitPos, "omp_region.end");
6037 BasicBlock *FiniBB =
6038 EntryBB->splitBasicBlock(EntryBB->getTerminator(), "omp_region.finalize");
6039
6040 Builder.SetInsertPoint(EntryBB->getTerminator());
6041 emitCommonDirectiveEntry(OMPD, EntryCall, ExitBB, Conditional);
6042
6043 // generate body
6044 if (Error Err = BodyGenCB(/* AllocaIP */ InsertPointTy(),
6045 /* CodeGenIP */ Builder.saveIP()))
6046 return Err;
6047
6048 // emit exit call and do any needed finalization.
6049 auto FinIP = InsertPointTy(FiniBB, FiniBB->getFirstInsertionPt());
6050 assert(FiniBB->getTerminator()->getNumSuccessors() == 1 &&
6051 FiniBB->getTerminator()->getSuccessor(0) == ExitBB &&
6052 "Unexpected control flow graph state!!");
6053 InsertPointOrErrorTy AfterIP =
6054 emitCommonDirectiveExit(OMPD, FinIP, ExitCall, HasFinalize);
6055 if (!AfterIP)
6056 return AfterIP.takeError();
6057 assert(FiniBB->getUniquePredecessor()->getUniqueSuccessor() == FiniBB &&
6058 "Unexpected Control Flow State!");
6059 MergeBlockIntoPredecessor(FiniBB);
6060
6061 // If we are skipping the region of a non conditional, remove the exit
6062 // block, and clear the builder's insertion point.
6063 assert(SplitPos->getParent() == ExitBB &&
6064 "Unexpected Insertion point location!");
6065 auto merged = MergeBlockIntoPredecessor(ExitBB);
6066 BasicBlock *ExitPredBB = SplitPos->getParent();
6067 auto InsertBB = merged ? ExitPredBB : ExitBB;
6068 if (!isa_and_nonnull<BranchInst>(SplitPos))
6069 SplitPos->eraseFromParent();
6070 Builder.SetInsertPoint(InsertBB);
6071
6072 return Builder.saveIP();
6073 }
6074
emitCommonDirectiveEntry(Directive OMPD,Value * EntryCall,BasicBlock * ExitBB,bool Conditional)6075 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitCommonDirectiveEntry(
6076 Directive OMPD, Value *EntryCall, BasicBlock *ExitBB, bool Conditional) {
6077 // if nothing to do, Return current insertion point.
6078 if (!Conditional || !EntryCall)
6079 return Builder.saveIP();
6080
6081 BasicBlock *EntryBB = Builder.GetInsertBlock();
6082 Value *CallBool = Builder.CreateIsNotNull(EntryCall);
6083 auto *ThenBB = BasicBlock::Create(M.getContext(), "omp_region.body");
6084 auto *UI = new UnreachableInst(Builder.getContext(), ThenBB);
6085
6086 // Emit thenBB and set the Builder's insertion point there for
6087 // body generation next. Place the block after the current block.
6088 Function *CurFn = EntryBB->getParent();
6089 CurFn->insert(std::next(EntryBB->getIterator()), ThenBB);
6090
6091 // Move Entry branch to end of ThenBB, and replace with conditional
6092 // branch (If-stmt)
6093 Instruction *EntryBBTI = EntryBB->getTerminator();
6094 Builder.CreateCondBr(CallBool, ThenBB, ExitBB);
6095 EntryBBTI->removeFromParent();
6096 Builder.SetInsertPoint(UI);
6097 Builder.Insert(EntryBBTI);
6098 UI->eraseFromParent();
6099 Builder.SetInsertPoint(ThenBB->getTerminator());
6100
6101 // return an insertion point to ExitBB.
6102 return IRBuilder<>::InsertPoint(ExitBB, ExitBB->getFirstInsertionPt());
6103 }
6104
emitCommonDirectiveExit(omp::Directive OMPD,InsertPointTy FinIP,Instruction * ExitCall,bool HasFinalize)6105 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitCommonDirectiveExit(
6106 omp::Directive OMPD, InsertPointTy FinIP, Instruction *ExitCall,
6107 bool HasFinalize) {
6108
6109 Builder.restoreIP(FinIP);
6110
6111 // If there is finalization to do, emit it before the exit call
6112 if (HasFinalize) {
6113 assert(!FinalizationStack.empty() &&
6114 "Unexpected finalization stack state!");
6115
6116 FinalizationInfo Fi = FinalizationStack.pop_back_val();
6117 assert(Fi.DK == OMPD && "Unexpected Directive for Finalization call!");
6118
6119 if (Error Err = Fi.FiniCB(FinIP))
6120 return Err;
6121
6122 BasicBlock *FiniBB = FinIP.getBlock();
6123 Instruction *FiniBBTI = FiniBB->getTerminator();
6124
6125 // set Builder IP for call creation
6126 Builder.SetInsertPoint(FiniBBTI);
6127 }
6128
6129 if (!ExitCall)
6130 return Builder.saveIP();
6131
6132 // place the Exitcall as last instruction before Finalization block terminator
6133 ExitCall->removeFromParent();
6134 Builder.Insert(ExitCall);
6135
6136 return IRBuilder<>::InsertPoint(ExitCall->getParent(),
6137 ExitCall->getIterator());
6138 }
6139
createCopyinClauseBlocks(InsertPointTy IP,Value * MasterAddr,Value * PrivateAddr,llvm::IntegerType * IntPtrTy,bool BranchtoEnd)6140 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createCopyinClauseBlocks(
6141 InsertPointTy IP, Value *MasterAddr, Value *PrivateAddr,
6142 llvm::IntegerType *IntPtrTy, bool BranchtoEnd) {
6143 if (!IP.isSet())
6144 return IP;
6145
6146 IRBuilder<>::InsertPointGuard IPG(Builder);
6147
6148 // creates the following CFG structure
6149 // OMP_Entry : (MasterAddr != PrivateAddr)?
6150 // F T
6151 // | \
6152 // | copin.not.master
6153 // | /
6154 // v /
6155 // copyin.not.master.end
6156 // |
6157 // v
6158 // OMP.Entry.Next
6159
6160 BasicBlock *OMP_Entry = IP.getBlock();
6161 Function *CurFn = OMP_Entry->getParent();
6162 BasicBlock *CopyBegin =
6163 BasicBlock::Create(M.getContext(), "copyin.not.master", CurFn);
6164 BasicBlock *CopyEnd = nullptr;
6165
6166 // If entry block is terminated, split to preserve the branch to following
6167 // basic block (i.e. OMP.Entry.Next), otherwise, leave everything as is.
6168 if (isa_and_nonnull<BranchInst>(OMP_Entry->getTerminator())) {
6169 CopyEnd = OMP_Entry->splitBasicBlock(OMP_Entry->getTerminator(),
6170 "copyin.not.master.end");
6171 OMP_Entry->getTerminator()->eraseFromParent();
6172 } else {
6173 CopyEnd =
6174 BasicBlock::Create(M.getContext(), "copyin.not.master.end", CurFn);
6175 }
6176
6177 Builder.SetInsertPoint(OMP_Entry);
6178 Value *MasterPtr = Builder.CreatePtrToInt(MasterAddr, IntPtrTy);
6179 Value *PrivatePtr = Builder.CreatePtrToInt(PrivateAddr, IntPtrTy);
6180 Value *cmp = Builder.CreateICmpNE(MasterPtr, PrivatePtr);
6181 Builder.CreateCondBr(cmp, CopyBegin, CopyEnd);
6182
6183 Builder.SetInsertPoint(CopyBegin);
6184 if (BranchtoEnd)
6185 Builder.SetInsertPoint(Builder.CreateBr(CopyEnd));
6186
6187 return Builder.saveIP();
6188 }
6189
createOMPAlloc(const LocationDescription & Loc,Value * Size,Value * Allocator,std::string Name)6190 CallInst *OpenMPIRBuilder::createOMPAlloc(const LocationDescription &Loc,
6191 Value *Size, Value *Allocator,
6192 std::string Name) {
6193 IRBuilder<>::InsertPointGuard IPG(Builder);
6194 updateToLocation(Loc);
6195
6196 uint32_t SrcLocStrSize;
6197 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
6198 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
6199 Value *ThreadId = getOrCreateThreadID(Ident);
6200 Value *Args[] = {ThreadId, Size, Allocator};
6201
6202 Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_alloc);
6203
6204 return Builder.CreateCall(Fn, Args, Name);
6205 }
6206
createOMPFree(const LocationDescription & Loc,Value * Addr,Value * Allocator,std::string Name)6207 CallInst *OpenMPIRBuilder::createOMPFree(const LocationDescription &Loc,
6208 Value *Addr, Value *Allocator,
6209 std::string Name) {
6210 IRBuilder<>::InsertPointGuard IPG(Builder);
6211 updateToLocation(Loc);
6212
6213 uint32_t SrcLocStrSize;
6214 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
6215 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
6216 Value *ThreadId = getOrCreateThreadID(Ident);
6217 Value *Args[] = {ThreadId, Addr, Allocator};
6218 Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_free);
6219 return Builder.CreateCall(Fn, Args, Name);
6220 }
6221
createOMPInteropInit(const LocationDescription & Loc,Value * InteropVar,omp::OMPInteropType InteropType,Value * Device,Value * NumDependences,Value * DependenceAddress,bool HaveNowaitClause)6222 CallInst *OpenMPIRBuilder::createOMPInteropInit(
6223 const LocationDescription &Loc, Value *InteropVar,
6224 omp::OMPInteropType InteropType, Value *Device, Value *NumDependences,
6225 Value *DependenceAddress, bool HaveNowaitClause) {
6226 IRBuilder<>::InsertPointGuard IPG(Builder);
6227 updateToLocation(Loc);
6228
6229 uint32_t SrcLocStrSize;
6230 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
6231 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
6232 Value *ThreadId = getOrCreateThreadID(Ident);
6233 if (Device == nullptr)
6234 Device = Constant::getAllOnesValue(Int32);
6235 Constant *InteropTypeVal = ConstantInt::get(Int32, (int)InteropType);
6236 if (NumDependences == nullptr) {
6237 NumDependences = ConstantInt::get(Int32, 0);
6238 PointerType *PointerTypeVar = PointerType::getUnqual(M.getContext());
6239 DependenceAddress = ConstantPointerNull::get(PointerTypeVar);
6240 }
6241 Value *HaveNowaitClauseVal = ConstantInt::get(Int32, HaveNowaitClause);
6242 Value *Args[] = {
6243 Ident, ThreadId, InteropVar, InteropTypeVal,
6244 Device, NumDependences, DependenceAddress, HaveNowaitClauseVal};
6245
6246 Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___tgt_interop_init);
6247
6248 return Builder.CreateCall(Fn, Args);
6249 }
6250
createOMPInteropDestroy(const LocationDescription & Loc,Value * InteropVar,Value * Device,Value * NumDependences,Value * DependenceAddress,bool HaveNowaitClause)6251 CallInst *OpenMPIRBuilder::createOMPInteropDestroy(
6252 const LocationDescription &Loc, Value *InteropVar, Value *Device,
6253 Value *NumDependences, Value *DependenceAddress, bool HaveNowaitClause) {
6254 IRBuilder<>::InsertPointGuard IPG(Builder);
6255 updateToLocation(Loc);
6256
6257 uint32_t SrcLocStrSize;
6258 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
6259 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
6260 Value *ThreadId = getOrCreateThreadID(Ident);
6261 if (Device == nullptr)
6262 Device = Constant::getAllOnesValue(Int32);
6263 if (NumDependences == nullptr) {
6264 NumDependences = ConstantInt::get(Int32, 0);
6265 PointerType *PointerTypeVar = PointerType::getUnqual(M.getContext());
6266 DependenceAddress = ConstantPointerNull::get(PointerTypeVar);
6267 }
6268 Value *HaveNowaitClauseVal = ConstantInt::get(Int32, HaveNowaitClause);
6269 Value *Args[] = {
6270 Ident, ThreadId, InteropVar, Device,
6271 NumDependences, DependenceAddress, HaveNowaitClauseVal};
6272
6273 Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___tgt_interop_destroy);
6274
6275 return Builder.CreateCall(Fn, Args);
6276 }
6277
createOMPInteropUse(const LocationDescription & Loc,Value * InteropVar,Value * Device,Value * NumDependences,Value * DependenceAddress,bool HaveNowaitClause)6278 CallInst *OpenMPIRBuilder::createOMPInteropUse(const LocationDescription &Loc,
6279 Value *InteropVar, Value *Device,
6280 Value *NumDependences,
6281 Value *DependenceAddress,
6282 bool HaveNowaitClause) {
6283 IRBuilder<>::InsertPointGuard IPG(Builder);
6284 updateToLocation(Loc);
6285 uint32_t SrcLocStrSize;
6286 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
6287 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
6288 Value *ThreadId = getOrCreateThreadID(Ident);
6289 if (Device == nullptr)
6290 Device = Constant::getAllOnesValue(Int32);
6291 if (NumDependences == nullptr) {
6292 NumDependences = ConstantInt::get(Int32, 0);
6293 PointerType *PointerTypeVar = PointerType::getUnqual(M.getContext());
6294 DependenceAddress = ConstantPointerNull::get(PointerTypeVar);
6295 }
6296 Value *HaveNowaitClauseVal = ConstantInt::get(Int32, HaveNowaitClause);
6297 Value *Args[] = {
6298 Ident, ThreadId, InteropVar, Device,
6299 NumDependences, DependenceAddress, HaveNowaitClauseVal};
6300
6301 Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___tgt_interop_use);
6302
6303 return Builder.CreateCall(Fn, Args);
6304 }
6305
createCachedThreadPrivate(const LocationDescription & Loc,llvm::Value * Pointer,llvm::ConstantInt * Size,const llvm::Twine & Name)6306 CallInst *OpenMPIRBuilder::createCachedThreadPrivate(
6307 const LocationDescription &Loc, llvm::Value *Pointer,
6308 llvm::ConstantInt *Size, const llvm::Twine &Name) {
6309 IRBuilder<>::InsertPointGuard IPG(Builder);
6310 updateToLocation(Loc);
6311
6312 uint32_t SrcLocStrSize;
6313 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
6314 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
6315 Value *ThreadId = getOrCreateThreadID(Ident);
6316 Constant *ThreadPrivateCache =
6317 getOrCreateInternalVariable(Int8PtrPtr, Name.str());
6318 llvm::Value *Args[] = {Ident, ThreadId, Pointer, Size, ThreadPrivateCache};
6319
6320 Function *Fn =
6321 getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_threadprivate_cached);
6322
6323 return Builder.CreateCall(Fn, Args);
6324 }
6325
createTargetInit(const LocationDescription & Loc,const llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs & Attrs)6326 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetInit(
6327 const LocationDescription &Loc,
6328 const llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &Attrs) {
6329 assert(!Attrs.MaxThreads.empty() && !Attrs.MaxTeams.empty() &&
6330 "expected num_threads and num_teams to be specified");
6331
6332 if (!updateToLocation(Loc))
6333 return Loc.IP;
6334
6335 uint32_t SrcLocStrSize;
6336 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
6337 Constant *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
6338 Constant *IsSPMDVal = ConstantInt::getSigned(Int8, Attrs.ExecFlags);
6339 Constant *UseGenericStateMachineVal = ConstantInt::getSigned(
6340 Int8, Attrs.ExecFlags != omp::OMP_TGT_EXEC_MODE_SPMD);
6341 Constant *MayUseNestedParallelismVal = ConstantInt::getSigned(Int8, true);
6342 Constant *DebugIndentionLevelVal = ConstantInt::getSigned(Int16, 0);
6343
6344 Function *DebugKernelWrapper = Builder.GetInsertBlock()->getParent();
6345 Function *Kernel = DebugKernelWrapper;
6346
6347 // We need to strip the debug prefix to get the correct kernel name.
6348 StringRef KernelName = Kernel->getName();
6349 const std::string DebugPrefix = "_debug__";
6350 if (KernelName.ends_with(DebugPrefix)) {
6351 KernelName = KernelName.drop_back(DebugPrefix.length());
6352 Kernel = M.getFunction(KernelName);
6353 assert(Kernel && "Expected the real kernel to exist");
6354 }
6355
6356 // Manifest the launch configuration in the metadata matching the kernel
6357 // environment.
6358 if (Attrs.MinTeams > 1 || Attrs.MaxTeams.front() > 0)
6359 writeTeamsForKernel(T, *Kernel, Attrs.MinTeams, Attrs.MaxTeams.front());
6360
6361 // If MaxThreads not set, select the maximum between the default workgroup
6362 // size and the MinThreads value.
6363 int32_t MaxThreadsVal = Attrs.MaxThreads.front();
6364 if (MaxThreadsVal < 0)
6365 MaxThreadsVal = std::max(
6366 int32_t(getGridValue(T, Kernel).GV_Default_WG_Size), Attrs.MinThreads);
6367
6368 if (MaxThreadsVal > 0)
6369 writeThreadBoundsForKernel(T, *Kernel, Attrs.MinThreads, MaxThreadsVal);
6370
6371 Constant *MinThreads = ConstantInt::getSigned(Int32, Attrs.MinThreads);
6372 Constant *MaxThreads = ConstantInt::getSigned(Int32, MaxThreadsVal);
6373 Constant *MinTeams = ConstantInt::getSigned(Int32, Attrs.MinTeams);
6374 Constant *MaxTeams = ConstantInt::getSigned(Int32, Attrs.MaxTeams.front());
6375 Constant *ReductionDataSize =
6376 ConstantInt::getSigned(Int32, Attrs.ReductionDataSize);
6377 Constant *ReductionBufferLength =
6378 ConstantInt::getSigned(Int32, Attrs.ReductionBufferLength);
6379
6380 Function *Fn = getOrCreateRuntimeFunctionPtr(
6381 omp::RuntimeFunction::OMPRTL___kmpc_target_init);
6382 const DataLayout &DL = Fn->getDataLayout();
6383
6384 Twine DynamicEnvironmentName = KernelName + "_dynamic_environment";
6385 Constant *DynamicEnvironmentInitializer =
6386 ConstantStruct::get(DynamicEnvironment, {DebugIndentionLevelVal});
6387 GlobalVariable *DynamicEnvironmentGV = new GlobalVariable(
6388 M, DynamicEnvironment, /*IsConstant=*/false, GlobalValue::WeakODRLinkage,
6389 DynamicEnvironmentInitializer, DynamicEnvironmentName,
6390 /*InsertBefore=*/nullptr, GlobalValue::NotThreadLocal,
6391 DL.getDefaultGlobalsAddressSpace());
6392 DynamicEnvironmentGV->setVisibility(GlobalValue::ProtectedVisibility);
6393
6394 Constant *DynamicEnvironment =
6395 DynamicEnvironmentGV->getType() == DynamicEnvironmentPtr
6396 ? DynamicEnvironmentGV
6397 : ConstantExpr::getAddrSpaceCast(DynamicEnvironmentGV,
6398 DynamicEnvironmentPtr);
6399
6400 Constant *ConfigurationEnvironmentInitializer = ConstantStruct::get(
6401 ConfigurationEnvironment, {
6402 UseGenericStateMachineVal,
6403 MayUseNestedParallelismVal,
6404 IsSPMDVal,
6405 MinThreads,
6406 MaxThreads,
6407 MinTeams,
6408 MaxTeams,
6409 ReductionDataSize,
6410 ReductionBufferLength,
6411 });
6412 Constant *KernelEnvironmentInitializer = ConstantStruct::get(
6413 KernelEnvironment, {
6414 ConfigurationEnvironmentInitializer,
6415 Ident,
6416 DynamicEnvironment,
6417 });
6418 std::string KernelEnvironmentName =
6419 (KernelName + "_kernel_environment").str();
6420 GlobalVariable *KernelEnvironmentGV = new GlobalVariable(
6421 M, KernelEnvironment, /*IsConstant=*/true, GlobalValue::WeakODRLinkage,
6422 KernelEnvironmentInitializer, KernelEnvironmentName,
6423 /*InsertBefore=*/nullptr, GlobalValue::NotThreadLocal,
6424 DL.getDefaultGlobalsAddressSpace());
6425 KernelEnvironmentGV->setVisibility(GlobalValue::ProtectedVisibility);
6426
6427 Constant *KernelEnvironment =
6428 KernelEnvironmentGV->getType() == KernelEnvironmentPtr
6429 ? KernelEnvironmentGV
6430 : ConstantExpr::getAddrSpaceCast(KernelEnvironmentGV,
6431 KernelEnvironmentPtr);
6432 Value *KernelLaunchEnvironment = DebugKernelWrapper->getArg(0);
6433 Type *KernelLaunchEnvParamTy = Fn->getFunctionType()->getParamType(1);
6434 KernelLaunchEnvironment =
6435 KernelLaunchEnvironment->getType() == KernelLaunchEnvParamTy
6436 ? KernelLaunchEnvironment
6437 : Builder.CreateAddrSpaceCast(KernelLaunchEnvironment,
6438 KernelLaunchEnvParamTy);
6439 CallInst *ThreadKind =
6440 Builder.CreateCall(Fn, {KernelEnvironment, KernelLaunchEnvironment});
6441
6442 Value *ExecUserCode = Builder.CreateICmpEQ(
6443 ThreadKind, Constant::getAllOnesValue(ThreadKind->getType()),
6444 "exec_user_code");
6445
6446 // ThreadKind = __kmpc_target_init(...)
6447 // if (ThreadKind == -1)
6448 // user_code
6449 // else
6450 // return;
6451
6452 auto *UI = Builder.CreateUnreachable();
6453 BasicBlock *CheckBB = UI->getParent();
6454 BasicBlock *UserCodeEntryBB = CheckBB->splitBasicBlock(UI, "user_code.entry");
6455
6456 BasicBlock *WorkerExitBB = BasicBlock::Create(
6457 CheckBB->getContext(), "worker.exit", CheckBB->getParent());
6458 Builder.SetInsertPoint(WorkerExitBB);
6459 Builder.CreateRetVoid();
6460
6461 auto *CheckBBTI = CheckBB->getTerminator();
6462 Builder.SetInsertPoint(CheckBBTI);
6463 Builder.CreateCondBr(ExecUserCode, UI->getParent(), WorkerExitBB);
6464
6465 CheckBBTI->eraseFromParent();
6466 UI->eraseFromParent();
6467
6468 // Continue in the "user_code" block, see diagram above and in
6469 // openmp/libomptarget/deviceRTLs/common/include/target.h .
6470 return InsertPointTy(UserCodeEntryBB, UserCodeEntryBB->getFirstInsertionPt());
6471 }
6472
createTargetDeinit(const LocationDescription & Loc,int32_t TeamsReductionDataSize,int32_t TeamsReductionBufferLength)6473 void OpenMPIRBuilder::createTargetDeinit(const LocationDescription &Loc,
6474 int32_t TeamsReductionDataSize,
6475 int32_t TeamsReductionBufferLength) {
6476 if (!updateToLocation(Loc))
6477 return;
6478
6479 Function *Fn = getOrCreateRuntimeFunctionPtr(
6480 omp::RuntimeFunction::OMPRTL___kmpc_target_deinit);
6481
6482 Builder.CreateCall(Fn, {});
6483
6484 if (!TeamsReductionBufferLength || !TeamsReductionDataSize)
6485 return;
6486
6487 Function *Kernel = Builder.GetInsertBlock()->getParent();
6488 // We need to strip the debug prefix to get the correct kernel name.
6489 StringRef KernelName = Kernel->getName();
6490 const std::string DebugPrefix = "_debug__";
6491 if (KernelName.ends_with(DebugPrefix))
6492 KernelName = KernelName.drop_back(DebugPrefix.length());
6493 auto *KernelEnvironmentGV =
6494 M.getNamedGlobal((KernelName + "_kernel_environment").str());
6495 assert(KernelEnvironmentGV && "Expected kernel environment global\n");
6496 auto *KernelEnvironmentInitializer = KernelEnvironmentGV->getInitializer();
6497 auto *NewInitializer = ConstantFoldInsertValueInstruction(
6498 KernelEnvironmentInitializer,
6499 ConstantInt::get(Int32, TeamsReductionDataSize), {0, 7});
6500 NewInitializer = ConstantFoldInsertValueInstruction(
6501 NewInitializer, ConstantInt::get(Int32, TeamsReductionBufferLength),
6502 {0, 8});
6503 KernelEnvironmentGV->setInitializer(NewInitializer);
6504 }
6505
updateNVPTXAttr(Function & Kernel,StringRef Name,int32_t Value,bool Min)6506 static void updateNVPTXAttr(Function &Kernel, StringRef Name, int32_t Value,
6507 bool Min) {
6508 if (Kernel.hasFnAttribute(Name)) {
6509 int32_t OldLimit = Kernel.getFnAttributeAsParsedInteger(Name);
6510 Value = Min ? std::min(OldLimit, Value) : std::max(OldLimit, Value);
6511 }
6512 Kernel.addFnAttr(Name, llvm::utostr(Value));
6513 }
6514
6515 std::pair<int32_t, int32_t>
readThreadBoundsForKernel(const Triple & T,Function & Kernel)6516 OpenMPIRBuilder::readThreadBoundsForKernel(const Triple &T, Function &Kernel) {
6517 int32_t ThreadLimit =
6518 Kernel.getFnAttributeAsParsedInteger("omp_target_thread_limit");
6519
6520 if (T.isAMDGPU()) {
6521 const auto &Attr = Kernel.getFnAttribute("amdgpu-flat-work-group-size");
6522 if (!Attr.isValid() || !Attr.isStringAttribute())
6523 return {0, ThreadLimit};
6524 auto [LBStr, UBStr] = Attr.getValueAsString().split(',');
6525 int32_t LB, UB;
6526 if (!llvm::to_integer(UBStr, UB, 10))
6527 return {0, ThreadLimit};
6528 UB = ThreadLimit ? std::min(ThreadLimit, UB) : UB;
6529 if (!llvm::to_integer(LBStr, LB, 10))
6530 return {0, UB};
6531 return {LB, UB};
6532 }
6533
6534 if (Kernel.hasFnAttribute("nvvm.maxntid")) {
6535 int32_t UB = Kernel.getFnAttributeAsParsedInteger("nvvm.maxntid");
6536 return {0, ThreadLimit ? std::min(ThreadLimit, UB) : UB};
6537 }
6538 return {0, ThreadLimit};
6539 }
6540
writeThreadBoundsForKernel(const Triple & T,Function & Kernel,int32_t LB,int32_t UB)6541 void OpenMPIRBuilder::writeThreadBoundsForKernel(const Triple &T,
6542 Function &Kernel, int32_t LB,
6543 int32_t UB) {
6544 Kernel.addFnAttr("omp_target_thread_limit", std::to_string(UB));
6545
6546 if (T.isAMDGPU()) {
6547 Kernel.addFnAttr("amdgpu-flat-work-group-size",
6548 llvm::utostr(LB) + "," + llvm::utostr(UB));
6549 return;
6550 }
6551
6552 updateNVPTXAttr(Kernel, "nvvm.maxntid", UB, true);
6553 }
6554
6555 std::pair<int32_t, int32_t>
readTeamBoundsForKernel(const Triple &,Function & Kernel)6556 OpenMPIRBuilder::readTeamBoundsForKernel(const Triple &, Function &Kernel) {
6557 // TODO: Read from backend annotations if available.
6558 return {0, Kernel.getFnAttributeAsParsedInteger("omp_target_num_teams")};
6559 }
6560
writeTeamsForKernel(const Triple & T,Function & Kernel,int32_t LB,int32_t UB)6561 void OpenMPIRBuilder::writeTeamsForKernel(const Triple &T, Function &Kernel,
6562 int32_t LB, int32_t UB) {
6563 if (T.isNVPTX())
6564 if (UB > 0)
6565 Kernel.addFnAttr("nvvm.maxclusterrank", llvm::utostr(UB));
6566 if (T.isAMDGPU())
6567 Kernel.addFnAttr("amdgpu-max-num-workgroups", llvm::utostr(LB) + ",1,1");
6568
6569 Kernel.addFnAttr("omp_target_num_teams", std::to_string(LB));
6570 }
6571
setOutlinedTargetRegionFunctionAttributes(Function * OutlinedFn)6572 void OpenMPIRBuilder::setOutlinedTargetRegionFunctionAttributes(
6573 Function *OutlinedFn) {
6574 if (Config.isTargetDevice()) {
6575 OutlinedFn->setLinkage(GlobalValue::WeakODRLinkage);
6576 // TODO: Determine if DSO local can be set to true.
6577 OutlinedFn->setDSOLocal(false);
6578 OutlinedFn->setVisibility(GlobalValue::ProtectedVisibility);
6579 if (T.isAMDGCN())
6580 OutlinedFn->setCallingConv(CallingConv::AMDGPU_KERNEL);
6581 else if (T.isNVPTX())
6582 OutlinedFn->setCallingConv(CallingConv::PTX_Kernel);
6583 else if (T.isSPIRV())
6584 OutlinedFn->setCallingConv(CallingConv::SPIR_KERNEL);
6585 }
6586 }
6587
createOutlinedFunctionID(Function * OutlinedFn,StringRef EntryFnIDName)6588 Constant *OpenMPIRBuilder::createOutlinedFunctionID(Function *OutlinedFn,
6589 StringRef EntryFnIDName) {
6590 if (Config.isTargetDevice()) {
6591 assert(OutlinedFn && "The outlined function must exist if embedded");
6592 return OutlinedFn;
6593 }
6594
6595 return new GlobalVariable(
6596 M, Builder.getInt8Ty(), /*isConstant=*/true, GlobalValue::WeakAnyLinkage,
6597 Constant::getNullValue(Builder.getInt8Ty()), EntryFnIDName);
6598 }
6599
createTargetRegionEntryAddr(Function * OutlinedFn,StringRef EntryFnName)6600 Constant *OpenMPIRBuilder::createTargetRegionEntryAddr(Function *OutlinedFn,
6601 StringRef EntryFnName) {
6602 if (OutlinedFn)
6603 return OutlinedFn;
6604
6605 assert(!M.getGlobalVariable(EntryFnName, true) &&
6606 "Named kernel already exists?");
6607 return new GlobalVariable(
6608 M, Builder.getInt8Ty(), /*isConstant=*/true, GlobalValue::InternalLinkage,
6609 Constant::getNullValue(Builder.getInt8Ty()), EntryFnName);
6610 }
6611
emitTargetRegionFunction(TargetRegionEntryInfo & EntryInfo,FunctionGenCallback & GenerateFunctionCallback,bool IsOffloadEntry,Function * & OutlinedFn,Constant * & OutlinedFnID)6612 Error OpenMPIRBuilder::emitTargetRegionFunction(
6613 TargetRegionEntryInfo &EntryInfo,
6614 FunctionGenCallback &GenerateFunctionCallback, bool IsOffloadEntry,
6615 Function *&OutlinedFn, Constant *&OutlinedFnID) {
6616
6617 SmallString<64> EntryFnName;
6618 OffloadInfoManager.getTargetRegionEntryFnName(EntryFnName, EntryInfo);
6619
6620 if (Config.isTargetDevice() || !Config.openMPOffloadMandatory()) {
6621 Expected<Function *> CBResult = GenerateFunctionCallback(EntryFnName);
6622 if (!CBResult)
6623 return CBResult.takeError();
6624 OutlinedFn = *CBResult;
6625 } else {
6626 OutlinedFn = nullptr;
6627 }
6628
6629 // If this target outline function is not an offload entry, we don't need to
6630 // register it. This may be in the case of a false if clause, or if there are
6631 // no OpenMP targets.
6632 if (!IsOffloadEntry)
6633 return Error::success();
6634
6635 std::string EntryFnIDName =
6636 Config.isTargetDevice()
6637 ? std::string(EntryFnName)
6638 : createPlatformSpecificName({EntryFnName, "region_id"});
6639
6640 OutlinedFnID = registerTargetRegionFunction(EntryInfo, OutlinedFn,
6641 EntryFnName, EntryFnIDName);
6642 return Error::success();
6643 }
6644
registerTargetRegionFunction(TargetRegionEntryInfo & EntryInfo,Function * OutlinedFn,StringRef EntryFnName,StringRef EntryFnIDName)6645 Constant *OpenMPIRBuilder::registerTargetRegionFunction(
6646 TargetRegionEntryInfo &EntryInfo, Function *OutlinedFn,
6647 StringRef EntryFnName, StringRef EntryFnIDName) {
6648 if (OutlinedFn)
6649 setOutlinedTargetRegionFunctionAttributes(OutlinedFn);
6650 auto OutlinedFnID = createOutlinedFunctionID(OutlinedFn, EntryFnIDName);
6651 auto EntryAddr = createTargetRegionEntryAddr(OutlinedFn, EntryFnName);
6652 OffloadInfoManager.registerTargetRegionEntryInfo(
6653 EntryInfo, EntryAddr, OutlinedFnID,
6654 OffloadEntriesInfoManager::OMPTargetRegionEntryTargetRegion);
6655 return OutlinedFnID;
6656 }
6657
createTargetData(const LocationDescription & Loc,InsertPointTy AllocaIP,InsertPointTy CodeGenIP,Value * DeviceID,Value * IfCond,TargetDataInfo & Info,GenMapInfoCallbackTy GenMapInfoCB,CustomMapperCallbackTy CustomMapperCB,omp::RuntimeFunction * MapperFunc,function_ref<InsertPointOrErrorTy (InsertPointTy CodeGenIP,BodyGenTy BodyGenType)> BodyGenCB,function_ref<void (unsigned int,Value *)> DeviceAddrCB,Value * SrcLocInfo)6658 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTargetData(
6659 const LocationDescription &Loc, InsertPointTy AllocaIP,
6660 InsertPointTy CodeGenIP, Value *DeviceID, Value *IfCond,
6661 TargetDataInfo &Info, GenMapInfoCallbackTy GenMapInfoCB,
6662 CustomMapperCallbackTy CustomMapperCB, omp::RuntimeFunction *MapperFunc,
6663 function_ref<InsertPointOrErrorTy(InsertPointTy CodeGenIP,
6664 BodyGenTy BodyGenType)>
6665 BodyGenCB,
6666 function_ref<void(unsigned int, Value *)> DeviceAddrCB, Value *SrcLocInfo) {
6667 if (!updateToLocation(Loc))
6668 return InsertPointTy();
6669
6670 Builder.restoreIP(CodeGenIP);
6671 // Disable TargetData CodeGen on Device pass.
6672 if (Config.IsTargetDevice.value_or(false)) {
6673 if (BodyGenCB) {
6674 InsertPointOrErrorTy AfterIP =
6675 BodyGenCB(Builder.saveIP(), BodyGenTy::NoPriv);
6676 if (!AfterIP)
6677 return AfterIP.takeError();
6678 Builder.restoreIP(*AfterIP);
6679 }
6680 return Builder.saveIP();
6681 }
6682
6683 bool IsStandAlone = !BodyGenCB;
6684 MapInfosTy *MapInfo;
6685 // Generate the code for the opening of the data environment. Capture all the
6686 // arguments of the runtime call by reference because they are used in the
6687 // closing of the region.
6688 auto BeginThenGen = [&](InsertPointTy AllocaIP,
6689 InsertPointTy CodeGenIP) -> Error {
6690 MapInfo = &GenMapInfoCB(Builder.saveIP());
6691 if (Error Err = emitOffloadingArrays(
6692 AllocaIP, Builder.saveIP(), *MapInfo, Info, CustomMapperCB,
6693 /*IsNonContiguous=*/true, DeviceAddrCB))
6694 return Err;
6695
6696 TargetDataRTArgs RTArgs;
6697 emitOffloadingArraysArgument(Builder, RTArgs, Info);
6698
6699 // Emit the number of elements in the offloading arrays.
6700 Value *PointerNum = Builder.getInt32(Info.NumberOfPtrs);
6701
6702 // Source location for the ident struct
6703 if (!SrcLocInfo) {
6704 uint32_t SrcLocStrSize;
6705 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
6706 SrcLocInfo = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
6707 }
6708
6709 SmallVector<llvm::Value *, 13> OffloadingArgs = {
6710 SrcLocInfo, DeviceID,
6711 PointerNum, RTArgs.BasePointersArray,
6712 RTArgs.PointersArray, RTArgs.SizesArray,
6713 RTArgs.MapTypesArray, RTArgs.MapNamesArray,
6714 RTArgs.MappersArray};
6715
6716 if (IsStandAlone) {
6717 assert(MapperFunc && "MapperFunc missing for standalone target data");
6718
6719 auto TaskBodyCB = [&](Value *, Value *,
6720 IRBuilderBase::InsertPoint) -> Error {
6721 if (Info.HasNoWait) {
6722 OffloadingArgs.append({llvm::Constant::getNullValue(Int32),
6723 llvm::Constant::getNullValue(VoidPtr),
6724 llvm::Constant::getNullValue(Int32),
6725 llvm::Constant::getNullValue(VoidPtr)});
6726 }
6727
6728 Builder.CreateCall(getOrCreateRuntimeFunctionPtr(*MapperFunc),
6729 OffloadingArgs);
6730
6731 if (Info.HasNoWait) {
6732 BasicBlock *OffloadContBlock =
6733 BasicBlock::Create(Builder.getContext(), "omp_offload.cont");
6734 Function *CurFn = Builder.GetInsertBlock()->getParent();
6735 emitBlock(OffloadContBlock, CurFn, /*IsFinished=*/true);
6736 Builder.restoreIP(Builder.saveIP());
6737 }
6738 return Error::success();
6739 };
6740
6741 bool RequiresOuterTargetTask = Info.HasNoWait;
6742 if (!RequiresOuterTargetTask)
6743 cantFail(TaskBodyCB(/*DeviceID=*/nullptr, /*RTLoc=*/nullptr,
6744 /*TargetTaskAllocaIP=*/{}));
6745 else
6746 cantFail(emitTargetTask(TaskBodyCB, DeviceID, SrcLocInfo, AllocaIP,
6747 /*Dependencies=*/{}, RTArgs, Info.HasNoWait));
6748 } else {
6749 Function *BeginMapperFunc = getOrCreateRuntimeFunctionPtr(
6750 omp::OMPRTL___tgt_target_data_begin_mapper);
6751
6752 Builder.CreateCall(BeginMapperFunc, OffloadingArgs);
6753
6754 for (auto DeviceMap : Info.DevicePtrInfoMap) {
6755 if (isa<AllocaInst>(DeviceMap.second.second)) {
6756 auto *LI =
6757 Builder.CreateLoad(Builder.getPtrTy(), DeviceMap.second.first);
6758 Builder.CreateStore(LI, DeviceMap.second.second);
6759 }
6760 }
6761
6762 // If device pointer privatization is required, emit the body of the
6763 // region here. It will have to be duplicated: with and without
6764 // privatization.
6765 InsertPointOrErrorTy AfterIP =
6766 BodyGenCB(Builder.saveIP(), BodyGenTy::Priv);
6767 if (!AfterIP)
6768 return AfterIP.takeError();
6769 Builder.restoreIP(*AfterIP);
6770 }
6771 return Error::success();
6772 };
6773
6774 // If we need device pointer privatization, we need to emit the body of the
6775 // region with no privatization in the 'else' branch of the conditional.
6776 // Otherwise, we don't have to do anything.
6777 auto BeginElseGen = [&](InsertPointTy AllocaIP,
6778 InsertPointTy CodeGenIP) -> Error {
6779 InsertPointOrErrorTy AfterIP =
6780 BodyGenCB(Builder.saveIP(), BodyGenTy::DupNoPriv);
6781 if (!AfterIP)
6782 return AfterIP.takeError();
6783 Builder.restoreIP(*AfterIP);
6784 return Error::success();
6785 };
6786
6787 // Generate code for the closing of the data region.
6788 auto EndThenGen = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
6789 TargetDataRTArgs RTArgs;
6790 Info.EmitDebug = !MapInfo->Names.empty();
6791 emitOffloadingArraysArgument(Builder, RTArgs, Info, /*ForEndCall=*/true);
6792
6793 // Emit the number of elements in the offloading arrays.
6794 Value *PointerNum = Builder.getInt32(Info.NumberOfPtrs);
6795
6796 // Source location for the ident struct
6797 if (!SrcLocInfo) {
6798 uint32_t SrcLocStrSize;
6799 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
6800 SrcLocInfo = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
6801 }
6802
6803 Value *OffloadingArgs[] = {SrcLocInfo, DeviceID,
6804 PointerNum, RTArgs.BasePointersArray,
6805 RTArgs.PointersArray, RTArgs.SizesArray,
6806 RTArgs.MapTypesArray, RTArgs.MapNamesArray,
6807 RTArgs.MappersArray};
6808 Function *EndMapperFunc =
6809 getOrCreateRuntimeFunctionPtr(omp::OMPRTL___tgt_target_data_end_mapper);
6810
6811 Builder.CreateCall(EndMapperFunc, OffloadingArgs);
6812 return Error::success();
6813 };
6814
6815 // We don't have to do anything to close the region if the if clause evaluates
6816 // to false.
6817 auto EndElseGen = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
6818 return Error::success();
6819 };
6820
6821 Error Err = [&]() -> Error {
6822 if (BodyGenCB) {
6823 Error Err = [&]() {
6824 if (IfCond)
6825 return emitIfClause(IfCond, BeginThenGen, BeginElseGen, AllocaIP);
6826 return BeginThenGen(AllocaIP, Builder.saveIP());
6827 }();
6828
6829 if (Err)
6830 return Err;
6831
6832 // If we don't require privatization of device pointers, we emit the body
6833 // in between the runtime calls. This avoids duplicating the body code.
6834 InsertPointOrErrorTy AfterIP =
6835 BodyGenCB(Builder.saveIP(), BodyGenTy::NoPriv);
6836 if (!AfterIP)
6837 return AfterIP.takeError();
6838 Builder.restoreIP(*AfterIP);
6839
6840 if (IfCond)
6841 return emitIfClause(IfCond, EndThenGen, EndElseGen, AllocaIP);
6842 return EndThenGen(AllocaIP, Builder.saveIP());
6843 }
6844 if (IfCond)
6845 return emitIfClause(IfCond, BeginThenGen, EndElseGen, AllocaIP);
6846 return BeginThenGen(AllocaIP, Builder.saveIP());
6847 }();
6848
6849 if (Err)
6850 return Err;
6851
6852 return Builder.saveIP();
6853 }
6854
6855 FunctionCallee
createForStaticInitFunction(unsigned IVSize,bool IVSigned,bool IsGPUDistribute)6856 OpenMPIRBuilder::createForStaticInitFunction(unsigned IVSize, bool IVSigned,
6857 bool IsGPUDistribute) {
6858 assert((IVSize == 32 || IVSize == 64) &&
6859 "IV size is not compatible with the omp runtime");
6860 RuntimeFunction Name;
6861 if (IsGPUDistribute)
6862 Name = IVSize == 32
6863 ? (IVSigned ? omp::OMPRTL___kmpc_distribute_static_init_4
6864 : omp::OMPRTL___kmpc_distribute_static_init_4u)
6865 : (IVSigned ? omp::OMPRTL___kmpc_distribute_static_init_8
6866 : omp::OMPRTL___kmpc_distribute_static_init_8u);
6867 else
6868 Name = IVSize == 32 ? (IVSigned ? omp::OMPRTL___kmpc_for_static_init_4
6869 : omp::OMPRTL___kmpc_for_static_init_4u)
6870 : (IVSigned ? omp::OMPRTL___kmpc_for_static_init_8
6871 : omp::OMPRTL___kmpc_for_static_init_8u);
6872
6873 return getOrCreateRuntimeFunction(M, Name);
6874 }
6875
createDispatchInitFunction(unsigned IVSize,bool IVSigned)6876 FunctionCallee OpenMPIRBuilder::createDispatchInitFunction(unsigned IVSize,
6877 bool IVSigned) {
6878 assert((IVSize == 32 || IVSize == 64) &&
6879 "IV size is not compatible with the omp runtime");
6880 RuntimeFunction Name = IVSize == 32
6881 ? (IVSigned ? omp::OMPRTL___kmpc_dispatch_init_4
6882 : omp::OMPRTL___kmpc_dispatch_init_4u)
6883 : (IVSigned ? omp::OMPRTL___kmpc_dispatch_init_8
6884 : omp::OMPRTL___kmpc_dispatch_init_8u);
6885
6886 return getOrCreateRuntimeFunction(M, Name);
6887 }
6888
createDispatchNextFunction(unsigned IVSize,bool IVSigned)6889 FunctionCallee OpenMPIRBuilder::createDispatchNextFunction(unsigned IVSize,
6890 bool IVSigned) {
6891 assert((IVSize == 32 || IVSize == 64) &&
6892 "IV size is not compatible with the omp runtime");
6893 RuntimeFunction Name = IVSize == 32
6894 ? (IVSigned ? omp::OMPRTL___kmpc_dispatch_next_4
6895 : omp::OMPRTL___kmpc_dispatch_next_4u)
6896 : (IVSigned ? omp::OMPRTL___kmpc_dispatch_next_8
6897 : omp::OMPRTL___kmpc_dispatch_next_8u);
6898
6899 return getOrCreateRuntimeFunction(M, Name);
6900 }
6901
createDispatchFiniFunction(unsigned IVSize,bool IVSigned)6902 FunctionCallee OpenMPIRBuilder::createDispatchFiniFunction(unsigned IVSize,
6903 bool IVSigned) {
6904 assert((IVSize == 32 || IVSize == 64) &&
6905 "IV size is not compatible with the omp runtime");
6906 RuntimeFunction Name = IVSize == 32
6907 ? (IVSigned ? omp::OMPRTL___kmpc_dispatch_fini_4
6908 : omp::OMPRTL___kmpc_dispatch_fini_4u)
6909 : (IVSigned ? omp::OMPRTL___kmpc_dispatch_fini_8
6910 : omp::OMPRTL___kmpc_dispatch_fini_8u);
6911
6912 return getOrCreateRuntimeFunction(M, Name);
6913 }
6914
createDispatchDeinitFunction()6915 FunctionCallee OpenMPIRBuilder::createDispatchDeinitFunction() {
6916 return getOrCreateRuntimeFunction(M, omp::OMPRTL___kmpc_dispatch_deinit);
6917 }
6918
FixupDebugInfoForOutlinedFunction(OpenMPIRBuilder & OMPBuilder,IRBuilderBase & Builder,Function * Func,DenseMap<Value *,std::tuple<Value *,unsigned>> & ValueReplacementMap)6919 static void FixupDebugInfoForOutlinedFunction(
6920 OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, Function *Func,
6921 DenseMap<Value *, std::tuple<Value *, unsigned>> &ValueReplacementMap) {
6922
6923 DISubprogram *NewSP = Func->getSubprogram();
6924 if (!NewSP)
6925 return;
6926
6927 SmallDenseMap<DILocalVariable *, DILocalVariable *> RemappedVariables;
6928
6929 auto GetUpdatedDIVariable = [&](DILocalVariable *OldVar, unsigned arg) {
6930 DILocalVariable *&NewVar = RemappedVariables[OldVar];
6931 // Only use cached variable if the arg number matches. This is important
6932 // so that DIVariable created for privatized variables are not discarded.
6933 if (NewVar && (arg == NewVar->getArg()))
6934 return NewVar;
6935
6936 NewVar = llvm::DILocalVariable::get(
6937 Builder.getContext(), OldVar->getScope(), OldVar->getName(),
6938 OldVar->getFile(), OldVar->getLine(), OldVar->getType(), arg,
6939 OldVar->getFlags(), OldVar->getAlignInBits(), OldVar->getAnnotations());
6940 return NewVar;
6941 };
6942
6943 auto UpdateDebugRecord = [&](auto *DR) {
6944 DILocalVariable *OldVar = DR->getVariable();
6945 unsigned ArgNo = 0;
6946 for (auto Loc : DR->location_ops()) {
6947 auto Iter = ValueReplacementMap.find(Loc);
6948 if (Iter != ValueReplacementMap.end()) {
6949 DR->replaceVariableLocationOp(Loc, std::get<0>(Iter->second));
6950 ArgNo = std::get<1>(Iter->second) + 1;
6951 }
6952 }
6953 if (ArgNo != 0)
6954 DR->setVariable(GetUpdatedDIVariable(OldVar, ArgNo));
6955 };
6956
6957 // The location and scope of variable intrinsics and records still point to
6958 // the parent function of the target region. Update them.
6959 for (Instruction &I : instructions(Func)) {
6960 if (auto *DDI = dyn_cast<llvm::DbgVariableIntrinsic>(&I))
6961 UpdateDebugRecord(DDI);
6962
6963 for (DbgVariableRecord &DVR : filterDbgVars(I.getDbgRecordRange()))
6964 UpdateDebugRecord(&DVR);
6965 }
6966 // An extra argument is passed to the device. Create the debug data for it.
6967 if (OMPBuilder.Config.isTargetDevice()) {
6968 DICompileUnit *CU = NewSP->getUnit();
6969 Module *M = Func->getParent();
6970 DIBuilder DB(*M, true, CU);
6971 DIType *VoidPtrTy =
6972 DB.createQualifiedType(dwarf::DW_TAG_pointer_type, nullptr);
6973 DILocalVariable *Var = DB.createParameterVariable(
6974 NewSP, "dyn_ptr", /*ArgNo*/ 1, NewSP->getFile(), /*LineNo=*/0,
6975 VoidPtrTy, /*AlwaysPreserve=*/false, DINode::DIFlags::FlagArtificial);
6976 auto Loc = DILocation::get(Func->getContext(), 0, 0, NewSP, 0);
6977 DB.insertDeclare(&(*Func->arg_begin()), Var, DB.createExpression(), Loc,
6978 &(*Func->begin()));
6979 }
6980 }
6981
createOutlinedFunction(OpenMPIRBuilder & OMPBuilder,IRBuilderBase & Builder,const OpenMPIRBuilder::TargetKernelDefaultAttrs & DefaultAttrs,StringRef FuncName,SmallVectorImpl<Value * > & Inputs,OpenMPIRBuilder::TargetBodyGenCallbackTy & CBFunc,OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy & ArgAccessorFuncCB)6982 static Expected<Function *> createOutlinedFunction(
6983 OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
6984 const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
6985 StringRef FuncName, SmallVectorImpl<Value *> &Inputs,
6986 OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
6987 OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy &ArgAccessorFuncCB) {
6988 SmallVector<Type *> ParameterTypes;
6989 if (OMPBuilder.Config.isTargetDevice()) {
6990 // Add the "implicit" runtime argument we use to provide launch specific
6991 // information for target devices.
6992 auto *Int8PtrTy = PointerType::getUnqual(Builder.getContext());
6993 ParameterTypes.push_back(Int8PtrTy);
6994
6995 // All parameters to target devices are passed as pointers
6996 // or i64. This assumes 64-bit address spaces/pointers.
6997 for (auto &Arg : Inputs)
6998 ParameterTypes.push_back(Arg->getType()->isPointerTy()
6999 ? Arg->getType()
7000 : Type::getInt64Ty(Builder.getContext()));
7001 } else {
7002 for (auto &Arg : Inputs)
7003 ParameterTypes.push_back(Arg->getType());
7004 }
7005
7006 auto BB = Builder.GetInsertBlock();
7007 auto M = BB->getModule();
7008 auto FuncType = FunctionType::get(Builder.getVoidTy(), ParameterTypes,
7009 /*isVarArg*/ false);
7010 auto Func =
7011 Function::Create(FuncType, GlobalValue::InternalLinkage, FuncName, M);
7012
7013 // Forward target-cpu and target-features function attributes from the
7014 // original function to the new outlined function.
7015 Function *ParentFn = Builder.GetInsertBlock()->getParent();
7016
7017 auto TargetCpuAttr = ParentFn->getFnAttribute("target-cpu");
7018 if (TargetCpuAttr.isStringAttribute())
7019 Func->addFnAttr(TargetCpuAttr);
7020
7021 auto TargetFeaturesAttr = ParentFn->getFnAttribute("target-features");
7022 if (TargetFeaturesAttr.isStringAttribute())
7023 Func->addFnAttr(TargetFeaturesAttr);
7024
7025 if (OMPBuilder.Config.isTargetDevice()) {
7026 Value *ExecMode =
7027 OMPBuilder.emitKernelExecutionMode(FuncName, DefaultAttrs.ExecFlags);
7028 OMPBuilder.emitUsed("llvm.compiler.used", {ExecMode});
7029 }
7030
7031 // Save insert point.
7032 IRBuilder<>::InsertPointGuard IPG(Builder);
7033 // We will generate the entries in the outlined function but the debug
7034 // location may still be pointing to the parent function. Reset it now.
7035 Builder.SetCurrentDebugLocation(llvm::DebugLoc());
7036
7037 // Generate the region into the function.
7038 BasicBlock *EntryBB = BasicBlock::Create(Builder.getContext(), "entry", Func);
7039 Builder.SetInsertPoint(EntryBB);
7040
7041 // Insert target init call in the device compilation pass.
7042 if (OMPBuilder.Config.isTargetDevice())
7043 Builder.restoreIP(OMPBuilder.createTargetInit(Builder, DefaultAttrs));
7044
7045 BasicBlock *UserCodeEntryBB = Builder.GetInsertBlock();
7046
7047 // As we embed the user code in the middle of our target region after we
7048 // generate entry code, we must move what allocas we can into the entry
7049 // block to avoid possible breaking optimisations for device
7050 if (OMPBuilder.Config.isTargetDevice())
7051 OMPBuilder.ConstantAllocaRaiseCandidates.emplace_back(Func);
7052
7053 // Insert target deinit call in the device compilation pass.
7054 BasicBlock *OutlinedBodyBB =
7055 splitBB(Builder, /*CreateBranch=*/true, "outlined.body");
7056 llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = CBFunc(
7057 Builder.saveIP(),
7058 OpenMPIRBuilder::InsertPointTy(OutlinedBodyBB, OutlinedBodyBB->begin()));
7059 if (!AfterIP)
7060 return AfterIP.takeError();
7061 Builder.restoreIP(*AfterIP);
7062 if (OMPBuilder.Config.isTargetDevice())
7063 OMPBuilder.createTargetDeinit(Builder);
7064
7065 // Insert return instruction.
7066 Builder.CreateRetVoid();
7067
7068 // New Alloca IP at entry point of created device function.
7069 Builder.SetInsertPoint(EntryBB->getFirstNonPHIIt());
7070 auto AllocaIP = Builder.saveIP();
7071
7072 Builder.SetInsertPoint(UserCodeEntryBB->getFirstNonPHIOrDbg());
7073
7074 // Skip the artificial dyn_ptr on the device.
7075 const auto &ArgRange =
7076 OMPBuilder.Config.isTargetDevice()
7077 ? make_range(Func->arg_begin() + 1, Func->arg_end())
7078 : Func->args();
7079
7080 DenseMap<Value *, std::tuple<Value *, unsigned>> ValueReplacementMap;
7081
7082 auto ReplaceValue = [](Value *Input, Value *InputCopy, Function *Func) {
7083 // Things like GEP's can come in the form of Constants. Constants and
7084 // ConstantExpr's do not have access to the knowledge of what they're
7085 // contained in, so we must dig a little to find an instruction so we
7086 // can tell if they're used inside of the function we're outlining. We
7087 // also replace the original constant expression with a new instruction
7088 // equivalent; an instruction as it allows easy modification in the
7089 // following loop, as we can now know the constant (instruction) is
7090 // owned by our target function and replaceUsesOfWith can now be invoked
7091 // on it (cannot do this with constants it seems). A brand new one also
7092 // allows us to be cautious as it is perhaps possible the old expression
7093 // was used inside of the function but exists and is used externally
7094 // (unlikely by the nature of a Constant, but still).
7095 // NOTE: We cannot remove dead constants that have been rewritten to
7096 // instructions at this stage, we run the risk of breaking later lowering
7097 // by doing so as we could still be in the process of lowering the module
7098 // from MLIR to LLVM-IR and the MLIR lowering may still require the original
7099 // constants we have created rewritten versions of.
7100 if (auto *Const = dyn_cast<Constant>(Input))
7101 convertUsersOfConstantsToInstructions(Const, Func, false);
7102
7103 // Collect users before iterating over them to avoid invalidating the
7104 // iteration in case a user uses Input more than once (e.g. a call
7105 // instruction).
7106 SetVector<User *> Users(Input->users().begin(), Input->users().end());
7107 // Collect all the instructions
7108 for (User *User : make_early_inc_range(Users))
7109 if (auto *Instr = dyn_cast<Instruction>(User))
7110 if (Instr->getFunction() == Func)
7111 Instr->replaceUsesOfWith(Input, InputCopy);
7112 };
7113
7114 SmallVector<std::pair<Value *, Value *>> DeferredReplacement;
7115
7116 // Rewrite uses of input valus to parameters.
7117 for (auto InArg : zip(Inputs, ArgRange)) {
7118 Value *Input = std::get<0>(InArg);
7119 Argument &Arg = std::get<1>(InArg);
7120 Value *InputCopy = nullptr;
7121
7122 llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
7123 ArgAccessorFuncCB(Arg, Input, InputCopy, AllocaIP, Builder.saveIP());
7124 if (!AfterIP)
7125 return AfterIP.takeError();
7126 Builder.restoreIP(*AfterIP);
7127 ValueReplacementMap[Input] = std::make_tuple(InputCopy, Arg.getArgNo());
7128
7129 // In certain cases a Global may be set up for replacement, however, this
7130 // Global may be used in multiple arguments to the kernel, just segmented
7131 // apart, for example, if we have a global array, that is sectioned into
7132 // multiple mappings (technically not legal in OpenMP, but there is a case
7133 // in Fortran for Common Blocks where this is neccesary), we will end up
7134 // with GEP's into this array inside the kernel, that refer to the Global
7135 // but are technically seperate arguments to the kernel for all intents and
7136 // purposes. If we have mapped a segment that requires a GEP into the 0-th
7137 // index, it will fold into an referal to the Global, if we then encounter
7138 // this folded GEP during replacement all of the references to the
7139 // Global in the kernel will be replaced with the argument we have generated
7140 // that corresponds to it, including any other GEP's that refer to the
7141 // Global that may be other arguments. This will invalidate all of the other
7142 // preceding mapped arguments that refer to the same global that may be
7143 // seperate segments. To prevent this, we defer global processing until all
7144 // other processing has been performed.
7145 if (isa<GlobalValue>(Input)) {
7146 DeferredReplacement.push_back(std::make_pair(Input, InputCopy));
7147 continue;
7148 }
7149
7150 if (isa<ConstantData>(Input))
7151 continue;
7152
7153 ReplaceValue(Input, InputCopy, Func);
7154 }
7155
7156 // Replace all of our deferred Input values, currently just Globals.
7157 for (auto Deferred : DeferredReplacement)
7158 ReplaceValue(std::get<0>(Deferred), std::get<1>(Deferred), Func);
7159
7160 FixupDebugInfoForOutlinedFunction(OMPBuilder, Builder, Func,
7161 ValueReplacementMap);
7162 return Func;
7163 }
7164 /// Given a task descriptor, TaskWithPrivates, return the pointer to the block
7165 /// of pointers containing shared data between the parent task and the created
7166 /// task.
loadSharedDataFromTaskDescriptor(OpenMPIRBuilder & OMPIRBuilder,IRBuilderBase & Builder,Value * TaskWithPrivates,Type * TaskWithPrivatesTy)7167 static LoadInst *loadSharedDataFromTaskDescriptor(OpenMPIRBuilder &OMPIRBuilder,
7168 IRBuilderBase &Builder,
7169 Value *TaskWithPrivates,
7170 Type *TaskWithPrivatesTy) {
7171
7172 Type *TaskTy = OMPIRBuilder.Task;
7173 LLVMContext &Ctx = Builder.getContext();
7174 Value *TaskT =
7175 Builder.CreateStructGEP(TaskWithPrivatesTy, TaskWithPrivates, 0);
7176 Value *Shareds = TaskT;
7177 // TaskWithPrivatesTy can be one of the following
7178 // 1. %struct.task_with_privates = type { %struct.kmp_task_ompbuilder_t,
7179 // %struct.privates }
7180 // 2. %struct.kmp_task_ompbuilder_t ;; This is simply TaskTy
7181 //
7182 // In the former case, that is when TaskWithPrivatesTy != TaskTy,
7183 // its first member has to be the task descriptor. TaskTy is the type of the
7184 // task descriptor. TaskT is the pointer to the task descriptor. Loading the
7185 // first member of TaskT, gives us the pointer to shared data.
7186 if (TaskWithPrivatesTy != TaskTy)
7187 Shareds = Builder.CreateStructGEP(TaskTy, TaskT, 0);
7188 return Builder.CreateLoad(PointerType::getUnqual(Ctx), Shareds);
7189 }
7190 /// Create an entry point for a target task with the following.
7191 /// It'll have the following signature
7192 /// void @.omp_target_task_proxy_func(i32 %thread.id, ptr %task)
7193 /// This function is called from emitTargetTask once the
7194 /// code to launch the target kernel has been outlined already.
7195 /// NumOffloadingArrays is the number of offloading arrays that we need to copy
7196 /// into the task structure so that the deferred target task can access this
7197 /// data even after the stack frame of the generating task has been rolled
7198 /// back. Offloading arrays contain base pointers, pointers, sizes etc
7199 /// of the data that the target kernel will access. These in effect are the
7200 /// non-empty arrays of pointers held by OpenMPIRBuilder::TargetDataRTArgs.
emitTargetTaskProxyFunction(OpenMPIRBuilder & OMPBuilder,IRBuilderBase & Builder,CallInst * StaleCI,StructType * PrivatesTy,StructType * TaskWithPrivatesTy,const size_t NumOffloadingArrays,const int SharedArgsOperandNo)7201 static Function *emitTargetTaskProxyFunction(
7202 OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, CallInst *StaleCI,
7203 StructType *PrivatesTy, StructType *TaskWithPrivatesTy,
7204 const size_t NumOffloadingArrays, const int SharedArgsOperandNo) {
7205
7206 // If NumOffloadingArrays is non-zero, PrivatesTy better not be nullptr.
7207 // This is because PrivatesTy is the type of the structure in which
7208 // we pass the offloading arrays to the deferred target task.
7209 assert((!NumOffloadingArrays || PrivatesTy) &&
7210 "PrivatesTy cannot be nullptr when there are offloadingArrays"
7211 "to privatize");
7212
7213 Module &M = OMPBuilder.M;
7214 // KernelLaunchFunction is the target launch function, i.e.
7215 // the function that sets up kernel arguments and calls
7216 // __tgt_target_kernel to launch the kernel on the device.
7217 //
7218 Function *KernelLaunchFunction = StaleCI->getCalledFunction();
7219
7220 // StaleCI is the CallInst which is the call to the outlined
7221 // target kernel launch function. If there are local live-in values
7222 // that the outlined function uses then these are aggregated into a structure
7223 // which is passed as the second argument. If there are no local live-in
7224 // values or if all values used by the outlined kernel are global variables,
7225 // then there's only one argument, the threadID. So, StaleCI can be
7226 //
7227 // %structArg = alloca { ptr, ptr }, align 8
7228 // %gep_ = getelementptr { ptr, ptr }, ptr %structArg, i32 0, i32 0
7229 // store ptr %20, ptr %gep_, align 8
7230 // %gep_8 = getelementptr { ptr, ptr }, ptr %structArg, i32 0, i32 1
7231 // store ptr %21, ptr %gep_8, align 8
7232 // call void @_QQmain..omp_par.1(i32 %global.tid.val6, ptr %structArg)
7233 //
7234 // OR
7235 //
7236 // call void @_QQmain..omp_par.1(i32 %global.tid.val6)
7237 OpenMPIRBuilder::InsertPointTy IP(StaleCI->getParent(),
7238 StaleCI->getIterator());
7239
7240 LLVMContext &Ctx = StaleCI->getParent()->getContext();
7241
7242 Type *ThreadIDTy = Type::getInt32Ty(Ctx);
7243 Type *TaskPtrTy = OMPBuilder.TaskPtr;
7244 [[maybe_unused]] Type *TaskTy = OMPBuilder.Task;
7245
7246 auto ProxyFnTy =
7247 FunctionType::get(Builder.getVoidTy(), {ThreadIDTy, TaskPtrTy},
7248 /* isVarArg */ false);
7249 auto ProxyFn = Function::Create(ProxyFnTy, GlobalValue::InternalLinkage,
7250 ".omp_target_task_proxy_func",
7251 Builder.GetInsertBlock()->getModule());
7252 Value *ThreadId = ProxyFn->getArg(0);
7253 Value *TaskWithPrivates = ProxyFn->getArg(1);
7254 ThreadId->setName("thread.id");
7255 TaskWithPrivates->setName("task");
7256
7257 bool HasShareds = SharedArgsOperandNo > 0;
7258 bool HasOffloadingArrays = NumOffloadingArrays > 0;
7259 BasicBlock *EntryBB =
7260 BasicBlock::Create(Builder.getContext(), "entry", ProxyFn);
7261 Builder.SetInsertPoint(EntryBB);
7262
7263 SmallVector<Value *> KernelLaunchArgs;
7264 KernelLaunchArgs.reserve(StaleCI->arg_size());
7265 KernelLaunchArgs.push_back(ThreadId);
7266
7267 if (HasOffloadingArrays) {
7268 assert(TaskTy != TaskWithPrivatesTy &&
7269 "If there are offloading arrays to pass to the target"
7270 "TaskTy cannot be the same as TaskWithPrivatesTy");
7271 (void)TaskTy;
7272 Value *Privates =
7273 Builder.CreateStructGEP(TaskWithPrivatesTy, TaskWithPrivates, 1);
7274 for (unsigned int i = 0; i < NumOffloadingArrays; ++i)
7275 KernelLaunchArgs.push_back(
7276 Builder.CreateStructGEP(PrivatesTy, Privates, i));
7277 }
7278
7279 if (HasShareds) {
7280 auto *ArgStructAlloca =
7281 dyn_cast<AllocaInst>(StaleCI->getArgOperand(SharedArgsOperandNo));
7282 assert(ArgStructAlloca &&
7283 "Unable to find the alloca instruction corresponding to arguments "
7284 "for extracted function");
7285 auto *ArgStructType = cast<StructType>(ArgStructAlloca->getAllocatedType());
7286
7287 AllocaInst *NewArgStructAlloca =
7288 Builder.CreateAlloca(ArgStructType, nullptr, "structArg");
7289
7290 Value *SharedsSize =
7291 Builder.getInt64(M.getDataLayout().getTypeStoreSize(ArgStructType));
7292
7293 LoadInst *LoadShared = loadSharedDataFromTaskDescriptor(
7294 OMPBuilder, Builder, TaskWithPrivates, TaskWithPrivatesTy);
7295
7296 Builder.CreateMemCpy(
7297 NewArgStructAlloca, NewArgStructAlloca->getAlign(), LoadShared,
7298 LoadShared->getPointerAlignment(M.getDataLayout()), SharedsSize);
7299 KernelLaunchArgs.push_back(NewArgStructAlloca);
7300 }
7301 Builder.CreateCall(KernelLaunchFunction, KernelLaunchArgs);
7302 Builder.CreateRetVoid();
7303 return ProxyFn;
7304 }
getOffloadingArrayType(Value * V)7305 static Type *getOffloadingArrayType(Value *V) {
7306
7307 if (auto *GEP = dyn_cast<GetElementPtrInst>(V))
7308 return GEP->getSourceElementType();
7309 if (auto *Alloca = dyn_cast<AllocaInst>(V))
7310 return Alloca->getAllocatedType();
7311
7312 llvm_unreachable("Unhandled Instruction type");
7313 return nullptr;
7314 }
7315 // This function returns a struct that has at most two members.
7316 // The first member is always %struct.kmp_task_ompbuilder_t, that is the task
7317 // descriptor. The second member, if needed, is a struct containing arrays
7318 // that need to be passed to the offloaded target kernel. For example,
7319 // if .offload_baseptrs, .offload_ptrs and .offload_sizes have to be passed to
7320 // the target kernel and their types are [3 x ptr], [3 x ptr] and [3 x i64]
7321 // respectively, then the types created by this function are
7322 //
7323 // %struct.privates = type { [3 x ptr], [3 x ptr], [3 x i64] }
7324 // %struct.task_with_privates = type { %struct.kmp_task_ompbuilder_t,
7325 // %struct.privates }
7326 // %struct.task_with_privates is returned by this function.
7327 // If there aren't any offloading arrays to pass to the target kernel,
7328 // %struct.kmp_task_ompbuilder_t is returned.
7329 static StructType *
createTaskWithPrivatesTy(OpenMPIRBuilder & OMPIRBuilder,ArrayRef<Value * > OffloadingArraysToPrivatize)7330 createTaskWithPrivatesTy(OpenMPIRBuilder &OMPIRBuilder,
7331 ArrayRef<Value *> OffloadingArraysToPrivatize) {
7332
7333 if (OffloadingArraysToPrivatize.empty())
7334 return OMPIRBuilder.Task;
7335
7336 SmallVector<Type *, 4> StructFieldTypes;
7337 for (Value *V : OffloadingArraysToPrivatize) {
7338 assert(V->getType()->isPointerTy() &&
7339 "Expected pointer to array to privatize. Got a non-pointer value "
7340 "instead");
7341 Type *ArrayTy = getOffloadingArrayType(V);
7342 assert(ArrayTy && "ArrayType cannot be nullptr");
7343 StructFieldTypes.push_back(ArrayTy);
7344 }
7345 StructType *PrivatesStructTy =
7346 StructType::create(StructFieldTypes, "struct.privates");
7347 return StructType::create({OMPIRBuilder.Task, PrivatesStructTy},
7348 "struct.task_with_privates");
7349 }
emitTargetOutlinedFunction(OpenMPIRBuilder & OMPBuilder,IRBuilderBase & Builder,bool IsOffloadEntry,TargetRegionEntryInfo & EntryInfo,const OpenMPIRBuilder::TargetKernelDefaultAttrs & DefaultAttrs,Function * & OutlinedFn,Constant * & OutlinedFnID,SmallVectorImpl<Value * > & Inputs,OpenMPIRBuilder::TargetBodyGenCallbackTy & CBFunc,OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy & ArgAccessorFuncCB)7350 static Error emitTargetOutlinedFunction(
7351 OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsOffloadEntry,
7352 TargetRegionEntryInfo &EntryInfo,
7353 const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
7354 Function *&OutlinedFn, Constant *&OutlinedFnID,
7355 SmallVectorImpl<Value *> &Inputs,
7356 OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
7357 OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy &ArgAccessorFuncCB) {
7358
7359 OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction =
7360 [&](StringRef EntryFnName) {
7361 return createOutlinedFunction(OMPBuilder, Builder, DefaultAttrs,
7362 EntryFnName, Inputs, CBFunc,
7363 ArgAccessorFuncCB);
7364 };
7365
7366 return OMPBuilder.emitTargetRegionFunction(
7367 EntryInfo, GenerateOutlinedFunction, IsOffloadEntry, OutlinedFn,
7368 OutlinedFnID);
7369 }
7370
emitTargetTask(TargetTaskBodyCallbackTy TaskBodyCB,Value * DeviceID,Value * RTLoc,OpenMPIRBuilder::InsertPointTy AllocaIP,const SmallVector<llvm::OpenMPIRBuilder::DependData> & Dependencies,const TargetDataRTArgs & RTArgs,bool HasNoWait)7371 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitTargetTask(
7372 TargetTaskBodyCallbackTy TaskBodyCB, Value *DeviceID, Value *RTLoc,
7373 OpenMPIRBuilder::InsertPointTy AllocaIP,
7374 const SmallVector<llvm::OpenMPIRBuilder::DependData> &Dependencies,
7375 const TargetDataRTArgs &RTArgs, bool HasNoWait) {
7376
7377 // The following explains the code-gen scenario for the `target` directive. A
7378 // similar scneario is followed for other device-related directives (e.g.
7379 // `target enter data`) but in similar fashion since we only need to emit task
7380 // that encapsulates the proper runtime call.
7381 //
7382 // When we arrive at this function, the target region itself has been
7383 // outlined into the function OutlinedFn.
7384 // So at ths point, for
7385 // --------------------------------------------------------------
7386 // void user_code_that_offloads(...) {
7387 // omp target depend(..) map(from:a) map(to:b) private(i)
7388 // do i = 1, 10
7389 // a(i) = b(i) + n
7390 // }
7391 //
7392 // --------------------------------------------------------------
7393 //
7394 // we have
7395 //
7396 // --------------------------------------------------------------
7397 //
7398 // void user_code_that_offloads(...) {
7399 // %.offload_baseptrs = alloca [2 x ptr], align 8
7400 // %.offload_ptrs = alloca [2 x ptr], align 8
7401 // %.offload_mappers = alloca [2 x ptr], align 8
7402 // ;; target region has been outlined and now we need to
7403 // ;; offload to it via a target task.
7404 // }
7405 // void outlined_device_function(ptr a, ptr b, ptr n) {
7406 // n = *n_ptr;
7407 // do i = 1, 10
7408 // a(i) = b(i) + n
7409 // }
7410 //
7411 // We have to now do the following
7412 // (i) Make an offloading call to outlined_device_function using the OpenMP
7413 // RTL. See 'kernel_launch_function' in the pseudo code below. This is
7414 // emitted by emitKernelLaunch
7415 // (ii) Create a task entry point function that calls kernel_launch_function
7416 // and is the entry point for the target task. See
7417 // '@.omp_target_task_proxy_func in the pseudocode below.
7418 // (iii) Create a task with the task entry point created in (ii)
7419 //
7420 // That is we create the following
7421 // struct task_with_privates {
7422 // struct kmp_task_ompbuilder_t task_struct;
7423 // struct privates {
7424 // [2 x ptr] ; baseptrs
7425 // [2 x ptr] ; ptrs
7426 // [2 x i64] ; sizes
7427 // }
7428 // }
7429 // void user_code_that_offloads(...) {
7430 // %.offload_baseptrs = alloca [2 x ptr], align 8
7431 // %.offload_ptrs = alloca [2 x ptr], align 8
7432 // %.offload_sizes = alloca [2 x i64], align 8
7433 //
7434 // %structArg = alloca { ptr, ptr, ptr }, align 8
7435 // %strucArg[0] = a
7436 // %strucArg[1] = b
7437 // %strucArg[2] = &n
7438 //
7439 // target_task_with_privates = @__kmpc_omp_target_task_alloc(...,
7440 // sizeof(kmp_task_ompbuilder_t),
7441 // sizeof(structArg),
7442 // @.omp_target_task_proxy_func,
7443 // ...)
7444 // memcpy(target_task_with_privates->task_struct->shareds, %structArg,
7445 // sizeof(structArg))
7446 // memcpy(target_task_with_privates->privates->baseptrs,
7447 // offload_baseptrs, sizeof(offload_baseptrs)
7448 // memcpy(target_task_with_privates->privates->ptrs,
7449 // offload_ptrs, sizeof(offload_ptrs)
7450 // memcpy(target_task_with_privates->privates->sizes,
7451 // offload_sizes, sizeof(offload_sizes)
7452 // dependencies_array = ...
7453 // ;; if nowait not present
7454 // call @__kmpc_omp_wait_deps(..., dependencies_array)
7455 // call @__kmpc_omp_task_begin_if0(...)
7456 // call @ @.omp_target_task_proxy_func(i32 thread_id, ptr
7457 // %target_task_with_privates)
7458 // call @__kmpc_omp_task_complete_if0(...)
7459 // }
7460 //
7461 // define internal void @.omp_target_task_proxy_func(i32 %thread.id,
7462 // ptr %task) {
7463 // %structArg = alloca {ptr, ptr, ptr}
7464 // %task_ptr = getelementptr(%task, 0, 0)
7465 // %shared_data = load (getelementptr %task_ptr, 0, 0)
7466 // mempcy(%structArg, %shared_data, sizeof(%structArg))
7467 //
7468 // %offloading_arrays = getelementptr(%task, 0, 1)
7469 // %offload_baseptrs = getelementptr(%offloading_arrays, 0, 0)
7470 // %offload_ptrs = getelementptr(%offloading_arrays, 0, 1)
7471 // %offload_sizes = getelementptr(%offloading_arrays, 0, 2)
7472 // kernel_launch_function(%thread.id, %offload_baseptrs, %offload_ptrs,
7473 // %offload_sizes, %structArg)
7474 // }
7475 //
7476 // We need the proxy function because the signature of the task entry point
7477 // expected by kmpc_omp_task is always the same and will be different from
7478 // that of the kernel_launch function.
7479 //
7480 // kernel_launch_function is generated by emitKernelLaunch and has the
7481 // always_inline attribute. For this example, it'll look like so:
7482 // void kernel_launch_function(%thread_id, %offload_baseptrs, %offload_ptrs,
7483 // %offload_sizes, %structArg) alwaysinline {
7484 // %kernel_args = alloca %struct.__tgt_kernel_arguments, align 8
7485 // ; load aggregated data from %structArg
7486 // ; setup kernel_args using offload_baseptrs, offload_ptrs and
7487 // ; offload_sizes
7488 // call i32 @__tgt_target_kernel(...,
7489 // outlined_device_function,
7490 // ptr %kernel_args)
7491 // }
7492 // void outlined_device_function(ptr a, ptr b, ptr n) {
7493 // n = *n_ptr;
7494 // do i = 1, 10
7495 // a(i) = b(i) + n
7496 // }
7497 //
7498 BasicBlock *TargetTaskBodyBB =
7499 splitBB(Builder, /*CreateBranch=*/true, "target.task.body");
7500 BasicBlock *TargetTaskAllocaBB =
7501 splitBB(Builder, /*CreateBranch=*/true, "target.task.alloca");
7502
7503 InsertPointTy TargetTaskAllocaIP(TargetTaskAllocaBB,
7504 TargetTaskAllocaBB->begin());
7505 InsertPointTy TargetTaskBodyIP(TargetTaskBodyBB, TargetTaskBodyBB->begin());
7506
7507 OutlineInfo OI;
7508 OI.EntryBB = TargetTaskAllocaBB;
7509 OI.OuterAllocaBB = AllocaIP.getBlock();
7510
7511 // Add the thread ID argument.
7512 SmallVector<Instruction *, 4> ToBeDeleted;
7513 OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
7514 Builder, AllocaIP, ToBeDeleted, TargetTaskAllocaIP, "global.tid", false));
7515
7516 // Generate the task body which will subsequently be outlined.
7517 Builder.restoreIP(TargetTaskBodyIP);
7518 if (Error Err = TaskBodyCB(DeviceID, RTLoc, TargetTaskAllocaIP))
7519 return Err;
7520
7521 // The outliner (CodeExtractor) extract a sequence or vector of blocks that
7522 // it is given. These blocks are enumerated by
7523 // OpenMPIRBuilder::OutlineInfo::collectBlocks which expects the OI.ExitBlock
7524 // to be outside the region. In other words, OI.ExitBlock is expected to be
7525 // the start of the region after the outlining. We used to set OI.ExitBlock
7526 // to the InsertBlock after TaskBodyCB is done. This is fine in most cases
7527 // except when the task body is a single basic block. In that case,
7528 // OI.ExitBlock is set to the single task body block and will get left out of
7529 // the outlining process. So, simply create a new empty block to which we
7530 // uncoditionally branch from where TaskBodyCB left off
7531 OI.ExitBB = BasicBlock::Create(Builder.getContext(), "target.task.cont");
7532 emitBlock(OI.ExitBB, Builder.GetInsertBlock()->getParent(),
7533 /*IsFinished=*/true);
7534
7535 SmallVector<Value *, 2> OffloadingArraysToPrivatize;
7536 bool NeedsTargetTask = HasNoWait && DeviceID;
7537 if (NeedsTargetTask) {
7538 for (auto *V :
7539 {RTArgs.BasePointersArray, RTArgs.PointersArray, RTArgs.MappersArray,
7540 RTArgs.MapNamesArray, RTArgs.MapTypesArray, RTArgs.MapTypesArrayEnd,
7541 RTArgs.SizesArray}) {
7542 if (V && !isa<ConstantPointerNull, GlobalVariable>(V)) {
7543 OffloadingArraysToPrivatize.push_back(V);
7544 OI.ExcludeArgsFromAggregate.push_back(V);
7545 }
7546 }
7547 }
7548 OI.PostOutlineCB = [this, ToBeDeleted, Dependencies, NeedsTargetTask,
7549 DeviceID, OffloadingArraysToPrivatize](
7550 Function &OutlinedFn) mutable {
7551 assert(OutlinedFn.hasOneUse() &&
7552 "there must be a single user for the outlined function");
7553
7554 CallInst *StaleCI = cast<CallInst>(OutlinedFn.user_back());
7555
7556 // The first argument of StaleCI is always the thread id.
7557 // The next few arguments are the pointers to offloading arrays
7558 // if any. (see OffloadingArraysToPrivatize)
7559 // Finally, all other local values that are live-in into the outlined region
7560 // end up in a structure whose pointer is passed as the last argument. This
7561 // piece of data is passed in the "shared" field of the task structure. So,
7562 // we know we have to pass shareds to the task if the number of arguments is
7563 // greater than OffloadingArraysToPrivatize.size() + 1 The 1 is for the
7564 // thread id. Further, for safety, we assert that the number of arguments of
7565 // StaleCI is exactly OffloadingArraysToPrivatize.size() + 2
7566 const unsigned int NumStaleCIArgs = StaleCI->arg_size();
7567 bool HasShareds = NumStaleCIArgs > OffloadingArraysToPrivatize.size() + 1;
7568 assert((!HasShareds ||
7569 NumStaleCIArgs == (OffloadingArraysToPrivatize.size() + 2)) &&
7570 "Wrong number of arguments for StaleCI when shareds are present");
7571 int SharedArgOperandNo =
7572 HasShareds ? OffloadingArraysToPrivatize.size() + 1 : 0;
7573
7574 StructType *TaskWithPrivatesTy =
7575 createTaskWithPrivatesTy(*this, OffloadingArraysToPrivatize);
7576 StructType *PrivatesTy = nullptr;
7577
7578 if (!OffloadingArraysToPrivatize.empty())
7579 PrivatesTy =
7580 static_cast<StructType *>(TaskWithPrivatesTy->getElementType(1));
7581
7582 Function *ProxyFn = emitTargetTaskProxyFunction(
7583 *this, Builder, StaleCI, PrivatesTy, TaskWithPrivatesTy,
7584 OffloadingArraysToPrivatize.size(), SharedArgOperandNo);
7585
7586 LLVM_DEBUG(dbgs() << "Proxy task entry function created: " << *ProxyFn
7587 << "\n");
7588
7589 Builder.SetInsertPoint(StaleCI);
7590
7591 // Gather the arguments for emitting the runtime call.
7592 uint32_t SrcLocStrSize;
7593 Constant *SrcLocStr =
7594 getOrCreateSrcLocStr(LocationDescription(Builder), SrcLocStrSize);
7595 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
7596
7597 // @__kmpc_omp_task_alloc or @__kmpc_omp_target_task_alloc
7598 //
7599 // If `HasNoWait == true`, we call @__kmpc_omp_target_task_alloc to provide
7600 // the DeviceID to the deferred task and also since
7601 // @__kmpc_omp_target_task_alloc creates an untied/async task.
7602 Function *TaskAllocFn =
7603 !NeedsTargetTask
7604 ? getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_alloc)
7605 : getOrCreateRuntimeFunctionPtr(
7606 OMPRTL___kmpc_omp_target_task_alloc);
7607
7608 // Arguments - `loc_ref` (Ident) and `gtid` (ThreadID)
7609 // call.
7610 Value *ThreadID = getOrCreateThreadID(Ident);
7611
7612 // Argument - `sizeof_kmp_task_t` (TaskSize)
7613 // Tasksize refers to the size in bytes of kmp_task_t data structure
7614 // plus any other data to be passed to the target task, if any, which
7615 // is packed into a struct. kmp_task_t and the struct so created are
7616 // packed into a wrapper struct whose type is TaskWithPrivatesTy.
7617 Value *TaskSize = Builder.getInt64(
7618 M.getDataLayout().getTypeStoreSize(TaskWithPrivatesTy));
7619
7620 // Argument - `sizeof_shareds` (SharedsSize)
7621 // SharedsSize refers to the shareds array size in the kmp_task_t data
7622 // structure.
7623 Value *SharedsSize = Builder.getInt64(0);
7624 if (HasShareds) {
7625 auto *ArgStructAlloca =
7626 dyn_cast<AllocaInst>(StaleCI->getArgOperand(SharedArgOperandNo));
7627 assert(ArgStructAlloca &&
7628 "Unable to find the alloca instruction corresponding to arguments "
7629 "for extracted function");
7630 auto *ArgStructType =
7631 dyn_cast<StructType>(ArgStructAlloca->getAllocatedType());
7632 assert(ArgStructType && "Unable to find struct type corresponding to "
7633 "arguments for extracted function");
7634 SharedsSize =
7635 Builder.getInt64(M.getDataLayout().getTypeStoreSize(ArgStructType));
7636 }
7637
7638 // Argument - `flags`
7639 // Task is tied iff (Flags & 1) == 1.
7640 // Task is untied iff (Flags & 1) == 0.
7641 // Task is final iff (Flags & 2) == 2.
7642 // Task is not final iff (Flags & 2) == 0.
7643 // A target task is not final and is untied.
7644 Value *Flags = Builder.getInt32(0);
7645
7646 // Emit the @__kmpc_omp_task_alloc runtime call
7647 // The runtime call returns a pointer to an area where the task captured
7648 // variables must be copied before the task is run (TaskData)
7649 CallInst *TaskData = nullptr;
7650
7651 SmallVector<llvm::Value *> TaskAllocArgs = {
7652 /*loc_ref=*/Ident, /*gtid=*/ThreadID,
7653 /*flags=*/Flags,
7654 /*sizeof_task=*/TaskSize, /*sizeof_shared=*/SharedsSize,
7655 /*task_func=*/ProxyFn};
7656
7657 if (NeedsTargetTask) {
7658 assert(DeviceID && "Expected non-empty device ID.");
7659 TaskAllocArgs.push_back(DeviceID);
7660 }
7661
7662 TaskData = Builder.CreateCall(TaskAllocFn, TaskAllocArgs);
7663
7664 Align Alignment = TaskData->getPointerAlignment(M.getDataLayout());
7665 if (HasShareds) {
7666 Value *Shareds = StaleCI->getArgOperand(SharedArgOperandNo);
7667 Value *TaskShareds = loadSharedDataFromTaskDescriptor(
7668 *this, Builder, TaskData, TaskWithPrivatesTy);
7669 Builder.CreateMemCpy(TaskShareds, Alignment, Shareds, Alignment,
7670 SharedsSize);
7671 }
7672 if (!OffloadingArraysToPrivatize.empty()) {
7673 Value *Privates =
7674 Builder.CreateStructGEP(TaskWithPrivatesTy, TaskData, 1);
7675 for (unsigned int i = 0; i < OffloadingArraysToPrivatize.size(); ++i) {
7676 Value *PtrToPrivatize = OffloadingArraysToPrivatize[i];
7677 [[maybe_unused]] Type *ArrayType =
7678 getOffloadingArrayType(PtrToPrivatize);
7679 assert(ArrayType && "ArrayType cannot be nullptr");
7680
7681 Type *ElementType = PrivatesTy->getElementType(i);
7682 assert(ElementType == ArrayType &&
7683 "ElementType should match ArrayType");
7684 (void)ArrayType;
7685
7686 Value *Dst = Builder.CreateStructGEP(PrivatesTy, Privates, i);
7687 Builder.CreateMemCpy(
7688 Dst, Alignment, PtrToPrivatize, Alignment,
7689 Builder.getInt64(M.getDataLayout().getTypeStoreSize(ElementType)));
7690 }
7691 }
7692
7693 Value *DepArray = emitTaskDependencies(*this, Dependencies);
7694
7695 // ---------------------------------------------------------------
7696 // V5.2 13.8 target construct
7697 // If the nowait clause is present, execution of the target task
7698 // may be deferred. If the nowait clause is not present, the target task is
7699 // an included task.
7700 // ---------------------------------------------------------------
7701 // The above means that the lack of a nowait on the target construct
7702 // translates to '#pragma omp task if(0)'
7703 if (!NeedsTargetTask) {
7704 if (DepArray) {
7705 Function *TaskWaitFn =
7706 getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_wait_deps);
7707 Builder.CreateCall(
7708 TaskWaitFn,
7709 {/*loc_ref=*/Ident, /*gtid=*/ThreadID,
7710 /*ndeps=*/Builder.getInt32(Dependencies.size()),
7711 /*dep_list=*/DepArray,
7712 /*ndeps_noalias=*/ConstantInt::get(Builder.getInt32Ty(), 0),
7713 /*noalias_dep_list=*/
7714 ConstantPointerNull::get(PointerType::getUnqual(M.getContext()))});
7715 }
7716 // Included task.
7717 Function *TaskBeginFn =
7718 getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_begin_if0);
7719 Function *TaskCompleteFn =
7720 getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_complete_if0);
7721 Builder.CreateCall(TaskBeginFn, {Ident, ThreadID, TaskData});
7722 CallInst *CI = Builder.CreateCall(ProxyFn, {ThreadID, TaskData});
7723 CI->setDebugLoc(StaleCI->getDebugLoc());
7724 Builder.CreateCall(TaskCompleteFn, {Ident, ThreadID, TaskData});
7725 } else if (DepArray) {
7726 // HasNoWait - meaning the task may be deferred. Call
7727 // __kmpc_omp_task_with_deps if there are dependencies,
7728 // else call __kmpc_omp_task
7729 Function *TaskFn =
7730 getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_with_deps);
7731 Builder.CreateCall(
7732 TaskFn,
7733 {Ident, ThreadID, TaskData, Builder.getInt32(Dependencies.size()),
7734 DepArray, ConstantInt::get(Builder.getInt32Ty(), 0),
7735 ConstantPointerNull::get(PointerType::getUnqual(M.getContext()))});
7736 } else {
7737 // Emit the @__kmpc_omp_task runtime call to spawn the task
7738 Function *TaskFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task);
7739 Builder.CreateCall(TaskFn, {Ident, ThreadID, TaskData});
7740 }
7741
7742 StaleCI->eraseFromParent();
7743 for (Instruction *I : llvm::reverse(ToBeDeleted))
7744 I->eraseFromParent();
7745 };
7746 addOutlineInfo(std::move(OI));
7747
7748 LLVM_DEBUG(dbgs() << "Insert block after emitKernelLaunch = \n"
7749 << *(Builder.GetInsertBlock()) << "\n");
7750 LLVM_DEBUG(dbgs() << "Module after emitKernelLaunch = \n"
7751 << *(Builder.GetInsertBlock()->getParent()->getParent())
7752 << "\n");
7753 return Builder.saveIP();
7754 }
7755
emitOffloadingArraysAndArgs(InsertPointTy AllocaIP,InsertPointTy CodeGenIP,TargetDataInfo & Info,TargetDataRTArgs & RTArgs,MapInfosTy & CombinedInfo,CustomMapperCallbackTy CustomMapperCB,bool IsNonContiguous,bool ForEndCall,function_ref<void (unsigned int,Value *)> DeviceAddrCB)7756 Error OpenMPIRBuilder::emitOffloadingArraysAndArgs(
7757 InsertPointTy AllocaIP, InsertPointTy CodeGenIP, TargetDataInfo &Info,
7758 TargetDataRTArgs &RTArgs, MapInfosTy &CombinedInfo,
7759 CustomMapperCallbackTy CustomMapperCB, bool IsNonContiguous,
7760 bool ForEndCall, function_ref<void(unsigned int, Value *)> DeviceAddrCB) {
7761 if (Error Err =
7762 emitOffloadingArrays(AllocaIP, CodeGenIP, CombinedInfo, Info,
7763 CustomMapperCB, IsNonContiguous, DeviceAddrCB))
7764 return Err;
7765 emitOffloadingArraysArgument(Builder, RTArgs, Info, ForEndCall);
7766 return Error::success();
7767 }
7768
emitTargetCall(OpenMPIRBuilder & OMPBuilder,IRBuilderBase & Builder,OpenMPIRBuilder::InsertPointTy AllocaIP,OpenMPIRBuilder::TargetDataInfo & Info,const OpenMPIRBuilder::TargetKernelDefaultAttrs & DefaultAttrs,const OpenMPIRBuilder::TargetKernelRuntimeAttrs & RuntimeAttrs,Value * IfCond,Function * OutlinedFn,Constant * OutlinedFnID,SmallVectorImpl<Value * > & Args,OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,OpenMPIRBuilder::CustomMapperCallbackTy CustomMapperCB,const SmallVector<llvm::OpenMPIRBuilder::DependData> & Dependencies,bool HasNoWait)7769 static void emitTargetCall(
7770 OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
7771 OpenMPIRBuilder::InsertPointTy AllocaIP,
7772 OpenMPIRBuilder::TargetDataInfo &Info,
7773 const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
7774 const OpenMPIRBuilder::TargetKernelRuntimeAttrs &RuntimeAttrs,
7775 Value *IfCond, Function *OutlinedFn, Constant *OutlinedFnID,
7776 SmallVectorImpl<Value *> &Args,
7777 OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
7778 OpenMPIRBuilder::CustomMapperCallbackTy CustomMapperCB,
7779 const SmallVector<llvm::OpenMPIRBuilder::DependData> &Dependencies,
7780 bool HasNoWait) {
7781 // Generate a function call to the host fallback implementation of the target
7782 // region. This is called by the host when no offload entry was generated for
7783 // the target region and when the offloading call fails at runtime.
7784 auto &&EmitTargetCallFallbackCB = [&](OpenMPIRBuilder::InsertPointTy IP)
7785 -> OpenMPIRBuilder::InsertPointOrErrorTy {
7786 Builder.restoreIP(IP);
7787 Builder.CreateCall(OutlinedFn, Args);
7788 return Builder.saveIP();
7789 };
7790
7791 bool HasDependencies = Dependencies.size() > 0;
7792 bool RequiresOuterTargetTask = HasNoWait || HasDependencies;
7793
7794 OpenMPIRBuilder::TargetKernelArgs KArgs;
7795
7796 auto TaskBodyCB =
7797 [&](Value *DeviceID, Value *RTLoc,
7798 IRBuilderBase::InsertPoint TargetTaskAllocaIP) -> Error {
7799 // Assume no error was returned because EmitTargetCallFallbackCB doesn't
7800 // produce any.
7801 llvm::OpenMPIRBuilder::InsertPointTy AfterIP = cantFail([&]() {
7802 // emitKernelLaunch makes the necessary runtime call to offload the
7803 // kernel. We then outline all that code into a separate function
7804 // ('kernel_launch_function' in the pseudo code above). This function is
7805 // then called by the target task proxy function (see
7806 // '@.omp_target_task_proxy_func' in the pseudo code above)
7807 // "@.omp_target_task_proxy_func' is generated by
7808 // emitTargetTaskProxyFunction.
7809 if (OutlinedFnID && DeviceID)
7810 return OMPBuilder.emitKernelLaunch(Builder, OutlinedFnID,
7811 EmitTargetCallFallbackCB, KArgs,
7812 DeviceID, RTLoc, TargetTaskAllocaIP);
7813
7814 // We only need to do the outlining if `DeviceID` is set to avoid calling
7815 // `emitKernelLaunch` if we want to code-gen for the host; e.g. if we are
7816 // generating the `else` branch of an `if` clause.
7817 //
7818 // When OutlinedFnID is set to nullptr, then it's not an offloading call.
7819 // In this case, we execute the host implementation directly.
7820 return EmitTargetCallFallbackCB(OMPBuilder.Builder.saveIP());
7821 }());
7822
7823 OMPBuilder.Builder.restoreIP(AfterIP);
7824 return Error::success();
7825 };
7826
7827 auto &&EmitTargetCallElse =
7828 [&](OpenMPIRBuilder::InsertPointTy AllocaIP,
7829 OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error {
7830 // Assume no error was returned because EmitTargetCallFallbackCB doesn't
7831 // produce any.
7832 OpenMPIRBuilder::InsertPointTy AfterIP = cantFail([&]() {
7833 if (RequiresOuterTargetTask) {
7834 // Arguments that are intended to be directly forwarded to an
7835 // emitKernelLaunch call are pased as nullptr, since
7836 // OutlinedFnID=nullptr results in that call not being done.
7837 OpenMPIRBuilder::TargetDataRTArgs EmptyRTArgs;
7838 return OMPBuilder.emitTargetTask(TaskBodyCB, /*DeviceID=*/nullptr,
7839 /*RTLoc=*/nullptr, AllocaIP,
7840 Dependencies, EmptyRTArgs, HasNoWait);
7841 }
7842 return EmitTargetCallFallbackCB(Builder.saveIP());
7843 }());
7844
7845 Builder.restoreIP(AfterIP);
7846 return Error::success();
7847 };
7848
7849 auto &&EmitTargetCallThen =
7850 [&](OpenMPIRBuilder::InsertPointTy AllocaIP,
7851 OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error {
7852 Info.HasNoWait = HasNoWait;
7853 OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP());
7854 OpenMPIRBuilder::TargetDataRTArgs RTArgs;
7855 if (Error Err = OMPBuilder.emitOffloadingArraysAndArgs(
7856 AllocaIP, Builder.saveIP(), Info, RTArgs, MapInfo, CustomMapperCB,
7857 /*IsNonContiguous=*/true,
7858 /*ForEndCall=*/false))
7859 return Err;
7860
7861 SmallVector<Value *, 3> NumTeamsC;
7862 for (auto [DefaultVal, RuntimeVal] :
7863 zip_equal(DefaultAttrs.MaxTeams, RuntimeAttrs.MaxTeams))
7864 NumTeamsC.push_back(RuntimeVal ? RuntimeVal
7865 : Builder.getInt32(DefaultVal));
7866
7867 // Calculate number of threads: 0 if no clauses specified, otherwise it is
7868 // the minimum between optional THREAD_LIMIT and NUM_THREADS clauses.
7869 auto InitMaxThreadsClause = [&Builder](Value *Clause) {
7870 if (Clause)
7871 Clause = Builder.CreateIntCast(Clause, Builder.getInt32Ty(),
7872 /*isSigned=*/false);
7873 return Clause;
7874 };
7875 auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) {
7876 if (Clause)
7877 Result =
7878 Result ? Builder.CreateSelect(Builder.CreateICmpULT(Result, Clause),
7879 Result, Clause)
7880 : Clause;
7881 };
7882
7883 // If a multi-dimensional THREAD_LIMIT is set, it is the OMPX_BARE case, so
7884 // the NUM_THREADS clause is overriden by THREAD_LIMIT.
7885 SmallVector<Value *, 3> NumThreadsC;
7886 Value *MaxThreadsClause =
7887 RuntimeAttrs.TeamsThreadLimit.size() == 1
7888 ? InitMaxThreadsClause(RuntimeAttrs.MaxThreads)
7889 : nullptr;
7890
7891 for (auto [TeamsVal, TargetVal] : zip_equal(
7892 RuntimeAttrs.TeamsThreadLimit, RuntimeAttrs.TargetThreadLimit)) {
7893 Value *TeamsThreadLimitClause = InitMaxThreadsClause(TeamsVal);
7894 Value *NumThreads = InitMaxThreadsClause(TargetVal);
7895
7896 CombineMaxThreadsClauses(TeamsThreadLimitClause, NumThreads);
7897 CombineMaxThreadsClauses(MaxThreadsClause, NumThreads);
7898
7899 NumThreadsC.push_back(NumThreads ? NumThreads : Builder.getInt32(0));
7900 }
7901
7902 unsigned NumTargetItems = Info.NumberOfPtrs;
7903 // TODO: Use correct device ID
7904 Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF);
7905 uint32_t SrcLocStrSize;
7906 Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
7907 Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
7908 llvm::omp::IdentFlag(0), 0);
7909
7910 Value *TripCount = RuntimeAttrs.LoopTripCount
7911 ? Builder.CreateIntCast(RuntimeAttrs.LoopTripCount,
7912 Builder.getInt64Ty(),
7913 /*isSigned=*/false)
7914 : Builder.getInt64(0);
7915
7916 // TODO: Use correct DynCGGroupMem
7917 Value *DynCGGroupMem = Builder.getInt32(0);
7918
7919 KArgs = OpenMPIRBuilder::TargetKernelArgs(NumTargetItems, RTArgs, TripCount,
7920 NumTeamsC, NumThreadsC,
7921 DynCGGroupMem, HasNoWait);
7922
7923 // Assume no error was returned because TaskBodyCB and
7924 // EmitTargetCallFallbackCB don't produce any.
7925 OpenMPIRBuilder::InsertPointTy AfterIP = cantFail([&]() {
7926 // The presence of certain clauses on the target directive require the
7927 // explicit generation of the target task.
7928 if (RequiresOuterTargetTask)
7929 return OMPBuilder.emitTargetTask(TaskBodyCB, DeviceID, RTLoc, AllocaIP,
7930 Dependencies, KArgs.RTArgs,
7931 Info.HasNoWait);
7932
7933 return OMPBuilder.emitKernelLaunch(Builder, OutlinedFnID,
7934 EmitTargetCallFallbackCB, KArgs,
7935 DeviceID, RTLoc, AllocaIP);
7936 }());
7937
7938 Builder.restoreIP(AfterIP);
7939 return Error::success();
7940 };
7941
7942 // If we don't have an ID for the target region, it means an offload entry
7943 // wasn't created. In this case we just run the host fallback directly and
7944 // ignore any potential 'if' clauses.
7945 if (!OutlinedFnID) {
7946 cantFail(EmitTargetCallElse(AllocaIP, Builder.saveIP()));
7947 return;
7948 }
7949
7950 // If there's no 'if' clause, only generate the kernel launch code path.
7951 if (!IfCond) {
7952 cantFail(EmitTargetCallThen(AllocaIP, Builder.saveIP()));
7953 return;
7954 }
7955
7956 cantFail(OMPBuilder.emitIfClause(IfCond, EmitTargetCallThen,
7957 EmitTargetCallElse, AllocaIP));
7958 }
7959
createTarget(const LocationDescription & Loc,bool IsOffloadEntry,InsertPointTy AllocaIP,InsertPointTy CodeGenIP,TargetDataInfo & Info,TargetRegionEntryInfo & EntryInfo,const TargetKernelDefaultAttrs & DefaultAttrs,const TargetKernelRuntimeAttrs & RuntimeAttrs,Value * IfCond,SmallVectorImpl<Value * > & Inputs,GenMapInfoCallbackTy GenMapInfoCB,OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,CustomMapperCallbackTy CustomMapperCB,const SmallVector<DependData> & Dependencies,bool HasNowait)7960 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
7961 const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
7962 InsertPointTy CodeGenIP, TargetDataInfo &Info,
7963 TargetRegionEntryInfo &EntryInfo,
7964 const TargetKernelDefaultAttrs &DefaultAttrs,
7965 const TargetKernelRuntimeAttrs &RuntimeAttrs, Value *IfCond,
7966 SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB,
7967 OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
7968 OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
7969 CustomMapperCallbackTy CustomMapperCB,
7970 const SmallVector<DependData> &Dependencies, bool HasNowait) {
7971
7972 if (!updateToLocation(Loc))
7973 return InsertPointTy();
7974
7975 Builder.restoreIP(CodeGenIP);
7976
7977 Function *OutlinedFn;
7978 Constant *OutlinedFnID = nullptr;
7979 // The target region is outlined into its own function. The LLVM IR for
7980 // the target region itself is generated using the callbacks CBFunc
7981 // and ArgAccessorFuncCB
7982 if (Error Err = emitTargetOutlinedFunction(
7983 *this, Builder, IsOffloadEntry, EntryInfo, DefaultAttrs, OutlinedFn,
7984 OutlinedFnID, Inputs, CBFunc, ArgAccessorFuncCB))
7985 return Err;
7986
7987 // If we are not on the target device, then we need to generate code
7988 // to make a remote call (offload) to the previously outlined function
7989 // that represents the target region. Do that now.
7990 if (!Config.isTargetDevice())
7991 emitTargetCall(*this, Builder, AllocaIP, Info, DefaultAttrs, RuntimeAttrs,
7992 IfCond, OutlinedFn, OutlinedFnID, Inputs, GenMapInfoCB,
7993 CustomMapperCB, Dependencies, HasNowait);
7994 return Builder.saveIP();
7995 }
7996
getNameWithSeparators(ArrayRef<StringRef> Parts,StringRef FirstSeparator,StringRef Separator)7997 std::string OpenMPIRBuilder::getNameWithSeparators(ArrayRef<StringRef> Parts,
7998 StringRef FirstSeparator,
7999 StringRef Separator) {
8000 SmallString<128> Buffer;
8001 llvm::raw_svector_ostream OS(Buffer);
8002 StringRef Sep = FirstSeparator;
8003 for (StringRef Part : Parts) {
8004 OS << Sep << Part;
8005 Sep = Separator;
8006 }
8007 return OS.str().str();
8008 }
8009
8010 std::string
createPlatformSpecificName(ArrayRef<StringRef> Parts) const8011 OpenMPIRBuilder::createPlatformSpecificName(ArrayRef<StringRef> Parts) const {
8012 return OpenMPIRBuilder::getNameWithSeparators(Parts, Config.firstSeparator(),
8013 Config.separator());
8014 }
8015
8016 GlobalVariable *
getOrCreateInternalVariable(Type * Ty,const StringRef & Name,unsigned AddressSpace)8017 OpenMPIRBuilder::getOrCreateInternalVariable(Type *Ty, const StringRef &Name,
8018 unsigned AddressSpace) {
8019 auto &Elem = *InternalVars.try_emplace(Name, nullptr).first;
8020 if (Elem.second) {
8021 assert(Elem.second->getValueType() == Ty &&
8022 "OMP internal variable has different type than requested");
8023 } else {
8024 // TODO: investigate the appropriate linkage type used for the global
8025 // variable for possibly changing that to internal or private, or maybe
8026 // create different versions of the function for different OMP internal
8027 // variables.
8028 auto Linkage = this->M.getTargetTriple().getArch() == Triple::wasm32
8029 ? GlobalValue::InternalLinkage
8030 : GlobalValue::CommonLinkage;
8031 auto *GV = new GlobalVariable(M, Ty, /*IsConstant=*/false, Linkage,
8032 Constant::getNullValue(Ty), Elem.first(),
8033 /*InsertBefore=*/nullptr,
8034 GlobalValue::NotThreadLocal, AddressSpace);
8035 const DataLayout &DL = M.getDataLayout();
8036 const llvm::Align TypeAlign = DL.getABITypeAlign(Ty);
8037 const llvm::Align PtrAlign = DL.getPointerABIAlignment(AddressSpace);
8038 GV->setAlignment(std::max(TypeAlign, PtrAlign));
8039 Elem.second = GV;
8040 }
8041
8042 return Elem.second;
8043 }
8044
getOMPCriticalRegionLock(StringRef CriticalName)8045 Value *OpenMPIRBuilder::getOMPCriticalRegionLock(StringRef CriticalName) {
8046 std::string Prefix = Twine("gomp_critical_user_", CriticalName).str();
8047 std::string Name = getNameWithSeparators({Prefix, "var"}, ".", ".");
8048 return getOrCreateInternalVariable(KmpCriticalNameTy, Name);
8049 }
8050
getSizeInBytes(Value * BasePtr)8051 Value *OpenMPIRBuilder::getSizeInBytes(Value *BasePtr) {
8052 LLVMContext &Ctx = Builder.getContext();
8053 Value *Null =
8054 Constant::getNullValue(PointerType::getUnqual(BasePtr->getContext()));
8055 Value *SizeGep =
8056 Builder.CreateGEP(BasePtr->getType(), Null, Builder.getInt32(1));
8057 Value *SizePtrToInt = Builder.CreatePtrToInt(SizeGep, Type::getInt64Ty(Ctx));
8058 return SizePtrToInt;
8059 }
8060
8061 GlobalVariable *
createOffloadMaptypes(SmallVectorImpl<uint64_t> & Mappings,std::string VarName)8062 OpenMPIRBuilder::createOffloadMaptypes(SmallVectorImpl<uint64_t> &Mappings,
8063 std::string VarName) {
8064 llvm::Constant *MaptypesArrayInit =
8065 llvm::ConstantDataArray::get(M.getContext(), Mappings);
8066 auto *MaptypesArrayGlobal = new llvm::GlobalVariable(
8067 M, MaptypesArrayInit->getType(),
8068 /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, MaptypesArrayInit,
8069 VarName);
8070 MaptypesArrayGlobal->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global);
8071 return MaptypesArrayGlobal;
8072 }
8073
createMapperAllocas(const LocationDescription & Loc,InsertPointTy AllocaIP,unsigned NumOperands,struct MapperAllocas & MapperAllocas)8074 void OpenMPIRBuilder::createMapperAllocas(const LocationDescription &Loc,
8075 InsertPointTy AllocaIP,
8076 unsigned NumOperands,
8077 struct MapperAllocas &MapperAllocas) {
8078 if (!updateToLocation(Loc))
8079 return;
8080
8081 auto *ArrI8PtrTy = ArrayType::get(Int8Ptr, NumOperands);
8082 auto *ArrI64Ty = ArrayType::get(Int64, NumOperands);
8083 Builder.restoreIP(AllocaIP);
8084 AllocaInst *ArgsBase = Builder.CreateAlloca(
8085 ArrI8PtrTy, /* ArraySize = */ nullptr, ".offload_baseptrs");
8086 AllocaInst *Args = Builder.CreateAlloca(ArrI8PtrTy, /* ArraySize = */ nullptr,
8087 ".offload_ptrs");
8088 AllocaInst *ArgSizes = Builder.CreateAlloca(
8089 ArrI64Ty, /* ArraySize = */ nullptr, ".offload_sizes");
8090 Builder.restoreIP(Loc.IP);
8091 MapperAllocas.ArgsBase = ArgsBase;
8092 MapperAllocas.Args = Args;
8093 MapperAllocas.ArgSizes = ArgSizes;
8094 }
8095
emitMapperCall(const LocationDescription & Loc,Function * MapperFunc,Value * SrcLocInfo,Value * MaptypesArg,Value * MapnamesArg,struct MapperAllocas & MapperAllocas,int64_t DeviceID,unsigned NumOperands)8096 void OpenMPIRBuilder::emitMapperCall(const LocationDescription &Loc,
8097 Function *MapperFunc, Value *SrcLocInfo,
8098 Value *MaptypesArg, Value *MapnamesArg,
8099 struct MapperAllocas &MapperAllocas,
8100 int64_t DeviceID, unsigned NumOperands) {
8101 if (!updateToLocation(Loc))
8102 return;
8103
8104 auto *ArrI8PtrTy = ArrayType::get(Int8Ptr, NumOperands);
8105 auto *ArrI64Ty = ArrayType::get(Int64, NumOperands);
8106 Value *ArgsBaseGEP =
8107 Builder.CreateInBoundsGEP(ArrI8PtrTy, MapperAllocas.ArgsBase,
8108 {Builder.getInt32(0), Builder.getInt32(0)});
8109 Value *ArgsGEP =
8110 Builder.CreateInBoundsGEP(ArrI8PtrTy, MapperAllocas.Args,
8111 {Builder.getInt32(0), Builder.getInt32(0)});
8112 Value *ArgSizesGEP =
8113 Builder.CreateInBoundsGEP(ArrI64Ty, MapperAllocas.ArgSizes,
8114 {Builder.getInt32(0), Builder.getInt32(0)});
8115 Value *NullPtr =
8116 Constant::getNullValue(PointerType::getUnqual(Int8Ptr->getContext()));
8117 Builder.CreateCall(MapperFunc,
8118 {SrcLocInfo, Builder.getInt64(DeviceID),
8119 Builder.getInt32(NumOperands), ArgsBaseGEP, ArgsGEP,
8120 ArgSizesGEP, MaptypesArg, MapnamesArg, NullPtr});
8121 }
8122
emitOffloadingArraysArgument(IRBuilderBase & Builder,TargetDataRTArgs & RTArgs,TargetDataInfo & Info,bool ForEndCall)8123 void OpenMPIRBuilder::emitOffloadingArraysArgument(IRBuilderBase &Builder,
8124 TargetDataRTArgs &RTArgs,
8125 TargetDataInfo &Info,
8126 bool ForEndCall) {
8127 assert((!ForEndCall || Info.separateBeginEndCalls()) &&
8128 "expected region end call to runtime only when end call is separate");
8129 auto UnqualPtrTy = PointerType::getUnqual(M.getContext());
8130 auto VoidPtrTy = UnqualPtrTy;
8131 auto VoidPtrPtrTy = UnqualPtrTy;
8132 auto Int64Ty = Type::getInt64Ty(M.getContext());
8133 auto Int64PtrTy = UnqualPtrTy;
8134
8135 if (!Info.NumberOfPtrs) {
8136 RTArgs.BasePointersArray = ConstantPointerNull::get(VoidPtrPtrTy);
8137 RTArgs.PointersArray = ConstantPointerNull::get(VoidPtrPtrTy);
8138 RTArgs.SizesArray = ConstantPointerNull::get(Int64PtrTy);
8139 RTArgs.MapTypesArray = ConstantPointerNull::get(Int64PtrTy);
8140 RTArgs.MapNamesArray = ConstantPointerNull::get(VoidPtrPtrTy);
8141 RTArgs.MappersArray = ConstantPointerNull::get(VoidPtrPtrTy);
8142 return;
8143 }
8144
8145 RTArgs.BasePointersArray = Builder.CreateConstInBoundsGEP2_32(
8146 ArrayType::get(VoidPtrTy, Info.NumberOfPtrs),
8147 Info.RTArgs.BasePointersArray,
8148 /*Idx0=*/0, /*Idx1=*/0);
8149 RTArgs.PointersArray = Builder.CreateConstInBoundsGEP2_32(
8150 ArrayType::get(VoidPtrTy, Info.NumberOfPtrs), Info.RTArgs.PointersArray,
8151 /*Idx0=*/0,
8152 /*Idx1=*/0);
8153 RTArgs.SizesArray = Builder.CreateConstInBoundsGEP2_32(
8154 ArrayType::get(Int64Ty, Info.NumberOfPtrs), Info.RTArgs.SizesArray,
8155 /*Idx0=*/0, /*Idx1=*/0);
8156 RTArgs.MapTypesArray = Builder.CreateConstInBoundsGEP2_32(
8157 ArrayType::get(Int64Ty, Info.NumberOfPtrs),
8158 ForEndCall && Info.RTArgs.MapTypesArrayEnd ? Info.RTArgs.MapTypesArrayEnd
8159 : Info.RTArgs.MapTypesArray,
8160 /*Idx0=*/0,
8161 /*Idx1=*/0);
8162
8163 // Only emit the mapper information arrays if debug information is
8164 // requested.
8165 if (!Info.EmitDebug)
8166 RTArgs.MapNamesArray = ConstantPointerNull::get(VoidPtrPtrTy);
8167 else
8168 RTArgs.MapNamesArray = Builder.CreateConstInBoundsGEP2_32(
8169 ArrayType::get(VoidPtrTy, Info.NumberOfPtrs), Info.RTArgs.MapNamesArray,
8170 /*Idx0=*/0,
8171 /*Idx1=*/0);
8172 // If there is no user-defined mapper, set the mapper array to nullptr to
8173 // avoid an unnecessary data privatization
8174 if (!Info.HasMapper)
8175 RTArgs.MappersArray = ConstantPointerNull::get(VoidPtrPtrTy);
8176 else
8177 RTArgs.MappersArray =
8178 Builder.CreatePointerCast(Info.RTArgs.MappersArray, VoidPtrPtrTy);
8179 }
8180
emitNonContiguousDescriptor(InsertPointTy AllocaIP,InsertPointTy CodeGenIP,MapInfosTy & CombinedInfo,TargetDataInfo & Info)8181 void OpenMPIRBuilder::emitNonContiguousDescriptor(InsertPointTy AllocaIP,
8182 InsertPointTy CodeGenIP,
8183 MapInfosTy &CombinedInfo,
8184 TargetDataInfo &Info) {
8185 MapInfosTy::StructNonContiguousInfo &NonContigInfo =
8186 CombinedInfo.NonContigInfo;
8187
8188 // Build an array of struct descriptor_dim and then assign it to
8189 // offload_args.
8190 //
8191 // struct descriptor_dim {
8192 // uint64_t offset;
8193 // uint64_t count;
8194 // uint64_t stride
8195 // };
8196 Type *Int64Ty = Builder.getInt64Ty();
8197 StructType *DimTy = StructType::create(
8198 M.getContext(), ArrayRef<Type *>({Int64Ty, Int64Ty, Int64Ty}),
8199 "struct.descriptor_dim");
8200
8201 enum { OffsetFD = 0, CountFD, StrideFD };
8202 // We need two index variable here since the size of "Dims" is the same as
8203 // the size of Components, however, the size of offset, count, and stride is
8204 // equal to the size of base declaration that is non-contiguous.
8205 for (unsigned I = 0, L = 0, E = NonContigInfo.Dims.size(); I < E; ++I) {
8206 // Skip emitting ir if dimension size is 1 since it cannot be
8207 // non-contiguous.
8208 if (NonContigInfo.Dims[I] == 1)
8209 continue;
8210 Builder.restoreIP(AllocaIP);
8211 ArrayType *ArrayTy = ArrayType::get(DimTy, NonContigInfo.Dims[I]);
8212 AllocaInst *DimsAddr =
8213 Builder.CreateAlloca(ArrayTy, /* ArraySize = */ nullptr, "dims");
8214 Builder.restoreIP(CodeGenIP);
8215 for (unsigned II = 0, EE = NonContigInfo.Dims[I]; II < EE; ++II) {
8216 unsigned RevIdx = EE - II - 1;
8217 Value *DimsLVal = Builder.CreateInBoundsGEP(
8218 DimsAddr->getAllocatedType(), DimsAddr,
8219 {Builder.getInt64(0), Builder.getInt64(II)});
8220 // Offset
8221 Value *OffsetLVal = Builder.CreateStructGEP(DimTy, DimsLVal, OffsetFD);
8222 Builder.CreateAlignedStore(
8223 NonContigInfo.Offsets[L][RevIdx], OffsetLVal,
8224 M.getDataLayout().getPrefTypeAlign(OffsetLVal->getType()));
8225 // Count
8226 Value *CountLVal = Builder.CreateStructGEP(DimTy, DimsLVal, CountFD);
8227 Builder.CreateAlignedStore(
8228 NonContigInfo.Counts[L][RevIdx], CountLVal,
8229 M.getDataLayout().getPrefTypeAlign(CountLVal->getType()));
8230 // Stride
8231 Value *StrideLVal = Builder.CreateStructGEP(DimTy, DimsLVal, StrideFD);
8232 Builder.CreateAlignedStore(
8233 NonContigInfo.Strides[L][RevIdx], StrideLVal,
8234 M.getDataLayout().getPrefTypeAlign(CountLVal->getType()));
8235 }
8236 // args[I] = &dims
8237 Builder.restoreIP(CodeGenIP);
8238 Value *DAddr = Builder.CreatePointerBitCastOrAddrSpaceCast(
8239 DimsAddr, Builder.getPtrTy());
8240 Value *P = Builder.CreateConstInBoundsGEP2_32(
8241 ArrayType::get(Builder.getPtrTy(), Info.NumberOfPtrs),
8242 Info.RTArgs.PointersArray, 0, I);
8243 Builder.CreateAlignedStore(
8244 DAddr, P, M.getDataLayout().getPrefTypeAlign(Builder.getPtrTy()));
8245 ++L;
8246 }
8247 }
8248
emitUDMapperArrayInitOrDel(Function * MapperFn,Value * MapperHandle,Value * Base,Value * Begin,Value * Size,Value * MapType,Value * MapName,TypeSize ElementSize,BasicBlock * ExitBB,bool IsInit)8249 void OpenMPIRBuilder::emitUDMapperArrayInitOrDel(
8250 Function *MapperFn, Value *MapperHandle, Value *Base, Value *Begin,
8251 Value *Size, Value *MapType, Value *MapName, TypeSize ElementSize,
8252 BasicBlock *ExitBB, bool IsInit) {
8253 StringRef Prefix = IsInit ? ".init" : ".del";
8254
8255 // Evaluate if this is an array section.
8256 BasicBlock *BodyBB = BasicBlock::Create(
8257 M.getContext(), createPlatformSpecificName({"omp.array", Prefix}));
8258 Value *IsArray =
8259 Builder.CreateICmpSGT(Size, Builder.getInt64(1), "omp.arrayinit.isarray");
8260 Value *DeleteBit = Builder.CreateAnd(
8261 MapType,
8262 Builder.getInt64(
8263 static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
8264 OpenMPOffloadMappingFlags::OMP_MAP_DELETE)));
8265 Value *DeleteCond;
8266 Value *Cond;
8267 if (IsInit) {
8268 // base != begin?
8269 Value *BaseIsBegin = Builder.CreateICmpNE(Base, Begin);
8270 // IsPtrAndObj?
8271 Value *PtrAndObjBit = Builder.CreateAnd(
8272 MapType,
8273 Builder.getInt64(
8274 static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
8275 OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ)));
8276 PtrAndObjBit = Builder.CreateIsNotNull(PtrAndObjBit);
8277 BaseIsBegin = Builder.CreateAnd(BaseIsBegin, PtrAndObjBit);
8278 Cond = Builder.CreateOr(IsArray, BaseIsBegin);
8279 DeleteCond = Builder.CreateIsNull(
8280 DeleteBit,
8281 createPlatformSpecificName({"omp.array", Prefix, ".delete"}));
8282 } else {
8283 Cond = IsArray;
8284 DeleteCond = Builder.CreateIsNotNull(
8285 DeleteBit,
8286 createPlatformSpecificName({"omp.array", Prefix, ".delete"}));
8287 }
8288 Cond = Builder.CreateAnd(Cond, DeleteCond);
8289 Builder.CreateCondBr(Cond, BodyBB, ExitBB);
8290
8291 emitBlock(BodyBB, MapperFn);
8292 // Get the array size by multiplying element size and element number (i.e., \p
8293 // Size).
8294 Value *ArraySize = Builder.CreateNUWMul(Size, Builder.getInt64(ElementSize));
8295 // Remove OMP_MAP_TO and OMP_MAP_FROM from the map type, so that it achieves
8296 // memory allocation/deletion purpose only.
8297 Value *MapTypeArg = Builder.CreateAnd(
8298 MapType,
8299 Builder.getInt64(
8300 ~static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
8301 OpenMPOffloadMappingFlags::OMP_MAP_TO |
8302 OpenMPOffloadMappingFlags::OMP_MAP_FROM)));
8303 MapTypeArg = Builder.CreateOr(
8304 MapTypeArg,
8305 Builder.getInt64(
8306 static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
8307 OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT)));
8308
8309 // Call the runtime API __tgt_push_mapper_component to fill up the runtime
8310 // data structure.
8311 Value *OffloadingArgs[] = {MapperHandle, Base, Begin,
8312 ArraySize, MapTypeArg, MapName};
8313 Builder.CreateCall(
8314 getOrCreateRuntimeFunction(M, OMPRTL___tgt_push_mapper_component),
8315 OffloadingArgs);
8316 }
8317
emitUserDefinedMapper(function_ref<MapInfosOrErrorTy (InsertPointTy CodeGenIP,llvm::Value * PtrPHI,llvm::Value * BeginArg)> GenMapInfoCB,Type * ElemTy,StringRef FuncName,CustomMapperCallbackTy CustomMapperCB)8318 Expected<Function *> OpenMPIRBuilder::emitUserDefinedMapper(
8319 function_ref<MapInfosOrErrorTy(InsertPointTy CodeGenIP, llvm::Value *PtrPHI,
8320 llvm::Value *BeginArg)>
8321 GenMapInfoCB,
8322 Type *ElemTy, StringRef FuncName, CustomMapperCallbackTy CustomMapperCB) {
8323 SmallVector<Type *> Params;
8324 Params.emplace_back(Builder.getPtrTy());
8325 Params.emplace_back(Builder.getPtrTy());
8326 Params.emplace_back(Builder.getPtrTy());
8327 Params.emplace_back(Builder.getInt64Ty());
8328 Params.emplace_back(Builder.getInt64Ty());
8329 Params.emplace_back(Builder.getPtrTy());
8330
8331 auto *FnTy =
8332 FunctionType::get(Builder.getVoidTy(), Params, /* IsVarArg */ false);
8333
8334 SmallString<64> TyStr;
8335 raw_svector_ostream Out(TyStr);
8336 Function *MapperFn =
8337 Function::Create(FnTy, GlobalValue::InternalLinkage, FuncName, M);
8338 MapperFn->addFnAttr(Attribute::NoInline);
8339 MapperFn->addFnAttr(Attribute::NoUnwind);
8340 MapperFn->addParamAttr(0, Attribute::NoUndef);
8341 MapperFn->addParamAttr(1, Attribute::NoUndef);
8342 MapperFn->addParamAttr(2, Attribute::NoUndef);
8343 MapperFn->addParamAttr(3, Attribute::NoUndef);
8344 MapperFn->addParamAttr(4, Attribute::NoUndef);
8345 MapperFn->addParamAttr(5, Attribute::NoUndef);
8346
8347 // Start the mapper function code generation.
8348 BasicBlock *EntryBB = BasicBlock::Create(M.getContext(), "entry", MapperFn);
8349 auto SavedIP = Builder.saveIP();
8350 Builder.SetInsertPoint(EntryBB);
8351
8352 Value *MapperHandle = MapperFn->getArg(0);
8353 Value *BaseIn = MapperFn->getArg(1);
8354 Value *BeginIn = MapperFn->getArg(2);
8355 Value *Size = MapperFn->getArg(3);
8356 Value *MapType = MapperFn->getArg(4);
8357 Value *MapName = MapperFn->getArg(5);
8358
8359 // Compute the starting and end addresses of array elements.
8360 // Prepare common arguments for array initiation and deletion.
8361 // Convert the size in bytes into the number of array elements.
8362 TypeSize ElementSize = M.getDataLayout().getTypeStoreSize(ElemTy);
8363 Size = Builder.CreateExactUDiv(Size, Builder.getInt64(ElementSize));
8364 Value *PtrBegin = BeginIn;
8365 Value *PtrEnd = Builder.CreateGEP(ElemTy, PtrBegin, Size);
8366
8367 // Emit array initiation if this is an array section and \p MapType indicates
8368 // that memory allocation is required.
8369 BasicBlock *HeadBB = BasicBlock::Create(M.getContext(), "omp.arraymap.head");
8370 emitUDMapperArrayInitOrDel(MapperFn, MapperHandle, BaseIn, BeginIn, Size,
8371 MapType, MapName, ElementSize, HeadBB,
8372 /*IsInit=*/true);
8373
8374 // Emit a for loop to iterate through SizeArg of elements and map all of them.
8375
8376 // Emit the loop header block.
8377 emitBlock(HeadBB, MapperFn);
8378 BasicBlock *BodyBB = BasicBlock::Create(M.getContext(), "omp.arraymap.body");
8379 BasicBlock *DoneBB = BasicBlock::Create(M.getContext(), "omp.done");
8380 // Evaluate whether the initial condition is satisfied.
8381 Value *IsEmpty =
8382 Builder.CreateICmpEQ(PtrBegin, PtrEnd, "omp.arraymap.isempty");
8383 Builder.CreateCondBr(IsEmpty, DoneBB, BodyBB);
8384
8385 // Emit the loop body block.
8386 emitBlock(BodyBB, MapperFn);
8387 BasicBlock *LastBB = BodyBB;
8388 PHINode *PtrPHI =
8389 Builder.CreatePHI(PtrBegin->getType(), 2, "omp.arraymap.ptrcurrent");
8390 PtrPHI->addIncoming(PtrBegin, HeadBB);
8391
8392 // Get map clause information. Fill up the arrays with all mapped variables.
8393 MapInfosOrErrorTy Info = GenMapInfoCB(Builder.saveIP(), PtrPHI, BeginIn);
8394 if (!Info)
8395 return Info.takeError();
8396
8397 // Call the runtime API __tgt_mapper_num_components to get the number of
8398 // pre-existing components.
8399 Value *OffloadingArgs[] = {MapperHandle};
8400 Value *PreviousSize = Builder.CreateCall(
8401 getOrCreateRuntimeFunction(M, OMPRTL___tgt_mapper_num_components),
8402 OffloadingArgs);
8403 Value *ShiftedPreviousSize =
8404 Builder.CreateShl(PreviousSize, Builder.getInt64(getFlagMemberOffset()));
8405
8406 // Fill up the runtime mapper handle for all components.
8407 for (unsigned I = 0; I < Info->BasePointers.size(); ++I) {
8408 Value *CurBaseArg = Info->BasePointers[I];
8409 Value *CurBeginArg = Info->Pointers[I];
8410 Value *CurSizeArg = Info->Sizes[I];
8411 Value *CurNameArg = Info->Names.size()
8412 ? Info->Names[I]
8413 : Constant::getNullValue(Builder.getPtrTy());
8414
8415 // Extract the MEMBER_OF field from the map type.
8416 Value *OriMapType = Builder.getInt64(
8417 static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
8418 Info->Types[I]));
8419 Value *MemberMapType =
8420 Builder.CreateNUWAdd(OriMapType, ShiftedPreviousSize);
8421
8422 // Combine the map type inherited from user-defined mapper with that
8423 // specified in the program. According to the OMP_MAP_TO and OMP_MAP_FROM
8424 // bits of the \a MapType, which is the input argument of the mapper
8425 // function, the following code will set the OMP_MAP_TO and OMP_MAP_FROM
8426 // bits of MemberMapType.
8427 // [OpenMP 5.0], 1.2.6. map-type decay.
8428 // | alloc | to | from | tofrom | release | delete
8429 // ----------------------------------------------------------
8430 // alloc | alloc | alloc | alloc | alloc | release | delete
8431 // to | alloc | to | alloc | to | release | delete
8432 // from | alloc | alloc | from | from | release | delete
8433 // tofrom | alloc | to | from | tofrom | release | delete
8434 Value *LeftToFrom = Builder.CreateAnd(
8435 MapType,
8436 Builder.getInt64(
8437 static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
8438 OpenMPOffloadMappingFlags::OMP_MAP_TO |
8439 OpenMPOffloadMappingFlags::OMP_MAP_FROM)));
8440 BasicBlock *AllocBB = BasicBlock::Create(M.getContext(), "omp.type.alloc");
8441 BasicBlock *AllocElseBB =
8442 BasicBlock::Create(M.getContext(), "omp.type.alloc.else");
8443 BasicBlock *ToBB = BasicBlock::Create(M.getContext(), "omp.type.to");
8444 BasicBlock *ToElseBB =
8445 BasicBlock::Create(M.getContext(), "omp.type.to.else");
8446 BasicBlock *FromBB = BasicBlock::Create(M.getContext(), "omp.type.from");
8447 BasicBlock *EndBB = BasicBlock::Create(M.getContext(), "omp.type.end");
8448 Value *IsAlloc = Builder.CreateIsNull(LeftToFrom);
8449 Builder.CreateCondBr(IsAlloc, AllocBB, AllocElseBB);
8450 // In case of alloc, clear OMP_MAP_TO and OMP_MAP_FROM.
8451 emitBlock(AllocBB, MapperFn);
8452 Value *AllocMapType = Builder.CreateAnd(
8453 MemberMapType,
8454 Builder.getInt64(
8455 ~static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
8456 OpenMPOffloadMappingFlags::OMP_MAP_TO |
8457 OpenMPOffloadMappingFlags::OMP_MAP_FROM)));
8458 Builder.CreateBr(EndBB);
8459 emitBlock(AllocElseBB, MapperFn);
8460 Value *IsTo = Builder.CreateICmpEQ(
8461 LeftToFrom,
8462 Builder.getInt64(
8463 static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
8464 OpenMPOffloadMappingFlags::OMP_MAP_TO)));
8465 Builder.CreateCondBr(IsTo, ToBB, ToElseBB);
8466 // In case of to, clear OMP_MAP_FROM.
8467 emitBlock(ToBB, MapperFn);
8468 Value *ToMapType = Builder.CreateAnd(
8469 MemberMapType,
8470 Builder.getInt64(
8471 ~static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
8472 OpenMPOffloadMappingFlags::OMP_MAP_FROM)));
8473 Builder.CreateBr(EndBB);
8474 emitBlock(ToElseBB, MapperFn);
8475 Value *IsFrom = Builder.CreateICmpEQ(
8476 LeftToFrom,
8477 Builder.getInt64(
8478 static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
8479 OpenMPOffloadMappingFlags::OMP_MAP_FROM)));
8480 Builder.CreateCondBr(IsFrom, FromBB, EndBB);
8481 // In case of from, clear OMP_MAP_TO.
8482 emitBlock(FromBB, MapperFn);
8483 Value *FromMapType = Builder.CreateAnd(
8484 MemberMapType,
8485 Builder.getInt64(
8486 ~static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
8487 OpenMPOffloadMappingFlags::OMP_MAP_TO)));
8488 // In case of tofrom, do nothing.
8489 emitBlock(EndBB, MapperFn);
8490 LastBB = EndBB;
8491 PHINode *CurMapType =
8492 Builder.CreatePHI(Builder.getInt64Ty(), 4, "omp.maptype");
8493 CurMapType->addIncoming(AllocMapType, AllocBB);
8494 CurMapType->addIncoming(ToMapType, ToBB);
8495 CurMapType->addIncoming(FromMapType, FromBB);
8496 CurMapType->addIncoming(MemberMapType, ToElseBB);
8497
8498 Value *OffloadingArgs[] = {MapperHandle, CurBaseArg, CurBeginArg,
8499 CurSizeArg, CurMapType, CurNameArg};
8500
8501 auto ChildMapperFn = CustomMapperCB(I);
8502 if (!ChildMapperFn)
8503 return ChildMapperFn.takeError();
8504 if (*ChildMapperFn) {
8505 // Call the corresponding mapper function.
8506 Builder.CreateCall(*ChildMapperFn, OffloadingArgs)->setDoesNotThrow();
8507 } else {
8508 // Call the runtime API __tgt_push_mapper_component to fill up the runtime
8509 // data structure.
8510 Builder.CreateCall(
8511 getOrCreateRuntimeFunction(M, OMPRTL___tgt_push_mapper_component),
8512 OffloadingArgs);
8513 }
8514 }
8515
8516 // Update the pointer to point to the next element that needs to be mapped,
8517 // and check whether we have mapped all elements.
8518 Value *PtrNext = Builder.CreateConstGEP1_32(ElemTy, PtrPHI, /*Idx0=*/1,
8519 "omp.arraymap.next");
8520 PtrPHI->addIncoming(PtrNext, LastBB);
8521 Value *IsDone = Builder.CreateICmpEQ(PtrNext, PtrEnd, "omp.arraymap.isdone");
8522 BasicBlock *ExitBB = BasicBlock::Create(M.getContext(), "omp.arraymap.exit");
8523 Builder.CreateCondBr(IsDone, ExitBB, BodyBB);
8524
8525 emitBlock(ExitBB, MapperFn);
8526 // Emit array deletion if this is an array section and \p MapType indicates
8527 // that deletion is required.
8528 emitUDMapperArrayInitOrDel(MapperFn, MapperHandle, BaseIn, BeginIn, Size,
8529 MapType, MapName, ElementSize, DoneBB,
8530 /*IsInit=*/false);
8531
8532 // Emit the function exit block.
8533 emitBlock(DoneBB, MapperFn, /*IsFinished=*/true);
8534
8535 Builder.CreateRetVoid();
8536 Builder.restoreIP(SavedIP);
8537 return MapperFn;
8538 }
8539
emitOffloadingArrays(InsertPointTy AllocaIP,InsertPointTy CodeGenIP,MapInfosTy & CombinedInfo,TargetDataInfo & Info,CustomMapperCallbackTy CustomMapperCB,bool IsNonContiguous,function_ref<void (unsigned int,Value *)> DeviceAddrCB)8540 Error OpenMPIRBuilder::emitOffloadingArrays(
8541 InsertPointTy AllocaIP, InsertPointTy CodeGenIP, MapInfosTy &CombinedInfo,
8542 TargetDataInfo &Info, CustomMapperCallbackTy CustomMapperCB,
8543 bool IsNonContiguous,
8544 function_ref<void(unsigned int, Value *)> DeviceAddrCB) {
8545
8546 // Reset the array information.
8547 Info.clearArrayInfo();
8548 Info.NumberOfPtrs = CombinedInfo.BasePointers.size();
8549
8550 if (Info.NumberOfPtrs == 0)
8551 return Error::success();
8552
8553 Builder.restoreIP(AllocaIP);
8554 // Detect if we have any capture size requiring runtime evaluation of the
8555 // size so that a constant array could be eventually used.
8556 ArrayType *PointerArrayType =
8557 ArrayType::get(Builder.getPtrTy(), Info.NumberOfPtrs);
8558
8559 Info.RTArgs.BasePointersArray = Builder.CreateAlloca(
8560 PointerArrayType, /* ArraySize = */ nullptr, ".offload_baseptrs");
8561
8562 Info.RTArgs.PointersArray = Builder.CreateAlloca(
8563 PointerArrayType, /* ArraySize = */ nullptr, ".offload_ptrs");
8564 AllocaInst *MappersArray = Builder.CreateAlloca(
8565 PointerArrayType, /* ArraySize = */ nullptr, ".offload_mappers");
8566 Info.RTArgs.MappersArray = MappersArray;
8567
8568 // If we don't have any VLA types or other types that require runtime
8569 // evaluation, we can use a constant array for the map sizes, otherwise we
8570 // need to fill up the arrays as we do for the pointers.
8571 Type *Int64Ty = Builder.getInt64Ty();
8572 SmallVector<Constant *> ConstSizes(CombinedInfo.Sizes.size(),
8573 ConstantInt::get(Int64Ty, 0));
8574 SmallBitVector RuntimeSizes(CombinedInfo.Sizes.size());
8575 for (unsigned I = 0, E = CombinedInfo.Sizes.size(); I < E; ++I) {
8576 if (auto *CI = dyn_cast<Constant>(CombinedInfo.Sizes[I])) {
8577 if (!isa<ConstantExpr>(CI) && !isa<GlobalValue>(CI)) {
8578 if (IsNonContiguous &&
8579 static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
8580 CombinedInfo.Types[I] &
8581 OpenMPOffloadMappingFlags::OMP_MAP_NON_CONTIG))
8582 ConstSizes[I] =
8583 ConstantInt::get(Int64Ty, CombinedInfo.NonContigInfo.Dims[I]);
8584 else
8585 ConstSizes[I] = CI;
8586 continue;
8587 }
8588 }
8589 RuntimeSizes.set(I);
8590 }
8591
8592 if (RuntimeSizes.all()) {
8593 ArrayType *SizeArrayType = ArrayType::get(Int64Ty, Info.NumberOfPtrs);
8594 Info.RTArgs.SizesArray = Builder.CreateAlloca(
8595 SizeArrayType, /* ArraySize = */ nullptr, ".offload_sizes");
8596 Builder.restoreIP(CodeGenIP);
8597 } else {
8598 auto *SizesArrayInit = ConstantArray::get(
8599 ArrayType::get(Int64Ty, ConstSizes.size()), ConstSizes);
8600 std::string Name = createPlatformSpecificName({"offload_sizes"});
8601 auto *SizesArrayGbl =
8602 new GlobalVariable(M, SizesArrayInit->getType(), /*isConstant=*/true,
8603 GlobalValue::PrivateLinkage, SizesArrayInit, Name);
8604 SizesArrayGbl->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
8605
8606 if (!RuntimeSizes.any()) {
8607 Info.RTArgs.SizesArray = SizesArrayGbl;
8608 } else {
8609 unsigned IndexSize = M.getDataLayout().getIndexSizeInBits(0);
8610 Align OffloadSizeAlign = M.getDataLayout().getABIIntegerTypeAlignment(64);
8611 ArrayType *SizeArrayType = ArrayType::get(Int64Ty, Info.NumberOfPtrs);
8612 AllocaInst *Buffer = Builder.CreateAlloca(
8613 SizeArrayType, /* ArraySize = */ nullptr, ".offload_sizes");
8614 Buffer->setAlignment(OffloadSizeAlign);
8615 Builder.restoreIP(CodeGenIP);
8616 Builder.CreateMemCpy(
8617 Buffer, M.getDataLayout().getPrefTypeAlign(Buffer->getType()),
8618 SizesArrayGbl, OffloadSizeAlign,
8619 Builder.getIntN(
8620 IndexSize,
8621 Buffer->getAllocationSize(M.getDataLayout())->getFixedValue()));
8622
8623 Info.RTArgs.SizesArray = Buffer;
8624 }
8625 Builder.restoreIP(CodeGenIP);
8626 }
8627
8628 // The map types are always constant so we don't need to generate code to
8629 // fill arrays. Instead, we create an array constant.
8630 SmallVector<uint64_t, 4> Mapping;
8631 for (auto mapFlag : CombinedInfo.Types)
8632 Mapping.push_back(
8633 static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
8634 mapFlag));
8635 std::string MaptypesName = createPlatformSpecificName({"offload_maptypes"});
8636 auto *MapTypesArrayGbl = createOffloadMaptypes(Mapping, MaptypesName);
8637 Info.RTArgs.MapTypesArray = MapTypesArrayGbl;
8638
8639 // The information types are only built if provided.
8640 if (!CombinedInfo.Names.empty()) {
8641 auto *MapNamesArrayGbl = createOffloadMapnames(
8642 CombinedInfo.Names, createPlatformSpecificName({"offload_mapnames"}));
8643 Info.RTArgs.MapNamesArray = MapNamesArrayGbl;
8644 Info.EmitDebug = true;
8645 } else {
8646 Info.RTArgs.MapNamesArray =
8647 Constant::getNullValue(PointerType::getUnqual(Builder.getContext()));
8648 Info.EmitDebug = false;
8649 }
8650
8651 // If there's a present map type modifier, it must not be applied to the end
8652 // of a region, so generate a separate map type array in that case.
8653 if (Info.separateBeginEndCalls()) {
8654 bool EndMapTypesDiffer = false;
8655 for (uint64_t &Type : Mapping) {
8656 if (Type & static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
8657 OpenMPOffloadMappingFlags::OMP_MAP_PRESENT)) {
8658 Type &= ~static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
8659 OpenMPOffloadMappingFlags::OMP_MAP_PRESENT);
8660 EndMapTypesDiffer = true;
8661 }
8662 }
8663 if (EndMapTypesDiffer) {
8664 MapTypesArrayGbl = createOffloadMaptypes(Mapping, MaptypesName);
8665 Info.RTArgs.MapTypesArrayEnd = MapTypesArrayGbl;
8666 }
8667 }
8668
8669 PointerType *PtrTy = Builder.getPtrTy();
8670 for (unsigned I = 0; I < Info.NumberOfPtrs; ++I) {
8671 Value *BPVal = CombinedInfo.BasePointers[I];
8672 Value *BP = Builder.CreateConstInBoundsGEP2_32(
8673 ArrayType::get(PtrTy, Info.NumberOfPtrs), Info.RTArgs.BasePointersArray,
8674 0, I);
8675 Builder.CreateAlignedStore(BPVal, BP,
8676 M.getDataLayout().getPrefTypeAlign(PtrTy));
8677
8678 if (Info.requiresDevicePointerInfo()) {
8679 if (CombinedInfo.DevicePointers[I] == DeviceInfoTy::Pointer) {
8680 CodeGenIP = Builder.saveIP();
8681 Builder.restoreIP(AllocaIP);
8682 Info.DevicePtrInfoMap[BPVal] = {BP, Builder.CreateAlloca(PtrTy)};
8683 Builder.restoreIP(CodeGenIP);
8684 if (DeviceAddrCB)
8685 DeviceAddrCB(I, Info.DevicePtrInfoMap[BPVal].second);
8686 } else if (CombinedInfo.DevicePointers[I] == DeviceInfoTy::Address) {
8687 Info.DevicePtrInfoMap[BPVal] = {BP, BP};
8688 if (DeviceAddrCB)
8689 DeviceAddrCB(I, BP);
8690 }
8691 }
8692
8693 Value *PVal = CombinedInfo.Pointers[I];
8694 Value *P = Builder.CreateConstInBoundsGEP2_32(
8695 ArrayType::get(PtrTy, Info.NumberOfPtrs), Info.RTArgs.PointersArray, 0,
8696 I);
8697 // TODO: Check alignment correct.
8698 Builder.CreateAlignedStore(PVal, P,
8699 M.getDataLayout().getPrefTypeAlign(PtrTy));
8700
8701 if (RuntimeSizes.test(I)) {
8702 Value *S = Builder.CreateConstInBoundsGEP2_32(
8703 ArrayType::get(Int64Ty, Info.NumberOfPtrs), Info.RTArgs.SizesArray,
8704 /*Idx0=*/0,
8705 /*Idx1=*/I);
8706 Builder.CreateAlignedStore(Builder.CreateIntCast(CombinedInfo.Sizes[I],
8707 Int64Ty,
8708 /*isSigned=*/true),
8709 S, M.getDataLayout().getPrefTypeAlign(PtrTy));
8710 }
8711 // Fill up the mapper array.
8712 unsigned IndexSize = M.getDataLayout().getIndexSizeInBits(0);
8713 Value *MFunc = ConstantPointerNull::get(PtrTy);
8714
8715 auto CustomMFunc = CustomMapperCB(I);
8716 if (!CustomMFunc)
8717 return CustomMFunc.takeError();
8718 if (*CustomMFunc)
8719 MFunc = Builder.CreatePointerCast(*CustomMFunc, PtrTy);
8720
8721 Value *MAddr = Builder.CreateInBoundsGEP(
8722 MappersArray->getAllocatedType(), MappersArray,
8723 {Builder.getIntN(IndexSize, 0), Builder.getIntN(IndexSize, I)});
8724 Builder.CreateAlignedStore(
8725 MFunc, MAddr, M.getDataLayout().getPrefTypeAlign(MAddr->getType()));
8726 }
8727
8728 if (!IsNonContiguous || CombinedInfo.NonContigInfo.Offsets.empty() ||
8729 Info.NumberOfPtrs == 0)
8730 return Error::success();
8731 emitNonContiguousDescriptor(AllocaIP, CodeGenIP, CombinedInfo, Info);
8732 return Error::success();
8733 }
8734
emitBranch(BasicBlock * Target)8735 void OpenMPIRBuilder::emitBranch(BasicBlock *Target) {
8736 BasicBlock *CurBB = Builder.GetInsertBlock();
8737
8738 if (!CurBB || CurBB->getTerminator()) {
8739 // If there is no insert point or the previous block is already
8740 // terminated, don't touch it.
8741 } else {
8742 // Otherwise, create a fall-through branch.
8743 Builder.CreateBr(Target);
8744 }
8745
8746 Builder.ClearInsertionPoint();
8747 }
8748
emitBlock(BasicBlock * BB,Function * CurFn,bool IsFinished)8749 void OpenMPIRBuilder::emitBlock(BasicBlock *BB, Function *CurFn,
8750 bool IsFinished) {
8751 BasicBlock *CurBB = Builder.GetInsertBlock();
8752
8753 // Fall out of the current block (if necessary).
8754 emitBranch(BB);
8755
8756 if (IsFinished && BB->use_empty()) {
8757 BB->eraseFromParent();
8758 return;
8759 }
8760
8761 // Place the block after the current block, if possible, or else at
8762 // the end of the function.
8763 if (CurBB && CurBB->getParent())
8764 CurFn->insert(std::next(CurBB->getIterator()), BB);
8765 else
8766 CurFn->insert(CurFn->end(), BB);
8767 Builder.SetInsertPoint(BB);
8768 }
8769
emitIfClause(Value * Cond,BodyGenCallbackTy ThenGen,BodyGenCallbackTy ElseGen,InsertPointTy AllocaIP)8770 Error OpenMPIRBuilder::emitIfClause(Value *Cond, BodyGenCallbackTy ThenGen,
8771 BodyGenCallbackTy ElseGen,
8772 InsertPointTy AllocaIP) {
8773 // If the condition constant folds and can be elided, try to avoid emitting
8774 // the condition and the dead arm of the if/else.
8775 if (auto *CI = dyn_cast<ConstantInt>(Cond)) {
8776 auto CondConstant = CI->getSExtValue();
8777 if (CondConstant)
8778 return ThenGen(AllocaIP, Builder.saveIP());
8779
8780 return ElseGen(AllocaIP, Builder.saveIP());
8781 }
8782
8783 Function *CurFn = Builder.GetInsertBlock()->getParent();
8784
8785 // Otherwise, the condition did not fold, or we couldn't elide it. Just
8786 // emit the conditional branch.
8787 BasicBlock *ThenBlock = BasicBlock::Create(M.getContext(), "omp_if.then");
8788 BasicBlock *ElseBlock = BasicBlock::Create(M.getContext(), "omp_if.else");
8789 BasicBlock *ContBlock = BasicBlock::Create(M.getContext(), "omp_if.end");
8790 Builder.CreateCondBr(Cond, ThenBlock, ElseBlock);
8791 // Emit the 'then' code.
8792 emitBlock(ThenBlock, CurFn);
8793 if (Error Err = ThenGen(AllocaIP, Builder.saveIP()))
8794 return Err;
8795 emitBranch(ContBlock);
8796 // Emit the 'else' code if present.
8797 // There is no need to emit line number for unconditional branch.
8798 emitBlock(ElseBlock, CurFn);
8799 if (Error Err = ElseGen(AllocaIP, Builder.saveIP()))
8800 return Err;
8801 // There is no need to emit line number for unconditional branch.
8802 emitBranch(ContBlock);
8803 // Emit the continuation block for code after the if.
8804 emitBlock(ContBlock, CurFn, /*IsFinished=*/true);
8805 return Error::success();
8806 }
8807
checkAndEmitFlushAfterAtomic(const LocationDescription & Loc,llvm::AtomicOrdering AO,AtomicKind AK)8808 bool OpenMPIRBuilder::checkAndEmitFlushAfterAtomic(
8809 const LocationDescription &Loc, llvm::AtomicOrdering AO, AtomicKind AK) {
8810 assert(!(AO == AtomicOrdering::NotAtomic ||
8811 AO == llvm::AtomicOrdering::Unordered) &&
8812 "Unexpected Atomic Ordering.");
8813
8814 bool Flush = false;
8815 llvm::AtomicOrdering FlushAO = AtomicOrdering::Monotonic;
8816
8817 switch (AK) {
8818 case Read:
8819 if (AO == AtomicOrdering::Acquire || AO == AtomicOrdering::AcquireRelease ||
8820 AO == AtomicOrdering::SequentiallyConsistent) {
8821 FlushAO = AtomicOrdering::Acquire;
8822 Flush = true;
8823 }
8824 break;
8825 case Write:
8826 case Compare:
8827 case Update:
8828 if (AO == AtomicOrdering::Release || AO == AtomicOrdering::AcquireRelease ||
8829 AO == AtomicOrdering::SequentiallyConsistent) {
8830 FlushAO = AtomicOrdering::Release;
8831 Flush = true;
8832 }
8833 break;
8834 case Capture:
8835 switch (AO) {
8836 case AtomicOrdering::Acquire:
8837 FlushAO = AtomicOrdering::Acquire;
8838 Flush = true;
8839 break;
8840 case AtomicOrdering::Release:
8841 FlushAO = AtomicOrdering::Release;
8842 Flush = true;
8843 break;
8844 case AtomicOrdering::AcquireRelease:
8845 case AtomicOrdering::SequentiallyConsistent:
8846 FlushAO = AtomicOrdering::AcquireRelease;
8847 Flush = true;
8848 break;
8849 default:
8850 // do nothing - leave silently.
8851 break;
8852 }
8853 }
8854
8855 if (Flush) {
8856 // Currently Flush RT call still doesn't take memory_ordering, so for when
8857 // that happens, this tries to do the resolution of which atomic ordering
8858 // to use with but issue the flush call
8859 // TODO: pass `FlushAO` after memory ordering support is added
8860 (void)FlushAO;
8861 emitFlush(Loc);
8862 }
8863
8864 // for AO == AtomicOrdering::Monotonic and all other case combinations
8865 // do nothing
8866 return Flush;
8867 }
8868
8869 OpenMPIRBuilder::InsertPointTy
createAtomicRead(const LocationDescription & Loc,AtomicOpValue & X,AtomicOpValue & V,AtomicOrdering AO,InsertPointTy AllocaIP)8870 OpenMPIRBuilder::createAtomicRead(const LocationDescription &Loc,
8871 AtomicOpValue &X, AtomicOpValue &V,
8872 AtomicOrdering AO, InsertPointTy AllocaIP) {
8873 if (!updateToLocation(Loc))
8874 return Loc.IP;
8875
8876 assert(X.Var->getType()->isPointerTy() &&
8877 "OMP Atomic expects a pointer to target memory");
8878 Type *XElemTy = X.ElemTy;
8879 assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
8880 XElemTy->isPointerTy() || XElemTy->isStructTy()) &&
8881 "OMP atomic read expected a scalar type");
8882
8883 Value *XRead = nullptr;
8884
8885 if (XElemTy->isIntegerTy()) {
8886 LoadInst *XLD =
8887 Builder.CreateLoad(XElemTy, X.Var, X.IsVolatile, "omp.atomic.read");
8888 XLD->setAtomic(AO);
8889 XRead = cast<Value>(XLD);
8890 } else if (XElemTy->isStructTy()) {
8891 // FIXME: Add checks to ensure __atomic_load is emitted iff the
8892 // target does not support `atomicrmw` of the size of the struct
8893 LoadInst *OldVal = Builder.CreateLoad(XElemTy, X.Var, "omp.atomic.read");
8894 OldVal->setAtomic(AO);
8895 const DataLayout &LoadDL = OldVal->getModule()->getDataLayout();
8896 unsigned LoadSize =
8897 LoadDL.getTypeStoreSize(OldVal->getPointerOperand()->getType());
8898 OpenMPIRBuilder::AtomicInfo atomicInfo(
8899 &Builder, XElemTy, LoadSize * 8, LoadSize * 8, OldVal->getAlign(),
8900 OldVal->getAlign(), true /* UseLibcall */, AllocaIP, X.Var);
8901 auto AtomicLoadRes = atomicInfo.EmitAtomicLoadLibcall(AO);
8902 XRead = AtomicLoadRes.first;
8903 OldVal->eraseFromParent();
8904 } else {
8905 // We need to perform atomic op as integer
8906 IntegerType *IntCastTy =
8907 IntegerType::get(M.getContext(), XElemTy->getScalarSizeInBits());
8908 LoadInst *XLoad =
8909 Builder.CreateLoad(IntCastTy, X.Var, X.IsVolatile, "omp.atomic.load");
8910 XLoad->setAtomic(AO);
8911 if (XElemTy->isFloatingPointTy()) {
8912 XRead = Builder.CreateBitCast(XLoad, XElemTy, "atomic.flt.cast");
8913 } else {
8914 XRead = Builder.CreateIntToPtr(XLoad, XElemTy, "atomic.ptr.cast");
8915 }
8916 }
8917 checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Read);
8918 Builder.CreateStore(XRead, V.Var, V.IsVolatile);
8919 return Builder.saveIP();
8920 }
8921
8922 OpenMPIRBuilder::InsertPointTy
createAtomicWrite(const LocationDescription & Loc,AtomicOpValue & X,Value * Expr,AtomicOrdering AO,InsertPointTy AllocaIP)8923 OpenMPIRBuilder::createAtomicWrite(const LocationDescription &Loc,
8924 AtomicOpValue &X, Value *Expr,
8925 AtomicOrdering AO, InsertPointTy AllocaIP) {
8926 if (!updateToLocation(Loc))
8927 return Loc.IP;
8928
8929 assert(X.Var->getType()->isPointerTy() &&
8930 "OMP Atomic expects a pointer to target memory");
8931 Type *XElemTy = X.ElemTy;
8932 assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
8933 XElemTy->isPointerTy() || XElemTy->isStructTy()) &&
8934 "OMP atomic write expected a scalar type");
8935
8936 if (XElemTy->isIntegerTy()) {
8937 StoreInst *XSt = Builder.CreateStore(Expr, X.Var, X.IsVolatile);
8938 XSt->setAtomic(AO);
8939 } else if (XElemTy->isStructTy()) {
8940 LoadInst *OldVal = Builder.CreateLoad(XElemTy, X.Var, "omp.atomic.read");
8941 const DataLayout &LoadDL = OldVal->getModule()->getDataLayout();
8942 unsigned LoadSize =
8943 LoadDL.getTypeStoreSize(OldVal->getPointerOperand()->getType());
8944 OpenMPIRBuilder::AtomicInfo atomicInfo(
8945 &Builder, XElemTy, LoadSize * 8, LoadSize * 8, OldVal->getAlign(),
8946 OldVal->getAlign(), true /* UseLibcall */, AllocaIP, X.Var);
8947 atomicInfo.EmitAtomicStoreLibcall(AO, Expr);
8948 OldVal->eraseFromParent();
8949 } else {
8950 // We need to bitcast and perform atomic op as integers
8951 IntegerType *IntCastTy =
8952 IntegerType::get(M.getContext(), XElemTy->getScalarSizeInBits());
8953 Value *ExprCast =
8954 Builder.CreateBitCast(Expr, IntCastTy, "atomic.src.int.cast");
8955 StoreInst *XSt = Builder.CreateStore(ExprCast, X.Var, X.IsVolatile);
8956 XSt->setAtomic(AO);
8957 }
8958
8959 checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Write);
8960 return Builder.saveIP();
8961 }
8962
createAtomicUpdate(const LocationDescription & Loc,InsertPointTy AllocaIP,AtomicOpValue & X,Value * Expr,AtomicOrdering AO,AtomicRMWInst::BinOp RMWOp,AtomicUpdateCallbackTy & UpdateOp,bool IsXBinopExpr)8963 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createAtomicUpdate(
8964 const LocationDescription &Loc, InsertPointTy AllocaIP, AtomicOpValue &X,
8965 Value *Expr, AtomicOrdering AO, AtomicRMWInst::BinOp RMWOp,
8966 AtomicUpdateCallbackTy &UpdateOp, bool IsXBinopExpr) {
8967 assert(!isConflictIP(Loc.IP, AllocaIP) && "IPs must not be ambiguous");
8968 if (!updateToLocation(Loc))
8969 return Loc.IP;
8970
8971 LLVM_DEBUG({
8972 Type *XTy = X.Var->getType();
8973 assert(XTy->isPointerTy() &&
8974 "OMP Atomic expects a pointer to target memory");
8975 Type *XElemTy = X.ElemTy;
8976 assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
8977 XElemTy->isPointerTy()) &&
8978 "OMP atomic update expected a scalar type");
8979 assert((RMWOp != AtomicRMWInst::Max) && (RMWOp != AtomicRMWInst::Min) &&
8980 (RMWOp != AtomicRMWInst::UMax) && (RMWOp != AtomicRMWInst::UMin) &&
8981 "OpenMP atomic does not support LT or GT operations");
8982 });
8983
8984 Expected<std::pair<Value *, Value *>> AtomicResult =
8985 emitAtomicUpdate(AllocaIP, X.Var, X.ElemTy, Expr, AO, RMWOp, UpdateOp,
8986 X.IsVolatile, IsXBinopExpr);
8987 if (!AtomicResult)
8988 return AtomicResult.takeError();
8989 checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Update);
8990 return Builder.saveIP();
8991 }
8992
8993 // FIXME: Duplicating AtomicExpand
emitRMWOpAsInstruction(Value * Src1,Value * Src2,AtomicRMWInst::BinOp RMWOp)8994 Value *OpenMPIRBuilder::emitRMWOpAsInstruction(Value *Src1, Value *Src2,
8995 AtomicRMWInst::BinOp RMWOp) {
8996 switch (RMWOp) {
8997 case AtomicRMWInst::Add:
8998 return Builder.CreateAdd(Src1, Src2);
8999 case AtomicRMWInst::Sub:
9000 return Builder.CreateSub(Src1, Src2);
9001 case AtomicRMWInst::And:
9002 return Builder.CreateAnd(Src1, Src2);
9003 case AtomicRMWInst::Nand:
9004 return Builder.CreateNeg(Builder.CreateAnd(Src1, Src2));
9005 case AtomicRMWInst::Or:
9006 return Builder.CreateOr(Src1, Src2);
9007 case AtomicRMWInst::Xor:
9008 return Builder.CreateXor(Src1, Src2);
9009 case AtomicRMWInst::Xchg:
9010 case AtomicRMWInst::FAdd:
9011 case AtomicRMWInst::FSub:
9012 case AtomicRMWInst::BAD_BINOP:
9013 case AtomicRMWInst::Max:
9014 case AtomicRMWInst::Min:
9015 case AtomicRMWInst::UMax:
9016 case AtomicRMWInst::UMin:
9017 case AtomicRMWInst::FMax:
9018 case AtomicRMWInst::FMin:
9019 case AtomicRMWInst::FMaximum:
9020 case AtomicRMWInst::FMinimum:
9021 case AtomicRMWInst::UIncWrap:
9022 case AtomicRMWInst::UDecWrap:
9023 case AtomicRMWInst::USubCond:
9024 case AtomicRMWInst::USubSat:
9025 llvm_unreachable("Unsupported atomic update operation");
9026 }
9027 llvm_unreachable("Unsupported atomic update operation");
9028 }
9029
emitAtomicUpdate(InsertPointTy AllocaIP,Value * X,Type * XElemTy,Value * Expr,AtomicOrdering AO,AtomicRMWInst::BinOp RMWOp,AtomicUpdateCallbackTy & UpdateOp,bool VolatileX,bool IsXBinopExpr)9030 Expected<std::pair<Value *, Value *>> OpenMPIRBuilder::emitAtomicUpdate(
9031 InsertPointTy AllocaIP, Value *X, Type *XElemTy, Value *Expr,
9032 AtomicOrdering AO, AtomicRMWInst::BinOp RMWOp,
9033 AtomicUpdateCallbackTy &UpdateOp, bool VolatileX, bool IsXBinopExpr) {
9034 // TODO: handle the case where XElemTy is not byte-sized or not a power of 2
9035 // or a complex datatype.
9036 bool emitRMWOp = false;
9037 switch (RMWOp) {
9038 case AtomicRMWInst::Add:
9039 case AtomicRMWInst::And:
9040 case AtomicRMWInst::Nand:
9041 case AtomicRMWInst::Or:
9042 case AtomicRMWInst::Xor:
9043 case AtomicRMWInst::Xchg:
9044 emitRMWOp = XElemTy;
9045 break;
9046 case AtomicRMWInst::Sub:
9047 emitRMWOp = (IsXBinopExpr && XElemTy);
9048 break;
9049 default:
9050 emitRMWOp = false;
9051 }
9052 emitRMWOp &= XElemTy->isIntegerTy();
9053
9054 std::pair<Value *, Value *> Res;
9055 if (emitRMWOp) {
9056 Res.first = Builder.CreateAtomicRMW(RMWOp, X, Expr, llvm::MaybeAlign(), AO);
9057 // not needed except in case of postfix captures. Generate anyway for
9058 // consistency with the else part. Will be removed with any DCE pass.
9059 // AtomicRMWInst::Xchg does not have a coressponding instruction.
9060 if (RMWOp == AtomicRMWInst::Xchg)
9061 Res.second = Res.first;
9062 else
9063 Res.second = emitRMWOpAsInstruction(Res.first, Expr, RMWOp);
9064 } else if (RMWOp == llvm::AtomicRMWInst::BinOp::BAD_BINOP &&
9065 XElemTy->isStructTy()) {
9066 LoadInst *OldVal =
9067 Builder.CreateLoad(XElemTy, X, X->getName() + ".atomic.load");
9068 OldVal->setAtomic(AO);
9069 const DataLayout &LoadDL = OldVal->getModule()->getDataLayout();
9070 unsigned LoadSize =
9071 LoadDL.getTypeStoreSize(OldVal->getPointerOperand()->getType());
9072
9073 OpenMPIRBuilder::AtomicInfo atomicInfo(
9074 &Builder, XElemTy, LoadSize * 8, LoadSize * 8, OldVal->getAlign(),
9075 OldVal->getAlign(), true /* UseLibcall */, AllocaIP, X);
9076 auto AtomicLoadRes = atomicInfo.EmitAtomicLoadLibcall(AO);
9077 BasicBlock *CurBB = Builder.GetInsertBlock();
9078 Instruction *CurBBTI = CurBB->getTerminator();
9079 CurBBTI = CurBBTI ? CurBBTI : Builder.CreateUnreachable();
9080 BasicBlock *ExitBB =
9081 CurBB->splitBasicBlock(CurBBTI, X->getName() + ".atomic.exit");
9082 BasicBlock *ContBB = CurBB->splitBasicBlock(CurBB->getTerminator(),
9083 X->getName() + ".atomic.cont");
9084 ContBB->getTerminator()->eraseFromParent();
9085 Builder.restoreIP(AllocaIP);
9086 AllocaInst *NewAtomicAddr = Builder.CreateAlloca(XElemTy);
9087 NewAtomicAddr->setName(X->getName() + "x.new.val");
9088 Builder.SetInsertPoint(ContBB);
9089 llvm::PHINode *PHI = Builder.CreatePHI(OldVal->getType(), 2);
9090 PHI->addIncoming(AtomicLoadRes.first, CurBB);
9091 Value *OldExprVal = PHI;
9092 Expected<Value *> CBResult = UpdateOp(OldExprVal, Builder);
9093 if (!CBResult)
9094 return CBResult.takeError();
9095 Value *Upd = *CBResult;
9096 Builder.CreateStore(Upd, NewAtomicAddr);
9097 AtomicOrdering Failure =
9098 llvm::AtomicCmpXchgInst::getStrongestFailureOrdering(AO);
9099 auto Result = atomicInfo.EmitAtomicCompareExchangeLibcall(
9100 AtomicLoadRes.second, NewAtomicAddr, AO, Failure);
9101 LoadInst *PHILoad = Builder.CreateLoad(XElemTy, Result.first);
9102 PHI->addIncoming(PHILoad, Builder.GetInsertBlock());
9103 Builder.CreateCondBr(Result.second, ExitBB, ContBB);
9104 OldVal->eraseFromParent();
9105 Res.first = OldExprVal;
9106 Res.second = Upd;
9107
9108 if (UnreachableInst *ExitTI =
9109 dyn_cast<UnreachableInst>(ExitBB->getTerminator())) {
9110 CurBBTI->eraseFromParent();
9111 Builder.SetInsertPoint(ExitBB);
9112 } else {
9113 Builder.SetInsertPoint(ExitTI);
9114 }
9115 } else {
9116 IntegerType *IntCastTy =
9117 IntegerType::get(M.getContext(), XElemTy->getScalarSizeInBits());
9118 LoadInst *OldVal =
9119 Builder.CreateLoad(IntCastTy, X, X->getName() + ".atomic.load");
9120 OldVal->setAtomic(AO);
9121 // CurBB
9122 // | /---\
9123 // ContBB |
9124 // | \---/
9125 // ExitBB
9126 BasicBlock *CurBB = Builder.GetInsertBlock();
9127 Instruction *CurBBTI = CurBB->getTerminator();
9128 CurBBTI = CurBBTI ? CurBBTI : Builder.CreateUnreachable();
9129 BasicBlock *ExitBB =
9130 CurBB->splitBasicBlock(CurBBTI, X->getName() + ".atomic.exit");
9131 BasicBlock *ContBB = CurBB->splitBasicBlock(CurBB->getTerminator(),
9132 X->getName() + ".atomic.cont");
9133 ContBB->getTerminator()->eraseFromParent();
9134 Builder.restoreIP(AllocaIP);
9135 AllocaInst *NewAtomicAddr = Builder.CreateAlloca(XElemTy);
9136 NewAtomicAddr->setName(X->getName() + "x.new.val");
9137 Builder.SetInsertPoint(ContBB);
9138 llvm::PHINode *PHI = Builder.CreatePHI(OldVal->getType(), 2);
9139 PHI->addIncoming(OldVal, CurBB);
9140 bool IsIntTy = XElemTy->isIntegerTy();
9141 Value *OldExprVal = PHI;
9142 if (!IsIntTy) {
9143 if (XElemTy->isFloatingPointTy()) {
9144 OldExprVal = Builder.CreateBitCast(PHI, XElemTy,
9145 X->getName() + ".atomic.fltCast");
9146 } else {
9147 OldExprVal = Builder.CreateIntToPtr(PHI, XElemTy,
9148 X->getName() + ".atomic.ptrCast");
9149 }
9150 }
9151
9152 Expected<Value *> CBResult = UpdateOp(OldExprVal, Builder);
9153 if (!CBResult)
9154 return CBResult.takeError();
9155 Value *Upd = *CBResult;
9156 Builder.CreateStore(Upd, NewAtomicAddr);
9157 LoadInst *DesiredVal = Builder.CreateLoad(IntCastTy, NewAtomicAddr);
9158 AtomicOrdering Failure =
9159 llvm::AtomicCmpXchgInst::getStrongestFailureOrdering(AO);
9160 AtomicCmpXchgInst *Result = Builder.CreateAtomicCmpXchg(
9161 X, PHI, DesiredVal, llvm::MaybeAlign(), AO, Failure);
9162 Result->setVolatile(VolatileX);
9163 Value *PreviousVal = Builder.CreateExtractValue(Result, /*Idxs=*/0);
9164 Value *SuccessFailureVal = Builder.CreateExtractValue(Result, /*Idxs=*/1);
9165 PHI->addIncoming(PreviousVal, Builder.GetInsertBlock());
9166 Builder.CreateCondBr(SuccessFailureVal, ExitBB, ContBB);
9167
9168 Res.first = OldExprVal;
9169 Res.second = Upd;
9170
9171 // set Insertion point in exit block
9172 if (UnreachableInst *ExitTI =
9173 dyn_cast<UnreachableInst>(ExitBB->getTerminator())) {
9174 CurBBTI->eraseFromParent();
9175 Builder.SetInsertPoint(ExitBB);
9176 } else {
9177 Builder.SetInsertPoint(ExitTI);
9178 }
9179 }
9180
9181 return Res;
9182 }
9183
createAtomicCapture(const LocationDescription & Loc,InsertPointTy AllocaIP,AtomicOpValue & X,AtomicOpValue & V,Value * Expr,AtomicOrdering AO,AtomicRMWInst::BinOp RMWOp,AtomicUpdateCallbackTy & UpdateOp,bool UpdateExpr,bool IsPostfixUpdate,bool IsXBinopExpr)9184 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createAtomicCapture(
9185 const LocationDescription &Loc, InsertPointTy AllocaIP, AtomicOpValue &X,
9186 AtomicOpValue &V, Value *Expr, AtomicOrdering AO,
9187 AtomicRMWInst::BinOp RMWOp, AtomicUpdateCallbackTy &UpdateOp,
9188 bool UpdateExpr, bool IsPostfixUpdate, bool IsXBinopExpr) {
9189 if (!updateToLocation(Loc))
9190 return Loc.IP;
9191
9192 LLVM_DEBUG({
9193 Type *XTy = X.Var->getType();
9194 assert(XTy->isPointerTy() &&
9195 "OMP Atomic expects a pointer to target memory");
9196 Type *XElemTy = X.ElemTy;
9197 assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
9198 XElemTy->isPointerTy()) &&
9199 "OMP atomic capture expected a scalar type");
9200 assert((RMWOp != AtomicRMWInst::Max) && (RMWOp != AtomicRMWInst::Min) &&
9201 "OpenMP atomic does not support LT or GT operations");
9202 });
9203
9204 // If UpdateExpr is 'x' updated with some `expr` not based on 'x',
9205 // 'x' is simply atomically rewritten with 'expr'.
9206 AtomicRMWInst::BinOp AtomicOp = (UpdateExpr ? RMWOp : AtomicRMWInst::Xchg);
9207 Expected<std::pair<Value *, Value *>> AtomicResult =
9208 emitAtomicUpdate(AllocaIP, X.Var, X.ElemTy, Expr, AO, AtomicOp, UpdateOp,
9209 X.IsVolatile, IsXBinopExpr);
9210 if (!AtomicResult)
9211 return AtomicResult.takeError();
9212 Value *CapturedVal =
9213 (IsPostfixUpdate ? AtomicResult->first : AtomicResult->second);
9214 Builder.CreateStore(CapturedVal, V.Var, V.IsVolatile);
9215
9216 checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Capture);
9217 return Builder.saveIP();
9218 }
9219
createAtomicCompare(const LocationDescription & Loc,AtomicOpValue & X,AtomicOpValue & V,AtomicOpValue & R,Value * E,Value * D,AtomicOrdering AO,omp::OMPAtomicCompareOp Op,bool IsXBinopExpr,bool IsPostfixUpdate,bool IsFailOnly)9220 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare(
9221 const LocationDescription &Loc, AtomicOpValue &X, AtomicOpValue &V,
9222 AtomicOpValue &R, Value *E, Value *D, AtomicOrdering AO,
9223 omp::OMPAtomicCompareOp Op, bool IsXBinopExpr, bool IsPostfixUpdate,
9224 bool IsFailOnly) {
9225
9226 AtomicOrdering Failure = AtomicCmpXchgInst::getStrongestFailureOrdering(AO);
9227 return createAtomicCompare(Loc, X, V, R, E, D, AO, Op, IsXBinopExpr,
9228 IsPostfixUpdate, IsFailOnly, Failure);
9229 }
9230
createAtomicCompare(const LocationDescription & Loc,AtomicOpValue & X,AtomicOpValue & V,AtomicOpValue & R,Value * E,Value * D,AtomicOrdering AO,omp::OMPAtomicCompareOp Op,bool IsXBinopExpr,bool IsPostfixUpdate,bool IsFailOnly,AtomicOrdering Failure)9231 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare(
9232 const LocationDescription &Loc, AtomicOpValue &X, AtomicOpValue &V,
9233 AtomicOpValue &R, Value *E, Value *D, AtomicOrdering AO,
9234 omp::OMPAtomicCompareOp Op, bool IsXBinopExpr, bool IsPostfixUpdate,
9235 bool IsFailOnly, AtomicOrdering Failure) {
9236
9237 if (!updateToLocation(Loc))
9238 return Loc.IP;
9239
9240 assert(X.Var->getType()->isPointerTy() &&
9241 "OMP atomic expects a pointer to target memory");
9242 // compare capture
9243 if (V.Var) {
9244 assert(V.Var->getType()->isPointerTy() && "v.var must be of pointer type");
9245 assert(V.ElemTy == X.ElemTy && "x and v must be of same type");
9246 }
9247
9248 bool IsInteger = E->getType()->isIntegerTy();
9249
9250 if (Op == OMPAtomicCompareOp::EQ) {
9251 AtomicCmpXchgInst *Result = nullptr;
9252 if (!IsInteger) {
9253 IntegerType *IntCastTy =
9254 IntegerType::get(M.getContext(), X.ElemTy->getScalarSizeInBits());
9255 Value *EBCast = Builder.CreateBitCast(E, IntCastTy);
9256 Value *DBCast = Builder.CreateBitCast(D, IntCastTy);
9257 Result = Builder.CreateAtomicCmpXchg(X.Var, EBCast, DBCast, MaybeAlign(),
9258 AO, Failure);
9259 } else {
9260 Result =
9261 Builder.CreateAtomicCmpXchg(X.Var, E, D, MaybeAlign(), AO, Failure);
9262 }
9263
9264 if (V.Var) {
9265 Value *OldValue = Builder.CreateExtractValue(Result, /*Idxs=*/0);
9266 if (!IsInteger)
9267 OldValue = Builder.CreateBitCast(OldValue, X.ElemTy);
9268 assert(OldValue->getType() == V.ElemTy &&
9269 "OldValue and V must be of same type");
9270 if (IsPostfixUpdate) {
9271 Builder.CreateStore(OldValue, V.Var, V.IsVolatile);
9272 } else {
9273 Value *SuccessOrFail = Builder.CreateExtractValue(Result, /*Idxs=*/1);
9274 if (IsFailOnly) {
9275 // CurBB----
9276 // | |
9277 // v |
9278 // ContBB |
9279 // | |
9280 // v |
9281 // ExitBB <-
9282 //
9283 // where ContBB only contains the store of old value to 'v'.
9284 BasicBlock *CurBB = Builder.GetInsertBlock();
9285 Instruction *CurBBTI = CurBB->getTerminator();
9286 CurBBTI = CurBBTI ? CurBBTI : Builder.CreateUnreachable();
9287 BasicBlock *ExitBB = CurBB->splitBasicBlock(
9288 CurBBTI, X.Var->getName() + ".atomic.exit");
9289 BasicBlock *ContBB = CurBB->splitBasicBlock(
9290 CurBB->getTerminator(), X.Var->getName() + ".atomic.cont");
9291 ContBB->getTerminator()->eraseFromParent();
9292 CurBB->getTerminator()->eraseFromParent();
9293
9294 Builder.CreateCondBr(SuccessOrFail, ExitBB, ContBB);
9295
9296 Builder.SetInsertPoint(ContBB);
9297 Builder.CreateStore(OldValue, V.Var);
9298 Builder.CreateBr(ExitBB);
9299
9300 if (UnreachableInst *ExitTI =
9301 dyn_cast<UnreachableInst>(ExitBB->getTerminator())) {
9302 CurBBTI->eraseFromParent();
9303 Builder.SetInsertPoint(ExitBB);
9304 } else {
9305 Builder.SetInsertPoint(ExitTI);
9306 }
9307 } else {
9308 Value *CapturedValue =
9309 Builder.CreateSelect(SuccessOrFail, E, OldValue);
9310 Builder.CreateStore(CapturedValue, V.Var, V.IsVolatile);
9311 }
9312 }
9313 }
9314 // The comparison result has to be stored.
9315 if (R.Var) {
9316 assert(R.Var->getType()->isPointerTy() &&
9317 "r.var must be of pointer type");
9318 assert(R.ElemTy->isIntegerTy() && "r must be of integral type");
9319
9320 Value *SuccessFailureVal = Builder.CreateExtractValue(Result, /*Idxs=*/1);
9321 Value *ResultCast = R.IsSigned
9322 ? Builder.CreateSExt(SuccessFailureVal, R.ElemTy)
9323 : Builder.CreateZExt(SuccessFailureVal, R.ElemTy);
9324 Builder.CreateStore(ResultCast, R.Var, R.IsVolatile);
9325 }
9326 } else {
9327 assert((Op == OMPAtomicCompareOp::MAX || Op == OMPAtomicCompareOp::MIN) &&
9328 "Op should be either max or min at this point");
9329 assert(!IsFailOnly && "IsFailOnly is only valid when the comparison is ==");
9330
9331 // Reverse the ordop as the OpenMP forms are different from LLVM forms.
9332 // Let's take max as example.
9333 // OpenMP form:
9334 // x = x > expr ? expr : x;
9335 // LLVM form:
9336 // *ptr = *ptr > val ? *ptr : val;
9337 // We need to transform to LLVM form.
9338 // x = x <= expr ? x : expr;
9339 AtomicRMWInst::BinOp NewOp;
9340 if (IsXBinopExpr) {
9341 if (IsInteger) {
9342 if (X.IsSigned)
9343 NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::Min
9344 : AtomicRMWInst::Max;
9345 else
9346 NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::UMin
9347 : AtomicRMWInst::UMax;
9348 } else {
9349 NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::FMin
9350 : AtomicRMWInst::FMax;
9351 }
9352 } else {
9353 if (IsInteger) {
9354 if (X.IsSigned)
9355 NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::Max
9356 : AtomicRMWInst::Min;
9357 else
9358 NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::UMax
9359 : AtomicRMWInst::UMin;
9360 } else {
9361 NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::FMax
9362 : AtomicRMWInst::FMin;
9363 }
9364 }
9365
9366 AtomicRMWInst *OldValue =
9367 Builder.CreateAtomicRMW(NewOp, X.Var, E, MaybeAlign(), AO);
9368 if (V.Var) {
9369 Value *CapturedValue = nullptr;
9370 if (IsPostfixUpdate) {
9371 CapturedValue = OldValue;
9372 } else {
9373 CmpInst::Predicate Pred;
9374 switch (NewOp) {
9375 case AtomicRMWInst::Max:
9376 Pred = CmpInst::ICMP_SGT;
9377 break;
9378 case AtomicRMWInst::UMax:
9379 Pred = CmpInst::ICMP_UGT;
9380 break;
9381 case AtomicRMWInst::FMax:
9382 Pred = CmpInst::FCMP_OGT;
9383 break;
9384 case AtomicRMWInst::Min:
9385 Pred = CmpInst::ICMP_SLT;
9386 break;
9387 case AtomicRMWInst::UMin:
9388 Pred = CmpInst::ICMP_ULT;
9389 break;
9390 case AtomicRMWInst::FMin:
9391 Pred = CmpInst::FCMP_OLT;
9392 break;
9393 default:
9394 llvm_unreachable("unexpected comparison op");
9395 }
9396 Value *NonAtomicCmp = Builder.CreateCmp(Pred, OldValue, E);
9397 CapturedValue = Builder.CreateSelect(NonAtomicCmp, E, OldValue);
9398 }
9399 Builder.CreateStore(CapturedValue, V.Var, V.IsVolatile);
9400 }
9401 }
9402
9403 checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Compare);
9404
9405 return Builder.saveIP();
9406 }
9407
9408 OpenMPIRBuilder::InsertPointOrErrorTy
createTeams(const LocationDescription & Loc,BodyGenCallbackTy BodyGenCB,Value * NumTeamsLower,Value * NumTeamsUpper,Value * ThreadLimit,Value * IfExpr)9409 OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
9410 BodyGenCallbackTy BodyGenCB, Value *NumTeamsLower,
9411 Value *NumTeamsUpper, Value *ThreadLimit,
9412 Value *IfExpr) {
9413 if (!updateToLocation(Loc))
9414 return InsertPointTy();
9415
9416 uint32_t SrcLocStrSize;
9417 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
9418 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
9419 Function *CurrentFunction = Builder.GetInsertBlock()->getParent();
9420
9421 // Outer allocation basicblock is the entry block of the current function.
9422 BasicBlock &OuterAllocaBB = CurrentFunction->getEntryBlock();
9423 if (&OuterAllocaBB == Builder.GetInsertBlock()) {
9424 BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "teams.entry");
9425 Builder.SetInsertPoint(BodyBB, BodyBB->begin());
9426 }
9427
9428 // The current basic block is split into four basic blocks. After outlining,
9429 // they will be mapped as follows:
9430 // ```
9431 // def current_fn() {
9432 // current_basic_block:
9433 // br label %teams.exit
9434 // teams.exit:
9435 // ; instructions after teams
9436 // }
9437 //
9438 // def outlined_fn() {
9439 // teams.alloca:
9440 // br label %teams.body
9441 // teams.body:
9442 // ; instructions within teams body
9443 // }
9444 // ```
9445 BasicBlock *ExitBB = splitBB(Builder, /*CreateBranch=*/true, "teams.exit");
9446 BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "teams.body");
9447 BasicBlock *AllocaBB =
9448 splitBB(Builder, /*CreateBranch=*/true, "teams.alloca");
9449
9450 bool SubClausesPresent =
9451 (NumTeamsLower || NumTeamsUpper || ThreadLimit || IfExpr);
9452 // Push num_teams
9453 if (!Config.isTargetDevice() && SubClausesPresent) {
9454 assert((NumTeamsLower == nullptr || NumTeamsUpper != nullptr) &&
9455 "if lowerbound is non-null, then upperbound must also be non-null "
9456 "for bounds on num_teams");
9457
9458 if (NumTeamsUpper == nullptr)
9459 NumTeamsUpper = Builder.getInt32(0);
9460
9461 if (NumTeamsLower == nullptr)
9462 NumTeamsLower = NumTeamsUpper;
9463
9464 if (IfExpr) {
9465 assert(IfExpr->getType()->isIntegerTy() &&
9466 "argument to if clause must be an integer value");
9467
9468 // upper = ifexpr ? upper : 1
9469 if (IfExpr->getType() != Int1)
9470 IfExpr = Builder.CreateICmpNE(IfExpr,
9471 ConstantInt::get(IfExpr->getType(), 0));
9472 NumTeamsUpper = Builder.CreateSelect(
9473 IfExpr, NumTeamsUpper, Builder.getInt32(1), "numTeamsUpper");
9474
9475 // lower = ifexpr ? lower : 1
9476 NumTeamsLower = Builder.CreateSelect(
9477 IfExpr, NumTeamsLower, Builder.getInt32(1), "numTeamsLower");
9478 }
9479
9480 if (ThreadLimit == nullptr)
9481 ThreadLimit = Builder.getInt32(0);
9482
9483 Value *ThreadNum = getOrCreateThreadID(Ident);
9484 Builder.CreateCall(
9485 getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_num_teams_51),
9486 {Ident, ThreadNum, NumTeamsLower, NumTeamsUpper, ThreadLimit});
9487 }
9488 // Generate the body of teams.
9489 InsertPointTy AllocaIP(AllocaBB, AllocaBB->begin());
9490 InsertPointTy CodeGenIP(BodyBB, BodyBB->begin());
9491 if (Error Err = BodyGenCB(AllocaIP, CodeGenIP))
9492 return Err;
9493
9494 OutlineInfo OI;
9495 OI.EntryBB = AllocaBB;
9496 OI.ExitBB = ExitBB;
9497 OI.OuterAllocaBB = &OuterAllocaBB;
9498
9499 // Insert fake values for global tid and bound tid.
9500 SmallVector<Instruction *, 8> ToBeDeleted;
9501 InsertPointTy OuterAllocaIP(&OuterAllocaBB, OuterAllocaBB.begin());
9502 OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
9503 Builder, OuterAllocaIP, ToBeDeleted, AllocaIP, "gid", true));
9504 OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
9505 Builder, OuterAllocaIP, ToBeDeleted, AllocaIP, "tid", true));
9506
9507 auto HostPostOutlineCB = [this, Ident,
9508 ToBeDeleted](Function &OutlinedFn) mutable {
9509 // The stale call instruction will be replaced with a new call instruction
9510 // for runtime call with the outlined function.
9511
9512 assert(OutlinedFn.hasOneUse() &&
9513 "there must be a single user for the outlined function");
9514 CallInst *StaleCI = cast<CallInst>(OutlinedFn.user_back());
9515 ToBeDeleted.push_back(StaleCI);
9516
9517 assert((OutlinedFn.arg_size() == 2 || OutlinedFn.arg_size() == 3) &&
9518 "Outlined function must have two or three arguments only");
9519
9520 bool HasShared = OutlinedFn.arg_size() == 3;
9521
9522 OutlinedFn.getArg(0)->setName("global.tid.ptr");
9523 OutlinedFn.getArg(1)->setName("bound.tid.ptr");
9524 if (HasShared)
9525 OutlinedFn.getArg(2)->setName("data");
9526
9527 // Call to the runtime function for teams in the current function.
9528 assert(StaleCI && "Error while outlining - no CallInst user found for the "
9529 "outlined function.");
9530 Builder.SetInsertPoint(StaleCI);
9531 SmallVector<Value *> Args = {
9532 Ident, Builder.getInt32(StaleCI->arg_size() - 2), &OutlinedFn};
9533 if (HasShared)
9534 Args.push_back(StaleCI->getArgOperand(2));
9535 Builder.CreateCall(getOrCreateRuntimeFunctionPtr(
9536 omp::RuntimeFunction::OMPRTL___kmpc_fork_teams),
9537 Args);
9538
9539 for (Instruction *I : llvm::reverse(ToBeDeleted))
9540 I->eraseFromParent();
9541 };
9542
9543 if (!Config.isTargetDevice())
9544 OI.PostOutlineCB = HostPostOutlineCB;
9545
9546 addOutlineInfo(std::move(OI));
9547
9548 Builder.SetInsertPoint(ExitBB, ExitBB->begin());
9549
9550 return Builder.saveIP();
9551 }
9552
9553 OpenMPIRBuilder::InsertPointOrErrorTy
createDistribute(const LocationDescription & Loc,InsertPointTy OuterAllocaIP,BodyGenCallbackTy BodyGenCB)9554 OpenMPIRBuilder::createDistribute(const LocationDescription &Loc,
9555 InsertPointTy OuterAllocaIP,
9556 BodyGenCallbackTy BodyGenCB) {
9557 if (!updateToLocation(Loc))
9558 return InsertPointTy();
9559
9560 BasicBlock *OuterAllocaBB = OuterAllocaIP.getBlock();
9561
9562 if (OuterAllocaBB == Builder.GetInsertBlock()) {
9563 BasicBlock *BodyBB =
9564 splitBB(Builder, /*CreateBranch=*/true, "distribute.entry");
9565 Builder.SetInsertPoint(BodyBB, BodyBB->begin());
9566 }
9567 BasicBlock *ExitBB =
9568 splitBB(Builder, /*CreateBranch=*/true, "distribute.exit");
9569 BasicBlock *BodyBB =
9570 splitBB(Builder, /*CreateBranch=*/true, "distribute.body");
9571 BasicBlock *AllocaBB =
9572 splitBB(Builder, /*CreateBranch=*/true, "distribute.alloca");
9573
9574 // Generate the body of distribute clause
9575 InsertPointTy AllocaIP(AllocaBB, AllocaBB->begin());
9576 InsertPointTy CodeGenIP(BodyBB, BodyBB->begin());
9577 if (Error Err = BodyGenCB(AllocaIP, CodeGenIP))
9578 return Err;
9579
9580 OutlineInfo OI;
9581 OI.OuterAllocaBB = OuterAllocaIP.getBlock();
9582 OI.EntryBB = AllocaBB;
9583 OI.ExitBB = ExitBB;
9584
9585 addOutlineInfo(std::move(OI));
9586 Builder.SetInsertPoint(ExitBB, ExitBB->begin());
9587
9588 return Builder.saveIP();
9589 }
9590
9591 GlobalVariable *
createOffloadMapnames(SmallVectorImpl<llvm::Constant * > & Names,std::string VarName)9592 OpenMPIRBuilder::createOffloadMapnames(SmallVectorImpl<llvm::Constant *> &Names,
9593 std::string VarName) {
9594 llvm::Constant *MapNamesArrayInit = llvm::ConstantArray::get(
9595 llvm::ArrayType::get(llvm::PointerType::getUnqual(M.getContext()),
9596 Names.size()),
9597 Names);
9598 auto *MapNamesArrayGlobal = new llvm::GlobalVariable(
9599 M, MapNamesArrayInit->getType(),
9600 /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, MapNamesArrayInit,
9601 VarName);
9602 return MapNamesArrayGlobal;
9603 }
9604
9605 // Create all simple and struct types exposed by the runtime and remember
9606 // the llvm::PointerTypes of them for easy access later.
initializeTypes(Module & M)9607 void OpenMPIRBuilder::initializeTypes(Module &M) {
9608 LLVMContext &Ctx = M.getContext();
9609 StructType *T;
9610 #define OMP_TYPE(VarName, InitValue) VarName = InitValue;
9611 #define OMP_ARRAY_TYPE(VarName, ElemTy, ArraySize) \
9612 VarName##Ty = ArrayType::get(ElemTy, ArraySize); \
9613 VarName##PtrTy = PointerType::getUnqual(Ctx);
9614 #define OMP_FUNCTION_TYPE(VarName, IsVarArg, ReturnType, ...) \
9615 VarName = FunctionType::get(ReturnType, {__VA_ARGS__}, IsVarArg); \
9616 VarName##Ptr = PointerType::getUnqual(Ctx);
9617 #define OMP_STRUCT_TYPE(VarName, StructName, Packed, ...) \
9618 T = StructType::getTypeByName(Ctx, StructName); \
9619 if (!T) \
9620 T = StructType::create(Ctx, {__VA_ARGS__}, StructName, Packed); \
9621 VarName = T; \
9622 VarName##Ptr = PointerType::getUnqual(Ctx);
9623 #include "llvm/Frontend/OpenMP/OMPKinds.def"
9624 }
9625
collectBlocks(SmallPtrSetImpl<BasicBlock * > & BlockSet,SmallVectorImpl<BasicBlock * > & BlockVector)9626 void OpenMPIRBuilder::OutlineInfo::collectBlocks(
9627 SmallPtrSetImpl<BasicBlock *> &BlockSet,
9628 SmallVectorImpl<BasicBlock *> &BlockVector) {
9629 SmallVector<BasicBlock *, 32> Worklist;
9630 BlockSet.insert(EntryBB);
9631 BlockSet.insert(ExitBB);
9632
9633 Worklist.push_back(EntryBB);
9634 while (!Worklist.empty()) {
9635 BasicBlock *BB = Worklist.pop_back_val();
9636 BlockVector.push_back(BB);
9637 for (BasicBlock *SuccBB : successors(BB))
9638 if (BlockSet.insert(SuccBB).second)
9639 Worklist.push_back(SuccBB);
9640 }
9641 }
9642
createOffloadEntry(Constant * ID,Constant * Addr,uint64_t Size,int32_t Flags,GlobalValue::LinkageTypes,StringRef Name)9643 void OpenMPIRBuilder::createOffloadEntry(Constant *ID, Constant *Addr,
9644 uint64_t Size, int32_t Flags,
9645 GlobalValue::LinkageTypes,
9646 StringRef Name) {
9647 if (!Config.isGPU()) {
9648 llvm::offloading::emitOffloadingEntry(
9649 M, object::OffloadKind::OFK_OpenMP, ID,
9650 Name.empty() ? Addr->getName() : Name, Size, Flags, /*Data=*/0);
9651 return;
9652 }
9653 // TODO: Add support for global variables on the device after declare target
9654 // support.
9655 Function *Fn = dyn_cast<Function>(Addr);
9656 if (!Fn)
9657 return;
9658
9659 // Add a function attribute for the kernel.
9660 Fn->addFnAttr("kernel");
9661 if (T.isAMDGCN())
9662 Fn->addFnAttr("uniform-work-group-size", "true");
9663 Fn->addFnAttr(Attribute::MustProgress);
9664 }
9665
9666 // We only generate metadata for function that contain target regions.
createOffloadEntriesAndInfoMetadata(EmitMetadataErrorReportFunctionTy & ErrorFn)9667 void OpenMPIRBuilder::createOffloadEntriesAndInfoMetadata(
9668 EmitMetadataErrorReportFunctionTy &ErrorFn) {
9669
9670 // If there are no entries, we don't need to do anything.
9671 if (OffloadInfoManager.empty())
9672 return;
9673
9674 LLVMContext &C = M.getContext();
9675 SmallVector<std::pair<const OffloadEntriesInfoManager::OffloadEntryInfo *,
9676 TargetRegionEntryInfo>,
9677 16>
9678 OrderedEntries(OffloadInfoManager.size());
9679
9680 // Auxiliary methods to create metadata values and strings.
9681 auto &&GetMDInt = [this](unsigned V) {
9682 return ConstantAsMetadata::get(ConstantInt::get(Builder.getInt32Ty(), V));
9683 };
9684
9685 auto &&GetMDString = [&C](StringRef V) { return MDString::get(C, V); };
9686
9687 // Create the offloading info metadata node.
9688 NamedMDNode *MD = M.getOrInsertNamedMetadata("omp_offload.info");
9689 auto &&TargetRegionMetadataEmitter =
9690 [&C, MD, &OrderedEntries, &GetMDInt, &GetMDString](
9691 const TargetRegionEntryInfo &EntryInfo,
9692 const OffloadEntriesInfoManager::OffloadEntryInfoTargetRegion &E) {
9693 // Generate metadata for target regions. Each entry of this metadata
9694 // contains:
9695 // - Entry 0 -> Kind of this type of metadata (0).
9696 // - Entry 1 -> Device ID of the file where the entry was identified.
9697 // - Entry 2 -> File ID of the file where the entry was identified.
9698 // - Entry 3 -> Mangled name of the function where the entry was
9699 // identified.
9700 // - Entry 4 -> Line in the file where the entry was identified.
9701 // - Entry 5 -> Count of regions at this DeviceID/FilesID/Line.
9702 // - Entry 6 -> Order the entry was created.
9703 // The first element of the metadata node is the kind.
9704 Metadata *Ops[] = {
9705 GetMDInt(E.getKind()), GetMDInt(EntryInfo.DeviceID),
9706 GetMDInt(EntryInfo.FileID), GetMDString(EntryInfo.ParentName),
9707 GetMDInt(EntryInfo.Line), GetMDInt(EntryInfo.Count),
9708 GetMDInt(E.getOrder())};
9709
9710 // Save this entry in the right position of the ordered entries array.
9711 OrderedEntries[E.getOrder()] = std::make_pair(&E, EntryInfo);
9712
9713 // Add metadata to the named metadata node.
9714 MD->addOperand(MDNode::get(C, Ops));
9715 };
9716
9717 OffloadInfoManager.actOnTargetRegionEntriesInfo(TargetRegionMetadataEmitter);
9718
9719 // Create function that emits metadata for each device global variable entry;
9720 auto &&DeviceGlobalVarMetadataEmitter =
9721 [&C, &OrderedEntries, &GetMDInt, &GetMDString, MD](
9722 StringRef MangledName,
9723 const OffloadEntriesInfoManager::OffloadEntryInfoDeviceGlobalVar &E) {
9724 // Generate metadata for global variables. Each entry of this metadata
9725 // contains:
9726 // - Entry 0 -> Kind of this type of metadata (1).
9727 // - Entry 1 -> Mangled name of the variable.
9728 // - Entry 2 -> Declare target kind.
9729 // - Entry 3 -> Order the entry was created.
9730 // The first element of the metadata node is the kind.
9731 Metadata *Ops[] = {GetMDInt(E.getKind()), GetMDString(MangledName),
9732 GetMDInt(E.getFlags()), GetMDInt(E.getOrder())};
9733
9734 // Save this entry in the right position of the ordered entries array.
9735 TargetRegionEntryInfo varInfo(MangledName, 0, 0, 0);
9736 OrderedEntries[E.getOrder()] = std::make_pair(&E, varInfo);
9737
9738 // Add metadata to the named metadata node.
9739 MD->addOperand(MDNode::get(C, Ops));
9740 };
9741
9742 OffloadInfoManager.actOnDeviceGlobalVarEntriesInfo(
9743 DeviceGlobalVarMetadataEmitter);
9744
9745 for (const auto &E : OrderedEntries) {
9746 assert(E.first && "All ordered entries must exist!");
9747 if (const auto *CE =
9748 dyn_cast<OffloadEntriesInfoManager::OffloadEntryInfoTargetRegion>(
9749 E.first)) {
9750 if (!CE->getID() || !CE->getAddress()) {
9751 // Do not blame the entry if the parent funtion is not emitted.
9752 TargetRegionEntryInfo EntryInfo = E.second;
9753 StringRef FnName = EntryInfo.ParentName;
9754 if (!M.getNamedValue(FnName))
9755 continue;
9756 ErrorFn(EMIT_MD_TARGET_REGION_ERROR, EntryInfo);
9757 continue;
9758 }
9759 createOffloadEntry(CE->getID(), CE->getAddress(),
9760 /*Size=*/0, CE->getFlags(),
9761 GlobalValue::WeakAnyLinkage);
9762 } else if (const auto *CE = dyn_cast<
9763 OffloadEntriesInfoManager::OffloadEntryInfoDeviceGlobalVar>(
9764 E.first)) {
9765 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind Flags =
9766 static_cast<OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind>(
9767 CE->getFlags());
9768 switch (Flags) {
9769 case OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter:
9770 case OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo:
9771 if (Config.isTargetDevice() && Config.hasRequiresUnifiedSharedMemory())
9772 continue;
9773 if (!CE->getAddress()) {
9774 ErrorFn(EMIT_MD_DECLARE_TARGET_ERROR, E.second);
9775 continue;
9776 }
9777 // The vaiable has no definition - no need to add the entry.
9778 if (CE->getVarSize() == 0)
9779 continue;
9780 break;
9781 case OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink:
9782 assert(((Config.isTargetDevice() && !CE->getAddress()) ||
9783 (!Config.isTargetDevice() && CE->getAddress())) &&
9784 "Declaret target link address is set.");
9785 if (Config.isTargetDevice())
9786 continue;
9787 if (!CE->getAddress()) {
9788 ErrorFn(EMIT_MD_GLOBAL_VAR_LINK_ERROR, TargetRegionEntryInfo());
9789 continue;
9790 }
9791 break;
9792 default:
9793 break;
9794 }
9795
9796 // Hidden or internal symbols on the device are not externally visible.
9797 // We should not attempt to register them by creating an offloading
9798 // entry. Indirect variables are handled separately on the device.
9799 if (auto *GV = dyn_cast<GlobalValue>(CE->getAddress()))
9800 if ((GV->hasLocalLinkage() || GV->hasHiddenVisibility()) &&
9801 Flags != OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirect)
9802 continue;
9803
9804 // Indirect globals need to use a special name that doesn't match the name
9805 // of the associated host global.
9806 if (Flags == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirect)
9807 createOffloadEntry(CE->getAddress(), CE->getAddress(), CE->getVarSize(),
9808 Flags, CE->getLinkage(), CE->getVarName());
9809 else
9810 createOffloadEntry(CE->getAddress(), CE->getAddress(), CE->getVarSize(),
9811 Flags, CE->getLinkage());
9812
9813 } else {
9814 llvm_unreachable("Unsupported entry kind.");
9815 }
9816 }
9817
9818 // Emit requires directive globals to a special entry so the runtime can
9819 // register them when the device image is loaded.
9820 // TODO: This reduces the offloading entries to a 32-bit integer. Offloading
9821 // entries should be redesigned to better suit this use-case.
9822 if (Config.hasRequiresFlags() && !Config.isTargetDevice())
9823 offloading::emitOffloadingEntry(
9824 M, object::OffloadKind::OFK_OpenMP,
9825 Constant::getNullValue(PointerType::getUnqual(M.getContext())),
9826 ".requires", /*Size=*/0,
9827 OffloadEntriesInfoManager::OMPTargetGlobalRegisterRequires,
9828 Config.getRequiresFlags());
9829 }
9830
getTargetRegionEntryFnName(SmallVectorImpl<char> & Name,StringRef ParentName,unsigned DeviceID,unsigned FileID,unsigned Line,unsigned Count)9831 void TargetRegionEntryInfo::getTargetRegionEntryFnName(
9832 SmallVectorImpl<char> &Name, StringRef ParentName, unsigned DeviceID,
9833 unsigned FileID, unsigned Line, unsigned Count) {
9834 raw_svector_ostream OS(Name);
9835 OS << KernelNamePrefix << llvm::format("%x", DeviceID)
9836 << llvm::format("_%x_", FileID) << ParentName << "_l" << Line;
9837 if (Count)
9838 OS << "_" << Count;
9839 }
9840
getTargetRegionEntryFnName(SmallVectorImpl<char> & Name,const TargetRegionEntryInfo & EntryInfo)9841 void OffloadEntriesInfoManager::getTargetRegionEntryFnName(
9842 SmallVectorImpl<char> &Name, const TargetRegionEntryInfo &EntryInfo) {
9843 unsigned NewCount = getTargetRegionEntryInfoCount(EntryInfo);
9844 TargetRegionEntryInfo::getTargetRegionEntryFnName(
9845 Name, EntryInfo.ParentName, EntryInfo.DeviceID, EntryInfo.FileID,
9846 EntryInfo.Line, NewCount);
9847 }
9848
9849 TargetRegionEntryInfo
getTargetEntryUniqueInfo(FileIdentifierInfoCallbackTy CallBack,StringRef ParentName)9850 OpenMPIRBuilder::getTargetEntryUniqueInfo(FileIdentifierInfoCallbackTy CallBack,
9851 StringRef ParentName) {
9852 sys::fs::UniqueID ID(0xdeadf17e, 0);
9853 auto FileIDInfo = CallBack();
9854 uint64_t FileID = 0;
9855 std::error_code EC = sys::fs::getUniqueID(std::get<0>(FileIDInfo), ID);
9856 // If the inode ID could not be determined, create a hash value
9857 // the current file name and use that as an ID.
9858 if (EC)
9859 FileID = hash_value(std::get<0>(FileIDInfo));
9860 else
9861 FileID = ID.getFile();
9862
9863 return TargetRegionEntryInfo(ParentName, ID.getDevice(), FileID,
9864 std::get<1>(FileIDInfo));
9865 }
9866
getFlagMemberOffset()9867 unsigned OpenMPIRBuilder::getFlagMemberOffset() {
9868 unsigned Offset = 0;
9869 for (uint64_t Remain =
9870 static_cast<std::underlying_type_t<omp::OpenMPOffloadMappingFlags>>(
9871 omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF);
9872 !(Remain & 1); Remain = Remain >> 1)
9873 Offset++;
9874 return Offset;
9875 }
9876
9877 omp::OpenMPOffloadMappingFlags
getMemberOfFlag(unsigned Position)9878 OpenMPIRBuilder::getMemberOfFlag(unsigned Position) {
9879 // Rotate by getFlagMemberOffset() bits.
9880 return static_cast<omp::OpenMPOffloadMappingFlags>(((uint64_t)Position + 1)
9881 << getFlagMemberOffset());
9882 }
9883
setCorrectMemberOfFlag(omp::OpenMPOffloadMappingFlags & Flags,omp::OpenMPOffloadMappingFlags MemberOfFlag)9884 void OpenMPIRBuilder::setCorrectMemberOfFlag(
9885 omp::OpenMPOffloadMappingFlags &Flags,
9886 omp::OpenMPOffloadMappingFlags MemberOfFlag) {
9887 // If the entry is PTR_AND_OBJ but has not been marked with the special
9888 // placeholder value 0xFFFF in the MEMBER_OF field, then it should not be
9889 // marked as MEMBER_OF.
9890 if (static_cast<std::underlying_type_t<omp::OpenMPOffloadMappingFlags>>(
9891 Flags & omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ) &&
9892 static_cast<std::underlying_type_t<omp::OpenMPOffloadMappingFlags>>(
9893 (Flags & omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF) !=
9894 omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF))
9895 return;
9896
9897 // Reset the placeholder value to prepare the flag for the assignment of the
9898 // proper MEMBER_OF value.
9899 Flags &= ~omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
9900 Flags |= MemberOfFlag;
9901 }
9902
getAddrOfDeclareTargetVar(OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind CaptureClause,OffloadEntriesInfoManager::OMPTargetDeviceClauseKind DeviceClause,bool IsDeclaration,bool IsExternallyVisible,TargetRegionEntryInfo EntryInfo,StringRef MangledName,std::vector<GlobalVariable * > & GeneratedRefs,bool OpenMPSIMD,std::vector<Triple> TargetTriple,Type * LlvmPtrTy,std::function<Constant * ()> GlobalInitializer,std::function<GlobalValue::LinkageTypes ()> VariableLinkage)9903 Constant *OpenMPIRBuilder::getAddrOfDeclareTargetVar(
9904 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind CaptureClause,
9905 OffloadEntriesInfoManager::OMPTargetDeviceClauseKind DeviceClause,
9906 bool IsDeclaration, bool IsExternallyVisible,
9907 TargetRegionEntryInfo EntryInfo, StringRef MangledName,
9908 std::vector<GlobalVariable *> &GeneratedRefs, bool OpenMPSIMD,
9909 std::vector<Triple> TargetTriple, Type *LlvmPtrTy,
9910 std::function<Constant *()> GlobalInitializer,
9911 std::function<GlobalValue::LinkageTypes()> VariableLinkage) {
9912 // TODO: convert this to utilise the IRBuilder Config rather than
9913 // a passed down argument.
9914 if (OpenMPSIMD)
9915 return nullptr;
9916
9917 if (CaptureClause == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink ||
9918 ((CaptureClause == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo ||
9919 CaptureClause ==
9920 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter) &&
9921 Config.hasRequiresUnifiedSharedMemory())) {
9922 SmallString<64> PtrName;
9923 {
9924 raw_svector_ostream OS(PtrName);
9925 OS << MangledName;
9926 if (!IsExternallyVisible)
9927 OS << format("_%x", EntryInfo.FileID);
9928 OS << "_decl_tgt_ref_ptr";
9929 }
9930
9931 Value *Ptr = M.getNamedValue(PtrName);
9932
9933 if (!Ptr) {
9934 GlobalValue *GlobalValue = M.getNamedValue(MangledName);
9935 Ptr = getOrCreateInternalVariable(LlvmPtrTy, PtrName);
9936
9937 auto *GV = cast<GlobalVariable>(Ptr);
9938 GV->setLinkage(GlobalValue::WeakAnyLinkage);
9939
9940 if (!Config.isTargetDevice()) {
9941 if (GlobalInitializer)
9942 GV->setInitializer(GlobalInitializer());
9943 else
9944 GV->setInitializer(GlobalValue);
9945 }
9946
9947 registerTargetGlobalVariable(
9948 CaptureClause, DeviceClause, IsDeclaration, IsExternallyVisible,
9949 EntryInfo, MangledName, GeneratedRefs, OpenMPSIMD, TargetTriple,
9950 GlobalInitializer, VariableLinkage, LlvmPtrTy, cast<Constant>(Ptr));
9951 }
9952
9953 return cast<Constant>(Ptr);
9954 }
9955
9956 return nullptr;
9957 }
9958
registerTargetGlobalVariable(OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind CaptureClause,OffloadEntriesInfoManager::OMPTargetDeviceClauseKind DeviceClause,bool IsDeclaration,bool IsExternallyVisible,TargetRegionEntryInfo EntryInfo,StringRef MangledName,std::vector<GlobalVariable * > & GeneratedRefs,bool OpenMPSIMD,std::vector<Triple> TargetTriple,std::function<Constant * ()> GlobalInitializer,std::function<GlobalValue::LinkageTypes ()> VariableLinkage,Type * LlvmPtrTy,Constant * Addr)9959 void OpenMPIRBuilder::registerTargetGlobalVariable(
9960 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind CaptureClause,
9961 OffloadEntriesInfoManager::OMPTargetDeviceClauseKind DeviceClause,
9962 bool IsDeclaration, bool IsExternallyVisible,
9963 TargetRegionEntryInfo EntryInfo, StringRef MangledName,
9964 std::vector<GlobalVariable *> &GeneratedRefs, bool OpenMPSIMD,
9965 std::vector<Triple> TargetTriple,
9966 std::function<Constant *()> GlobalInitializer,
9967 std::function<GlobalValue::LinkageTypes()> VariableLinkage, Type *LlvmPtrTy,
9968 Constant *Addr) {
9969 if (DeviceClause != OffloadEntriesInfoManager::OMPTargetDeviceClauseAny ||
9970 (TargetTriple.empty() && !Config.isTargetDevice()))
9971 return;
9972
9973 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind Flags;
9974 StringRef VarName;
9975 int64_t VarSize;
9976 GlobalValue::LinkageTypes Linkage;
9977
9978 if ((CaptureClause == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo ||
9979 CaptureClause ==
9980 OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter) &&
9981 !Config.hasRequiresUnifiedSharedMemory()) {
9982 Flags = OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
9983 VarName = MangledName;
9984 GlobalValue *LlvmVal = M.getNamedValue(VarName);
9985
9986 if (!IsDeclaration)
9987 VarSize = divideCeil(
9988 M.getDataLayout().getTypeSizeInBits(LlvmVal->getValueType()), 8);
9989 else
9990 VarSize = 0;
9991 Linkage = (VariableLinkage) ? VariableLinkage() : LlvmVal->getLinkage();
9992
9993 // This is a workaround carried over from Clang which prevents undesired
9994 // optimisation of internal variables.
9995 if (Config.isTargetDevice() &&
9996 (!IsExternallyVisible || Linkage == GlobalValue::LinkOnceODRLinkage)) {
9997 // Do not create a "ref-variable" if the original is not also available
9998 // on the host.
9999 if (!OffloadInfoManager.hasDeviceGlobalVarEntryInfo(VarName))
10000 return;
10001
10002 std::string RefName = createPlatformSpecificName({VarName, "ref"});
10003
10004 if (!M.getNamedValue(RefName)) {
10005 Constant *AddrRef =
10006 getOrCreateInternalVariable(Addr->getType(), RefName);
10007 auto *GvAddrRef = cast<GlobalVariable>(AddrRef);
10008 GvAddrRef->setConstant(true);
10009 GvAddrRef->setLinkage(GlobalValue::InternalLinkage);
10010 GvAddrRef->setInitializer(Addr);
10011 GeneratedRefs.push_back(GvAddrRef);
10012 }
10013 }
10014 } else {
10015 if (CaptureClause == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink)
10016 Flags = OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink;
10017 else
10018 Flags = OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
10019
10020 if (Config.isTargetDevice()) {
10021 VarName = (Addr) ? Addr->getName() : "";
10022 Addr = nullptr;
10023 } else {
10024 Addr = getAddrOfDeclareTargetVar(
10025 CaptureClause, DeviceClause, IsDeclaration, IsExternallyVisible,
10026 EntryInfo, MangledName, GeneratedRefs, OpenMPSIMD, TargetTriple,
10027 LlvmPtrTy, GlobalInitializer, VariableLinkage);
10028 VarName = (Addr) ? Addr->getName() : "";
10029 }
10030 VarSize = M.getDataLayout().getPointerSize();
10031 Linkage = GlobalValue::WeakAnyLinkage;
10032 }
10033
10034 OffloadInfoManager.registerDeviceGlobalVarEntryInfo(VarName, Addr, VarSize,
10035 Flags, Linkage);
10036 }
10037
10038 /// Loads all the offload entries information from the host IR
10039 /// metadata.
loadOffloadInfoMetadata(Module & M)10040 void OpenMPIRBuilder::loadOffloadInfoMetadata(Module &M) {
10041 // If we are in target mode, load the metadata from the host IR. This code has
10042 // to match the metadata creation in createOffloadEntriesAndInfoMetadata().
10043
10044 NamedMDNode *MD = M.getNamedMetadata(ompOffloadInfoName);
10045 if (!MD)
10046 return;
10047
10048 for (MDNode *MN : MD->operands()) {
10049 auto &&GetMDInt = [MN](unsigned Idx) {
10050 auto *V = cast<ConstantAsMetadata>(MN->getOperand(Idx));
10051 return cast<ConstantInt>(V->getValue())->getZExtValue();
10052 };
10053
10054 auto &&GetMDString = [MN](unsigned Idx) {
10055 auto *V = cast<MDString>(MN->getOperand(Idx));
10056 return V->getString();
10057 };
10058
10059 switch (GetMDInt(0)) {
10060 default:
10061 llvm_unreachable("Unexpected metadata!");
10062 break;
10063 case OffloadEntriesInfoManager::OffloadEntryInfo::
10064 OffloadingEntryInfoTargetRegion: {
10065 TargetRegionEntryInfo EntryInfo(/*ParentName=*/GetMDString(3),
10066 /*DeviceID=*/GetMDInt(1),
10067 /*FileID=*/GetMDInt(2),
10068 /*Line=*/GetMDInt(4),
10069 /*Count=*/GetMDInt(5));
10070 OffloadInfoManager.initializeTargetRegionEntryInfo(EntryInfo,
10071 /*Order=*/GetMDInt(6));
10072 break;
10073 }
10074 case OffloadEntriesInfoManager::OffloadEntryInfo::
10075 OffloadingEntryInfoDeviceGlobalVar:
10076 OffloadInfoManager.initializeDeviceGlobalVarEntryInfo(
10077 /*MangledName=*/GetMDString(1),
10078 static_cast<OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind>(
10079 /*Flags=*/GetMDInt(2)),
10080 /*Order=*/GetMDInt(3));
10081 break;
10082 }
10083 }
10084 }
10085
loadOffloadInfoMetadata(StringRef HostFilePath)10086 void OpenMPIRBuilder::loadOffloadInfoMetadata(StringRef HostFilePath) {
10087 if (HostFilePath.empty())
10088 return;
10089
10090 auto Buf = MemoryBuffer::getFile(HostFilePath);
10091 if (std::error_code Err = Buf.getError()) {
10092 report_fatal_error(("error opening host file from host file path inside of "
10093 "OpenMPIRBuilder: " +
10094 Err.message())
10095 .c_str());
10096 }
10097
10098 LLVMContext Ctx;
10099 auto M = expectedToErrorOrAndEmitErrors(
10100 Ctx, parseBitcodeFile(Buf.get()->getMemBufferRef(), Ctx));
10101 if (std::error_code Err = M.getError()) {
10102 report_fatal_error(
10103 ("error parsing host file inside of OpenMPIRBuilder: " + Err.message())
10104 .c_str());
10105 }
10106
10107 loadOffloadInfoMetadata(*M.get());
10108 }
10109
10110 //===----------------------------------------------------------------------===//
10111 // OffloadEntriesInfoManager
10112 //===----------------------------------------------------------------------===//
10113
empty() const10114 bool OffloadEntriesInfoManager::empty() const {
10115 return OffloadEntriesTargetRegion.empty() &&
10116 OffloadEntriesDeviceGlobalVar.empty();
10117 }
10118
getTargetRegionEntryInfoCount(const TargetRegionEntryInfo & EntryInfo) const10119 unsigned OffloadEntriesInfoManager::getTargetRegionEntryInfoCount(
10120 const TargetRegionEntryInfo &EntryInfo) const {
10121 auto It = OffloadEntriesTargetRegionCount.find(
10122 getTargetRegionEntryCountKey(EntryInfo));
10123 if (It == OffloadEntriesTargetRegionCount.end())
10124 return 0;
10125 return It->second;
10126 }
10127
incrementTargetRegionEntryInfoCount(const TargetRegionEntryInfo & EntryInfo)10128 void OffloadEntriesInfoManager::incrementTargetRegionEntryInfoCount(
10129 const TargetRegionEntryInfo &EntryInfo) {
10130 OffloadEntriesTargetRegionCount[getTargetRegionEntryCountKey(EntryInfo)] =
10131 EntryInfo.Count + 1;
10132 }
10133
10134 /// Initialize target region entry.
initializeTargetRegionEntryInfo(const TargetRegionEntryInfo & EntryInfo,unsigned Order)10135 void OffloadEntriesInfoManager::initializeTargetRegionEntryInfo(
10136 const TargetRegionEntryInfo &EntryInfo, unsigned Order) {
10137 OffloadEntriesTargetRegion[EntryInfo] =
10138 OffloadEntryInfoTargetRegion(Order, /*Addr=*/nullptr, /*ID=*/nullptr,
10139 OMPTargetRegionEntryTargetRegion);
10140 ++OffloadingEntriesNum;
10141 }
10142
registerTargetRegionEntryInfo(TargetRegionEntryInfo EntryInfo,Constant * Addr,Constant * ID,OMPTargetRegionEntryKind Flags)10143 void OffloadEntriesInfoManager::registerTargetRegionEntryInfo(
10144 TargetRegionEntryInfo EntryInfo, Constant *Addr, Constant *ID,
10145 OMPTargetRegionEntryKind Flags) {
10146 assert(EntryInfo.Count == 0 && "expected default EntryInfo");
10147
10148 // Update the EntryInfo with the next available count for this location.
10149 EntryInfo.Count = getTargetRegionEntryInfoCount(EntryInfo);
10150
10151 // If we are emitting code for a target, the entry is already initialized,
10152 // only has to be registered.
10153 if (OMPBuilder->Config.isTargetDevice()) {
10154 // This could happen if the device compilation is invoked standalone.
10155 if (!hasTargetRegionEntryInfo(EntryInfo)) {
10156 return;
10157 }
10158 auto &Entry = OffloadEntriesTargetRegion[EntryInfo];
10159 Entry.setAddress(Addr);
10160 Entry.setID(ID);
10161 Entry.setFlags(Flags);
10162 } else {
10163 if (Flags == OffloadEntriesInfoManager::OMPTargetRegionEntryTargetRegion &&
10164 hasTargetRegionEntryInfo(EntryInfo, /*IgnoreAddressId*/ true))
10165 return;
10166 assert(!hasTargetRegionEntryInfo(EntryInfo) &&
10167 "Target region entry already registered!");
10168 OffloadEntryInfoTargetRegion Entry(OffloadingEntriesNum, Addr, ID, Flags);
10169 OffloadEntriesTargetRegion[EntryInfo] = Entry;
10170 ++OffloadingEntriesNum;
10171 }
10172 incrementTargetRegionEntryInfoCount(EntryInfo);
10173 }
10174
hasTargetRegionEntryInfo(TargetRegionEntryInfo EntryInfo,bool IgnoreAddressId) const10175 bool OffloadEntriesInfoManager::hasTargetRegionEntryInfo(
10176 TargetRegionEntryInfo EntryInfo, bool IgnoreAddressId) const {
10177
10178 // Update the EntryInfo with the next available count for this location.
10179 EntryInfo.Count = getTargetRegionEntryInfoCount(EntryInfo);
10180
10181 auto It = OffloadEntriesTargetRegion.find(EntryInfo);
10182 if (It == OffloadEntriesTargetRegion.end()) {
10183 return false;
10184 }
10185 // Fail if this entry is already registered.
10186 if (!IgnoreAddressId && (It->second.getAddress() || It->second.getID()))
10187 return false;
10188 return true;
10189 }
10190
actOnTargetRegionEntriesInfo(const OffloadTargetRegionEntryInfoActTy & Action)10191 void OffloadEntriesInfoManager::actOnTargetRegionEntriesInfo(
10192 const OffloadTargetRegionEntryInfoActTy &Action) {
10193 // Scan all target region entries and perform the provided action.
10194 for (const auto &It : OffloadEntriesTargetRegion) {
10195 Action(It.first, It.second);
10196 }
10197 }
10198
initializeDeviceGlobalVarEntryInfo(StringRef Name,OMPTargetGlobalVarEntryKind Flags,unsigned Order)10199 void OffloadEntriesInfoManager::initializeDeviceGlobalVarEntryInfo(
10200 StringRef Name, OMPTargetGlobalVarEntryKind Flags, unsigned Order) {
10201 OffloadEntriesDeviceGlobalVar.try_emplace(Name, Order, Flags);
10202 ++OffloadingEntriesNum;
10203 }
10204
registerDeviceGlobalVarEntryInfo(StringRef VarName,Constant * Addr,int64_t VarSize,OMPTargetGlobalVarEntryKind Flags,GlobalValue::LinkageTypes Linkage)10205 void OffloadEntriesInfoManager::registerDeviceGlobalVarEntryInfo(
10206 StringRef VarName, Constant *Addr, int64_t VarSize,
10207 OMPTargetGlobalVarEntryKind Flags, GlobalValue::LinkageTypes Linkage) {
10208 if (OMPBuilder->Config.isTargetDevice()) {
10209 // This could happen if the device compilation is invoked standalone.
10210 if (!hasDeviceGlobalVarEntryInfo(VarName))
10211 return;
10212 auto &Entry = OffloadEntriesDeviceGlobalVar[VarName];
10213 if (Entry.getAddress() && hasDeviceGlobalVarEntryInfo(VarName)) {
10214 if (Entry.getVarSize() == 0) {
10215 Entry.setVarSize(VarSize);
10216 Entry.setLinkage(Linkage);
10217 }
10218 return;
10219 }
10220 Entry.setVarSize(VarSize);
10221 Entry.setLinkage(Linkage);
10222 Entry.setAddress(Addr);
10223 } else {
10224 if (hasDeviceGlobalVarEntryInfo(VarName)) {
10225 auto &Entry = OffloadEntriesDeviceGlobalVar[VarName];
10226 assert(Entry.isValid() && Entry.getFlags() == Flags &&
10227 "Entry not initialized!");
10228 if (Entry.getVarSize() == 0) {
10229 Entry.setVarSize(VarSize);
10230 Entry.setLinkage(Linkage);
10231 }
10232 return;
10233 }
10234 if (Flags == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirect)
10235 OffloadEntriesDeviceGlobalVar.try_emplace(VarName, OffloadingEntriesNum,
10236 Addr, VarSize, Flags, Linkage,
10237 VarName.str());
10238 else
10239 OffloadEntriesDeviceGlobalVar.try_emplace(
10240 VarName, OffloadingEntriesNum, Addr, VarSize, Flags, Linkage, "");
10241 ++OffloadingEntriesNum;
10242 }
10243 }
10244
actOnDeviceGlobalVarEntriesInfo(const OffloadDeviceGlobalVarEntryInfoActTy & Action)10245 void OffloadEntriesInfoManager::actOnDeviceGlobalVarEntriesInfo(
10246 const OffloadDeviceGlobalVarEntryInfoActTy &Action) {
10247 // Scan all target region entries and perform the provided action.
10248 for (const auto &E : OffloadEntriesDeviceGlobalVar)
10249 Action(E.getKey(), E.getValue());
10250 }
10251
10252 //===----------------------------------------------------------------------===//
10253 // CanonicalLoopInfo
10254 //===----------------------------------------------------------------------===//
10255
collectControlBlocks(SmallVectorImpl<BasicBlock * > & BBs)10256 void CanonicalLoopInfo::collectControlBlocks(
10257 SmallVectorImpl<BasicBlock *> &BBs) {
10258 // We only count those BBs as control block for which we do not need to
10259 // reverse the CFG, i.e. not the loop body which can contain arbitrary control
10260 // flow. For consistency, this also means we do not add the Body block, which
10261 // is just the entry to the body code.
10262 BBs.reserve(BBs.size() + 6);
10263 BBs.append({getPreheader(), Header, Cond, Latch, Exit, getAfter()});
10264 }
10265
getPreheader() const10266 BasicBlock *CanonicalLoopInfo::getPreheader() const {
10267 assert(isValid() && "Requires a valid canonical loop");
10268 for (BasicBlock *Pred : predecessors(Header)) {
10269 if (Pred != Latch)
10270 return Pred;
10271 }
10272 llvm_unreachable("Missing preheader");
10273 }
10274
setTripCount(Value * TripCount)10275 void CanonicalLoopInfo::setTripCount(Value *TripCount) {
10276 assert(isValid() && "Requires a valid canonical loop");
10277
10278 Instruction *CmpI = &getCond()->front();
10279 assert(isa<CmpInst>(CmpI) && "First inst must compare IV with TripCount");
10280 CmpI->setOperand(1, TripCount);
10281
10282 #ifndef NDEBUG
10283 assertOK();
10284 #endif
10285 }
10286
mapIndVar(llvm::function_ref<Value * (Instruction *)> Updater)10287 void CanonicalLoopInfo::mapIndVar(
10288 llvm::function_ref<Value *(Instruction *)> Updater) {
10289 assert(isValid() && "Requires a valid canonical loop");
10290
10291 Instruction *OldIV = getIndVar();
10292
10293 // Record all uses excluding those introduced by the updater. Uses by the
10294 // CanonicalLoopInfo itself to keep track of the number of iterations are
10295 // excluded.
10296 SmallVector<Use *> ReplacableUses;
10297 for (Use &U : OldIV->uses()) {
10298 auto *User = dyn_cast<Instruction>(U.getUser());
10299 if (!User)
10300 continue;
10301 if (User->getParent() == getCond())
10302 continue;
10303 if (User->getParent() == getLatch())
10304 continue;
10305 ReplacableUses.push_back(&U);
10306 }
10307
10308 // Run the updater that may introduce new uses
10309 Value *NewIV = Updater(OldIV);
10310
10311 // Replace the old uses with the value returned by the updater.
10312 for (Use *U : ReplacableUses)
10313 U->set(NewIV);
10314
10315 #ifndef NDEBUG
10316 assertOK();
10317 #endif
10318 }
10319
assertOK() const10320 void CanonicalLoopInfo::assertOK() const {
10321 #ifndef NDEBUG
10322 // No constraints if this object currently does not describe a loop.
10323 if (!isValid())
10324 return;
10325
10326 BasicBlock *Preheader = getPreheader();
10327 BasicBlock *Body = getBody();
10328 BasicBlock *After = getAfter();
10329
10330 // Verify standard control-flow we use for OpenMP loops.
10331 assert(Preheader);
10332 assert(isa<BranchInst>(Preheader->getTerminator()) &&
10333 "Preheader must terminate with unconditional branch");
10334 assert(Preheader->getSingleSuccessor() == Header &&
10335 "Preheader must jump to header");
10336
10337 assert(Header);
10338 assert(isa<BranchInst>(Header->getTerminator()) &&
10339 "Header must terminate with unconditional branch");
10340 assert(Header->getSingleSuccessor() == Cond &&
10341 "Header must jump to exiting block");
10342
10343 assert(Cond);
10344 assert(Cond->getSinglePredecessor() == Header &&
10345 "Exiting block only reachable from header");
10346
10347 assert(isa<BranchInst>(Cond->getTerminator()) &&
10348 "Exiting block must terminate with conditional branch");
10349 assert(size(successors(Cond)) == 2 &&
10350 "Exiting block must have two successors");
10351 assert(cast<BranchInst>(Cond->getTerminator())->getSuccessor(0) == Body &&
10352 "Exiting block's first successor jump to the body");
10353 assert(cast<BranchInst>(Cond->getTerminator())->getSuccessor(1) == Exit &&
10354 "Exiting block's second successor must exit the loop");
10355
10356 assert(Body);
10357 assert(Body->getSinglePredecessor() == Cond &&
10358 "Body only reachable from exiting block");
10359 assert(!isa<PHINode>(Body->front()));
10360
10361 assert(Latch);
10362 assert(isa<BranchInst>(Latch->getTerminator()) &&
10363 "Latch must terminate with unconditional branch");
10364 assert(Latch->getSingleSuccessor() == Header && "Latch must jump to header");
10365 // TODO: To support simple redirecting of the end of the body code that has
10366 // multiple; introduce another auxiliary basic block like preheader and after.
10367 assert(Latch->getSinglePredecessor() != nullptr);
10368 assert(!isa<PHINode>(Latch->front()));
10369
10370 assert(Exit);
10371 assert(isa<BranchInst>(Exit->getTerminator()) &&
10372 "Exit block must terminate with unconditional branch");
10373 assert(Exit->getSingleSuccessor() == After &&
10374 "Exit block must jump to after block");
10375
10376 assert(After);
10377 assert(After->getSinglePredecessor() == Exit &&
10378 "After block only reachable from exit block");
10379 assert(After->empty() || !isa<PHINode>(After->front()));
10380
10381 Instruction *IndVar = getIndVar();
10382 assert(IndVar && "Canonical induction variable not found?");
10383 assert(isa<IntegerType>(IndVar->getType()) &&
10384 "Induction variable must be an integer");
10385 assert(cast<PHINode>(IndVar)->getParent() == Header &&
10386 "Induction variable must be a PHI in the loop header");
10387 assert(cast<PHINode>(IndVar)->getIncomingBlock(0) == Preheader);
10388 assert(
10389 cast<ConstantInt>(cast<PHINode>(IndVar)->getIncomingValue(0))->isZero());
10390 assert(cast<PHINode>(IndVar)->getIncomingBlock(1) == Latch);
10391
10392 auto *NextIndVar = cast<PHINode>(IndVar)->getIncomingValue(1);
10393 assert(cast<Instruction>(NextIndVar)->getParent() == Latch);
10394 assert(cast<BinaryOperator>(NextIndVar)->getOpcode() == BinaryOperator::Add);
10395 assert(cast<BinaryOperator>(NextIndVar)->getOperand(0) == IndVar);
10396 assert(cast<ConstantInt>(cast<BinaryOperator>(NextIndVar)->getOperand(1))
10397 ->isOne());
10398
10399 Value *TripCount = getTripCount();
10400 assert(TripCount && "Loop trip count not found?");
10401 assert(IndVar->getType() == TripCount->getType() &&
10402 "Trip count and induction variable must have the same type");
10403
10404 auto *CmpI = cast<CmpInst>(&Cond->front());
10405 assert(CmpI->getPredicate() == CmpInst::ICMP_ULT &&
10406 "Exit condition must be a signed less-than comparison");
10407 assert(CmpI->getOperand(0) == IndVar &&
10408 "Exit condition must compare the induction variable");
10409 assert(CmpI->getOperand(1) == TripCount &&
10410 "Exit condition must compare with the trip count");
10411 #endif
10412 }
10413
invalidate()10414 void CanonicalLoopInfo::invalidate() {
10415 Header = nullptr;
10416 Cond = nullptr;
10417 Latch = nullptr;
10418 Exit = nullptr;
10419 }
10420