xref: /freebsd/contrib/llvm-project/llvm/lib/Transforms/IPO/OpenMPOpt.cpp (revision b64c5a0ace59af62eff52bfe110a521dc73c937b)
1 //===-- IPO/OpenMPOpt.cpp - Collection of OpenMP specific optimizations ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OpenMP specific optimizations:
10 //
11 // - Deduplication of runtime calls, e.g., omp_get_thread_num.
12 // - Replacing globalized device memory with stack memory.
13 // - Replacing globalized device memory with shared memory.
14 // - Parallel region merging.
15 // - Transforming generic-mode device kernels to SPMD mode.
16 // - Specializing the state machine for generic-mode device kernels.
17 //
18 //===----------------------------------------------------------------------===//
19 
20 #include "llvm/Transforms/IPO/OpenMPOpt.h"
21 
22 #include "llvm/ADT/EnumeratedArray.h"
23 #include "llvm/ADT/PostOrderIterator.h"
24 #include "llvm/ADT/SetVector.h"
25 #include "llvm/ADT/SmallPtrSet.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "llvm/ADT/Statistic.h"
28 #include "llvm/ADT/StringExtras.h"
29 #include "llvm/ADT/StringRef.h"
30 #include "llvm/Analysis/CallGraph.h"
31 #include "llvm/Analysis/CallGraphSCCPass.h"
32 #include "llvm/Analysis/MemoryLocation.h"
33 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
34 #include "llvm/Analysis/ValueTracking.h"
35 #include "llvm/Frontend/OpenMP/OMPConstants.h"
36 #include "llvm/Frontend/OpenMP/OMPDeviceConstants.h"
37 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
38 #include "llvm/IR/Assumptions.h"
39 #include "llvm/IR/BasicBlock.h"
40 #include "llvm/IR/Constants.h"
41 #include "llvm/IR/DiagnosticInfo.h"
42 #include "llvm/IR/Dominators.h"
43 #include "llvm/IR/Function.h"
44 #include "llvm/IR/GlobalValue.h"
45 #include "llvm/IR/GlobalVariable.h"
46 #include "llvm/IR/InstrTypes.h"
47 #include "llvm/IR/Instruction.h"
48 #include "llvm/IR/Instructions.h"
49 #include "llvm/IR/IntrinsicInst.h"
50 #include "llvm/IR/IntrinsicsAMDGPU.h"
51 #include "llvm/IR/IntrinsicsNVPTX.h"
52 #include "llvm/IR/LLVMContext.h"
53 #include "llvm/Support/Casting.h"
54 #include "llvm/Support/CommandLine.h"
55 #include "llvm/Support/Debug.h"
56 #include "llvm/Transforms/IPO/Attributor.h"
57 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
58 #include "llvm/Transforms/Utils/CallGraphUpdater.h"
59 
60 #include <algorithm>
61 #include <optional>
62 #include <string>
63 
64 using namespace llvm;
65 using namespace omp;
66 
67 #define DEBUG_TYPE "openmp-opt"
68 
69 static cl::opt<bool> DisableOpenMPOptimizations(
70     "openmp-opt-disable", cl::desc("Disable OpenMP specific optimizations."),
71     cl::Hidden, cl::init(false));
72 
73 static cl::opt<bool> EnableParallelRegionMerging(
74     "openmp-opt-enable-merging",
75     cl::desc("Enable the OpenMP region merging optimization."), cl::Hidden,
76     cl::init(false));
77 
78 static cl::opt<bool>
79     DisableInternalization("openmp-opt-disable-internalization",
80                            cl::desc("Disable function internalization."),
81                            cl::Hidden, cl::init(false));
82 
83 static cl::opt<bool> DeduceICVValues("openmp-deduce-icv-values",
84                                      cl::init(false), cl::Hidden);
85 static cl::opt<bool> PrintICVValues("openmp-print-icv-values", cl::init(false),
86                                     cl::Hidden);
87 static cl::opt<bool> PrintOpenMPKernels("openmp-print-gpu-kernels",
88                                         cl::init(false), cl::Hidden);
89 
90 static cl::opt<bool> HideMemoryTransferLatency(
91     "openmp-hide-memory-transfer-latency",
92     cl::desc("[WIP] Tries to hide the latency of host to device memory"
93              " transfers"),
94     cl::Hidden, cl::init(false));
95 
96 static cl::opt<bool> DisableOpenMPOptDeglobalization(
97     "openmp-opt-disable-deglobalization",
98     cl::desc("Disable OpenMP optimizations involving deglobalization."),
99     cl::Hidden, cl::init(false));
100 
101 static cl::opt<bool> DisableOpenMPOptSPMDization(
102     "openmp-opt-disable-spmdization",
103     cl::desc("Disable OpenMP optimizations involving SPMD-ization."),
104     cl::Hidden, cl::init(false));
105 
106 static cl::opt<bool> DisableOpenMPOptFolding(
107     "openmp-opt-disable-folding",
108     cl::desc("Disable OpenMP optimizations involving folding."), cl::Hidden,
109     cl::init(false));
110 
111 static cl::opt<bool> DisableOpenMPOptStateMachineRewrite(
112     "openmp-opt-disable-state-machine-rewrite",
113     cl::desc("Disable OpenMP optimizations that replace the state machine."),
114     cl::Hidden, cl::init(false));
115 
116 static cl::opt<bool> DisableOpenMPOptBarrierElimination(
117     "openmp-opt-disable-barrier-elimination",
118     cl::desc("Disable OpenMP optimizations that eliminate barriers."),
119     cl::Hidden, cl::init(false));
120 
121 static cl::opt<bool> PrintModuleAfterOptimizations(
122     "openmp-opt-print-module-after",
123     cl::desc("Print the current module after OpenMP optimizations."),
124     cl::Hidden, cl::init(false));
125 
126 static cl::opt<bool> PrintModuleBeforeOptimizations(
127     "openmp-opt-print-module-before",
128     cl::desc("Print the current module before OpenMP optimizations."),
129     cl::Hidden, cl::init(false));
130 
131 static cl::opt<bool> AlwaysInlineDeviceFunctions(
132     "openmp-opt-inline-device",
133     cl::desc("Inline all applicible functions on the device."), cl::Hidden,
134     cl::init(false));
135 
136 static cl::opt<bool>
137     EnableVerboseRemarks("openmp-opt-verbose-remarks",
138                          cl::desc("Enables more verbose remarks."), cl::Hidden,
139                          cl::init(false));
140 
141 static cl::opt<unsigned>
142     SetFixpointIterations("openmp-opt-max-iterations", cl::Hidden,
143                           cl::desc("Maximal number of attributor iterations."),
144                           cl::init(256));
145 
146 static cl::opt<unsigned>
147     SharedMemoryLimit("openmp-opt-shared-limit", cl::Hidden,
148                       cl::desc("Maximum amount of shared memory to use."),
149                       cl::init(std::numeric_limits<unsigned>::max()));
150 
151 STATISTIC(NumOpenMPRuntimeCallsDeduplicated,
152           "Number of OpenMP runtime calls deduplicated");
153 STATISTIC(NumOpenMPParallelRegionsDeleted,
154           "Number of OpenMP parallel regions deleted");
155 STATISTIC(NumOpenMPRuntimeFunctionsIdentified,
156           "Number of OpenMP runtime functions identified");
157 STATISTIC(NumOpenMPRuntimeFunctionUsesIdentified,
158           "Number of OpenMP runtime function uses identified");
159 STATISTIC(NumOpenMPTargetRegionKernels,
160           "Number of OpenMP target region entry points (=kernels) identified");
161 STATISTIC(NumNonOpenMPTargetRegionKernels,
162           "Number of non-OpenMP target region kernels identified");
163 STATISTIC(NumOpenMPTargetRegionKernelsSPMD,
164           "Number of OpenMP target region entry points (=kernels) executed in "
165           "SPMD-mode instead of generic-mode");
166 STATISTIC(NumOpenMPTargetRegionKernelsWithoutStateMachine,
167           "Number of OpenMP target region entry points (=kernels) executed in "
168           "generic-mode without a state machines");
169 STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback,
170           "Number of OpenMP target region entry points (=kernels) executed in "
171           "generic-mode with customized state machines with fallback");
172 STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback,
173           "Number of OpenMP target region entry points (=kernels) executed in "
174           "generic-mode with customized state machines without fallback");
175 STATISTIC(
176     NumOpenMPParallelRegionsReplacedInGPUStateMachine,
177     "Number of OpenMP parallel regions replaced with ID in GPU state machines");
178 STATISTIC(NumOpenMPParallelRegionsMerged,
179           "Number of OpenMP parallel regions merged");
180 STATISTIC(NumBytesMovedToSharedMemory,
181           "Amount of memory pushed to shared memory");
182 STATISTIC(NumBarriersEliminated, "Number of redundant barriers eliminated");
183 
184 #if !defined(NDEBUG)
185 static constexpr auto TAG = "[" DEBUG_TYPE "]";
186 #endif
187 
188 namespace KernelInfo {
189 
190 // struct ConfigurationEnvironmentTy {
191 //   uint8_t UseGenericStateMachine;
192 //   uint8_t MayUseNestedParallelism;
193 //   llvm::omp::OMPTgtExecModeFlags ExecMode;
194 //   int32_t MinThreads;
195 //   int32_t MaxThreads;
196 //   int32_t MinTeams;
197 //   int32_t MaxTeams;
198 // };
199 
200 // struct DynamicEnvironmentTy {
201 //   uint16_t DebugIndentionLevel;
202 // };
203 
204 // struct KernelEnvironmentTy {
205 //   ConfigurationEnvironmentTy Configuration;
206 //   IdentTy *Ident;
207 //   DynamicEnvironmentTy *DynamicEnv;
208 // };
209 
210 #define KERNEL_ENVIRONMENT_IDX(MEMBER, IDX)                                    \
211   constexpr const unsigned MEMBER##Idx = IDX;
212 
213 KERNEL_ENVIRONMENT_IDX(Configuration, 0)
214 KERNEL_ENVIRONMENT_IDX(Ident, 1)
215 
216 #undef KERNEL_ENVIRONMENT_IDX
217 
218 #define KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MEMBER, IDX)                      \
219   constexpr const unsigned MEMBER##Idx = IDX;
220 
221 KERNEL_ENVIRONMENT_CONFIGURATION_IDX(UseGenericStateMachine, 0)
222 KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MayUseNestedParallelism, 1)
223 KERNEL_ENVIRONMENT_CONFIGURATION_IDX(ExecMode, 2)
224 KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MinThreads, 3)
225 KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MaxThreads, 4)
226 KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MinTeams, 5)
227 KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MaxTeams, 6)
228 
229 #undef KERNEL_ENVIRONMENT_CONFIGURATION_IDX
230 
231 #define KERNEL_ENVIRONMENT_GETTER(MEMBER, RETURNTYPE)                          \
232   RETURNTYPE *get##MEMBER##FromKernelEnvironment(ConstantStruct *KernelEnvC) { \
233     return cast<RETURNTYPE>(KernelEnvC->getAggregateElement(MEMBER##Idx));     \
234   }
235 
236 KERNEL_ENVIRONMENT_GETTER(Ident, Constant)
237 KERNEL_ENVIRONMENT_GETTER(Configuration, ConstantStruct)
238 
239 #undef KERNEL_ENVIRONMENT_GETTER
240 
241 #define KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MEMBER)                        \
242   ConstantInt *get##MEMBER##FromKernelEnvironment(                             \
243       ConstantStruct *KernelEnvC) {                                            \
244     ConstantStruct *ConfigC =                                                  \
245         getConfigurationFromKernelEnvironment(KernelEnvC);                     \
246     return dyn_cast<ConstantInt>(ConfigC->getAggregateElement(MEMBER##Idx));   \
247   }
248 
249 KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(UseGenericStateMachine)
250 KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MayUseNestedParallelism)
251 KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(ExecMode)
252 KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MinThreads)
253 KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MaxThreads)
254 KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MinTeams)
255 KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MaxTeams)
256 
257 #undef KERNEL_ENVIRONMENT_CONFIGURATION_GETTER
258 
259 GlobalVariable *
260 getKernelEnvironementGVFromKernelInitCB(CallBase *KernelInitCB) {
261   constexpr const int InitKernelEnvironmentArgNo = 0;
262   return cast<GlobalVariable>(
263       KernelInitCB->getArgOperand(InitKernelEnvironmentArgNo)
264           ->stripPointerCasts());
265 }
266 
267 ConstantStruct *getKernelEnvironementFromKernelInitCB(CallBase *KernelInitCB) {
268   GlobalVariable *KernelEnvGV =
269       getKernelEnvironementGVFromKernelInitCB(KernelInitCB);
270   return cast<ConstantStruct>(KernelEnvGV->getInitializer());
271 }
272 } // namespace KernelInfo
273 
274 namespace {
275 
276 struct AAHeapToShared;
277 
278 struct AAICVTracker;
279 
280 /// OpenMP specific information. For now, stores RFIs and ICVs also needed for
281 /// Attributor runs.
282 struct OMPInformationCache : public InformationCache {
283   OMPInformationCache(Module &M, AnalysisGetter &AG,
284                       BumpPtrAllocator &Allocator, SetVector<Function *> *CGSCC,
285                       bool OpenMPPostLink)
286       : InformationCache(M, AG, Allocator, CGSCC), OMPBuilder(M),
287         OpenMPPostLink(OpenMPPostLink) {
288 
289     OMPBuilder.Config.IsTargetDevice = isOpenMPDevice(OMPBuilder.M);
290     OMPBuilder.initialize();
291     initializeRuntimeFunctions(M);
292     initializeInternalControlVars();
293   }
294 
295   /// Generic information that describes an internal control variable.
296   struct InternalControlVarInfo {
297     /// The kind, as described by InternalControlVar enum.
298     InternalControlVar Kind;
299 
300     /// The name of the ICV.
301     StringRef Name;
302 
303     /// Environment variable associated with this ICV.
304     StringRef EnvVarName;
305 
306     /// Initial value kind.
307     ICVInitValue InitKind;
308 
309     /// Initial value.
310     ConstantInt *InitValue;
311 
312     /// Setter RTL function associated with this ICV.
313     RuntimeFunction Setter;
314 
315     /// Getter RTL function associated with this ICV.
316     RuntimeFunction Getter;
317 
318     /// RTL Function corresponding to the override clause of this ICV
319     RuntimeFunction Clause;
320   };
321 
322   /// Generic information that describes a runtime function
323   struct RuntimeFunctionInfo {
324 
325     /// The kind, as described by the RuntimeFunction enum.
326     RuntimeFunction Kind;
327 
328     /// The name of the function.
329     StringRef Name;
330 
331     /// Flag to indicate a variadic function.
332     bool IsVarArg;
333 
334     /// The return type of the function.
335     Type *ReturnType;
336 
337     /// The argument types of the function.
338     SmallVector<Type *, 8> ArgumentTypes;
339 
340     /// The declaration if available.
341     Function *Declaration = nullptr;
342 
343     /// Uses of this runtime function per function containing the use.
344     using UseVector = SmallVector<Use *, 16>;
345 
346     /// Clear UsesMap for runtime function.
347     void clearUsesMap() { UsesMap.clear(); }
348 
349     /// Boolean conversion that is true if the runtime function was found.
350     operator bool() const { return Declaration; }
351 
352     /// Return the vector of uses in function \p F.
353     UseVector &getOrCreateUseVector(Function *F) {
354       std::shared_ptr<UseVector> &UV = UsesMap[F];
355       if (!UV)
356         UV = std::make_shared<UseVector>();
357       return *UV;
358     }
359 
360     /// Return the vector of uses in function \p F or `nullptr` if there are
361     /// none.
362     const UseVector *getUseVector(Function &F) const {
363       auto I = UsesMap.find(&F);
364       if (I != UsesMap.end())
365         return I->second.get();
366       return nullptr;
367     }
368 
369     /// Return how many functions contain uses of this runtime function.
370     size_t getNumFunctionsWithUses() const { return UsesMap.size(); }
371 
372     /// Return the number of arguments (or the minimal number for variadic
373     /// functions).
374     size_t getNumArgs() const { return ArgumentTypes.size(); }
375 
376     /// Run the callback \p CB on each use and forget the use if the result is
377     /// true. The callback will be fed the function in which the use was
378     /// encountered as second argument.
379     void foreachUse(SmallVectorImpl<Function *> &SCC,
380                     function_ref<bool(Use &, Function &)> CB) {
381       for (Function *F : SCC)
382         foreachUse(CB, F);
383     }
384 
385     /// Run the callback \p CB on each use within the function \p F and forget
386     /// the use if the result is true.
387     void foreachUse(function_ref<bool(Use &, Function &)> CB, Function *F) {
388       SmallVector<unsigned, 8> ToBeDeleted;
389       ToBeDeleted.clear();
390 
391       unsigned Idx = 0;
392       UseVector &UV = getOrCreateUseVector(F);
393 
394       for (Use *U : UV) {
395         if (CB(*U, *F))
396           ToBeDeleted.push_back(Idx);
397         ++Idx;
398       }
399 
400       // Remove the to-be-deleted indices in reverse order as prior
401       // modifications will not modify the smaller indices.
402       while (!ToBeDeleted.empty()) {
403         unsigned Idx = ToBeDeleted.pop_back_val();
404         UV[Idx] = UV.back();
405         UV.pop_back();
406       }
407     }
408 
409   private:
410     /// Map from functions to all uses of this runtime function contained in
411     /// them.
412     DenseMap<Function *, std::shared_ptr<UseVector>> UsesMap;
413 
414   public:
415     /// Iterators for the uses of this runtime function.
416     decltype(UsesMap)::iterator begin() { return UsesMap.begin(); }
417     decltype(UsesMap)::iterator end() { return UsesMap.end(); }
418   };
419 
420   /// An OpenMP-IR-Builder instance
421   OpenMPIRBuilder OMPBuilder;
422 
423   /// Map from runtime function kind to the runtime function description.
424   EnumeratedArray<RuntimeFunctionInfo, RuntimeFunction,
425                   RuntimeFunction::OMPRTL___last>
426       RFIs;
427 
428   /// Map from function declarations/definitions to their runtime enum type.
429   DenseMap<Function *, RuntimeFunction> RuntimeFunctionIDMap;
430 
431   /// Map from ICV kind to the ICV description.
432   EnumeratedArray<InternalControlVarInfo, InternalControlVar,
433                   InternalControlVar::ICV___last>
434       ICVs;
435 
436   /// Helper to initialize all internal control variable information for those
437   /// defined in OMPKinds.def.
438   void initializeInternalControlVars() {
439 #define ICV_RT_SET(_Name, RTL)                                                 \
440   {                                                                            \
441     auto &ICV = ICVs[_Name];                                                   \
442     ICV.Setter = RTL;                                                          \
443   }
444 #define ICV_RT_GET(Name, RTL)                                                  \
445   {                                                                            \
446     auto &ICV = ICVs[Name];                                                    \
447     ICV.Getter = RTL;                                                          \
448   }
449 #define ICV_DATA_ENV(Enum, _Name, _EnvVarName, Init)                           \
450   {                                                                            \
451     auto &ICV = ICVs[Enum];                                                    \
452     ICV.Name = _Name;                                                          \
453     ICV.Kind = Enum;                                                           \
454     ICV.InitKind = Init;                                                       \
455     ICV.EnvVarName = _EnvVarName;                                              \
456     switch (ICV.InitKind) {                                                    \
457     case ICV_IMPLEMENTATION_DEFINED:                                           \
458       ICV.InitValue = nullptr;                                                 \
459       break;                                                                   \
460     case ICV_ZERO:                                                             \
461       ICV.InitValue = ConstantInt::get(                                        \
462           Type::getInt32Ty(OMPBuilder.Int32->getContext()), 0);                \
463       break;                                                                   \
464     case ICV_FALSE:                                                            \
465       ICV.InitValue = ConstantInt::getFalse(OMPBuilder.Int1->getContext());    \
466       break;                                                                   \
467     case ICV_LAST:                                                             \
468       break;                                                                   \
469     }                                                                          \
470   }
471 #include "llvm/Frontend/OpenMP/OMPKinds.def"
472   }
473 
474   /// Returns true if the function declaration \p F matches the runtime
475   /// function types, that is, return type \p RTFRetType, and argument types
476   /// \p RTFArgTypes.
477   static bool declMatchesRTFTypes(Function *F, Type *RTFRetType,
478                                   SmallVector<Type *, 8> &RTFArgTypes) {
479     // TODO: We should output information to the user (under debug output
480     //       and via remarks).
481 
482     if (!F)
483       return false;
484     if (F->getReturnType() != RTFRetType)
485       return false;
486     if (F->arg_size() != RTFArgTypes.size())
487       return false;
488 
489     auto *RTFTyIt = RTFArgTypes.begin();
490     for (Argument &Arg : F->args()) {
491       if (Arg.getType() != *RTFTyIt)
492         return false;
493 
494       ++RTFTyIt;
495     }
496 
497     return true;
498   }
499 
500   // Helper to collect all uses of the declaration in the UsesMap.
501   unsigned collectUses(RuntimeFunctionInfo &RFI, bool CollectStats = true) {
502     unsigned NumUses = 0;
503     if (!RFI.Declaration)
504       return NumUses;
505     OMPBuilder.addAttributes(RFI.Kind, *RFI.Declaration);
506 
507     if (CollectStats) {
508       NumOpenMPRuntimeFunctionsIdentified += 1;
509       NumOpenMPRuntimeFunctionUsesIdentified += RFI.Declaration->getNumUses();
510     }
511 
512     // TODO: We directly convert uses into proper calls and unknown uses.
513     for (Use &U : RFI.Declaration->uses()) {
514       if (Instruction *UserI = dyn_cast<Instruction>(U.getUser())) {
515         if (!CGSCC || CGSCC->empty() || CGSCC->contains(UserI->getFunction())) {
516           RFI.getOrCreateUseVector(UserI->getFunction()).push_back(&U);
517           ++NumUses;
518         }
519       } else {
520         RFI.getOrCreateUseVector(nullptr).push_back(&U);
521         ++NumUses;
522       }
523     }
524     return NumUses;
525   }
526 
527   // Helper function to recollect uses of a runtime function.
528   void recollectUsesForFunction(RuntimeFunction RTF) {
529     auto &RFI = RFIs[RTF];
530     RFI.clearUsesMap();
531     collectUses(RFI, /*CollectStats*/ false);
532   }
533 
534   // Helper function to recollect uses of all runtime functions.
535   void recollectUses() {
536     for (int Idx = 0; Idx < RFIs.size(); ++Idx)
537       recollectUsesForFunction(static_cast<RuntimeFunction>(Idx));
538   }
539 
540   // Helper function to inherit the calling convention of the function callee.
541   void setCallingConvention(FunctionCallee Callee, CallInst *CI) {
542     if (Function *Fn = dyn_cast<Function>(Callee.getCallee()))
543       CI->setCallingConv(Fn->getCallingConv());
544   }
545 
546   // Helper function to determine if it's legal to create a call to the runtime
547   // functions.
548   bool runtimeFnsAvailable(ArrayRef<RuntimeFunction> Fns) {
549     // We can always emit calls if we haven't yet linked in the runtime.
550     if (!OpenMPPostLink)
551       return true;
552 
553     // Once the runtime has been already been linked in we cannot emit calls to
554     // any undefined functions.
555     for (RuntimeFunction Fn : Fns) {
556       RuntimeFunctionInfo &RFI = RFIs[Fn];
557 
558       if (RFI.Declaration && RFI.Declaration->isDeclaration())
559         return false;
560     }
561     return true;
562   }
563 
564   /// Helper to initialize all runtime function information for those defined
565   /// in OpenMPKinds.def.
566   void initializeRuntimeFunctions(Module &M) {
567 
568     // Helper macros for handling __VA_ARGS__ in OMP_RTL
569 #define OMP_TYPE(VarName, ...)                                                 \
570   Type *VarName = OMPBuilder.VarName;                                          \
571   (void)VarName;
572 
573 #define OMP_ARRAY_TYPE(VarName, ...)                                           \
574   ArrayType *VarName##Ty = OMPBuilder.VarName##Ty;                             \
575   (void)VarName##Ty;                                                           \
576   PointerType *VarName##PtrTy = OMPBuilder.VarName##PtrTy;                     \
577   (void)VarName##PtrTy;
578 
579 #define OMP_FUNCTION_TYPE(VarName, ...)                                        \
580   FunctionType *VarName = OMPBuilder.VarName;                                  \
581   (void)VarName;                                                               \
582   PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr;                         \
583   (void)VarName##Ptr;
584 
585 #define OMP_STRUCT_TYPE(VarName, ...)                                          \
586   StructType *VarName = OMPBuilder.VarName;                                    \
587   (void)VarName;                                                               \
588   PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr;                         \
589   (void)VarName##Ptr;
590 
591 #define OMP_RTL(_Enum, _Name, _IsVarArg, _ReturnType, ...)                     \
592   {                                                                            \
593     SmallVector<Type *, 8> ArgsTypes({__VA_ARGS__});                           \
594     Function *F = M.getFunction(_Name);                                        \
595     RTLFunctions.insert(F);                                                    \
596     if (declMatchesRTFTypes(F, OMPBuilder._ReturnType, ArgsTypes)) {           \
597       RuntimeFunctionIDMap[F] = _Enum;                                         \
598       auto &RFI = RFIs[_Enum];                                                 \
599       RFI.Kind = _Enum;                                                        \
600       RFI.Name = _Name;                                                        \
601       RFI.IsVarArg = _IsVarArg;                                                \
602       RFI.ReturnType = OMPBuilder._ReturnType;                                 \
603       RFI.ArgumentTypes = std::move(ArgsTypes);                                \
604       RFI.Declaration = F;                                                     \
605       unsigned NumUses = collectUses(RFI);                                     \
606       (void)NumUses;                                                           \
607       LLVM_DEBUG({                                                             \
608         dbgs() << TAG << RFI.Name << (RFI.Declaration ? "" : " not")           \
609                << " found\n";                                                  \
610         if (RFI.Declaration)                                                   \
611           dbgs() << TAG << "-> got " << NumUses << " uses in "                 \
612                  << RFI.getNumFunctionsWithUses()                              \
613                  << " different functions.\n";                                 \
614       });                                                                      \
615     }                                                                          \
616   }
617 #include "llvm/Frontend/OpenMP/OMPKinds.def"
618 
619     // Remove the `noinline` attribute from `__kmpc`, `ompx::` and `omp_`
620     // functions, except if `optnone` is present.
621     if (isOpenMPDevice(M)) {
622       for (Function &F : M) {
623         for (StringRef Prefix : {"__kmpc", "_ZN4ompx", "omp_"})
624           if (F.hasFnAttribute(Attribute::NoInline) &&
625               F.getName().starts_with(Prefix) &&
626               !F.hasFnAttribute(Attribute::OptimizeNone))
627             F.removeFnAttr(Attribute::NoInline);
628       }
629     }
630 
631     // TODO: We should attach the attributes defined in OMPKinds.def.
632   }
633 
634   /// Collection of known OpenMP runtime functions..
635   DenseSet<const Function *> RTLFunctions;
636 
637   /// Indicates if we have already linked in the OpenMP device library.
638   bool OpenMPPostLink = false;
639 };
640 
641 template <typename Ty, bool InsertInvalidates = true>
642 struct BooleanStateWithSetVector : public BooleanState {
643   bool contains(const Ty &Elem) const { return Set.contains(Elem); }
644   bool insert(const Ty &Elem) {
645     if (InsertInvalidates)
646       BooleanState::indicatePessimisticFixpoint();
647     return Set.insert(Elem);
648   }
649 
650   const Ty &operator[](int Idx) const { return Set[Idx]; }
651   bool operator==(const BooleanStateWithSetVector &RHS) const {
652     return BooleanState::operator==(RHS) && Set == RHS.Set;
653   }
654   bool operator!=(const BooleanStateWithSetVector &RHS) const {
655     return !(*this == RHS);
656   }
657 
658   bool empty() const { return Set.empty(); }
659   size_t size() const { return Set.size(); }
660 
661   /// "Clamp" this state with \p RHS.
662   BooleanStateWithSetVector &operator^=(const BooleanStateWithSetVector &RHS) {
663     BooleanState::operator^=(RHS);
664     Set.insert(RHS.Set.begin(), RHS.Set.end());
665     return *this;
666   }
667 
668 private:
669   /// A set to keep track of elements.
670   SetVector<Ty> Set;
671 
672 public:
673   typename decltype(Set)::iterator begin() { return Set.begin(); }
674   typename decltype(Set)::iterator end() { return Set.end(); }
675   typename decltype(Set)::const_iterator begin() const { return Set.begin(); }
676   typename decltype(Set)::const_iterator end() const { return Set.end(); }
677 };
678 
679 template <typename Ty, bool InsertInvalidates = true>
680 using BooleanStateWithPtrSetVector =
681     BooleanStateWithSetVector<Ty *, InsertInvalidates>;
682 
683 struct KernelInfoState : AbstractState {
684   /// Flag to track if we reached a fixpoint.
685   bool IsAtFixpoint = false;
686 
687   /// The parallel regions (identified by the outlined parallel functions) that
688   /// can be reached from the associated function.
689   BooleanStateWithPtrSetVector<CallBase, /* InsertInvalidates */ false>
690       ReachedKnownParallelRegions;
691 
692   /// State to track what parallel region we might reach.
693   BooleanStateWithPtrSetVector<CallBase> ReachedUnknownParallelRegions;
694 
695   /// State to track if we are in SPMD-mode, assumed or know, and why we decided
696   /// we cannot be. If it is assumed, then RequiresFullRuntime should also be
697   /// false.
698   BooleanStateWithPtrSetVector<Instruction, false> SPMDCompatibilityTracker;
699 
700   /// The __kmpc_target_init call in this kernel, if any. If we find more than
701   /// one we abort as the kernel is malformed.
702   CallBase *KernelInitCB = nullptr;
703 
704   /// The constant kernel environement as taken from and passed to
705   /// __kmpc_target_init.
706   ConstantStruct *KernelEnvC = nullptr;
707 
708   /// The __kmpc_target_deinit call in this kernel, if any. If we find more than
709   /// one we abort as the kernel is malformed.
710   CallBase *KernelDeinitCB = nullptr;
711 
712   /// Flag to indicate if the associated function is a kernel entry.
713   bool IsKernelEntry = false;
714 
715   /// State to track what kernel entries can reach the associated function.
716   BooleanStateWithPtrSetVector<Function, false> ReachingKernelEntries;
717 
718   /// State to indicate if we can track parallel level of the associated
719   /// function. We will give up tracking if we encounter unknown caller or the
720   /// caller is __kmpc_parallel_51.
721   BooleanStateWithSetVector<uint8_t> ParallelLevels;
722 
723   /// Flag that indicates if the kernel has nested Parallelism
724   bool NestedParallelism = false;
725 
726   /// Abstract State interface
727   ///{
728 
729   KernelInfoState() = default;
730   KernelInfoState(bool BestState) {
731     if (!BestState)
732       indicatePessimisticFixpoint();
733   }
734 
735   /// See AbstractState::isValidState(...)
736   bool isValidState() const override { return true; }
737 
738   /// See AbstractState::isAtFixpoint(...)
739   bool isAtFixpoint() const override { return IsAtFixpoint; }
740 
741   /// See AbstractState::indicatePessimisticFixpoint(...)
742   ChangeStatus indicatePessimisticFixpoint() override {
743     IsAtFixpoint = true;
744     ParallelLevels.indicatePessimisticFixpoint();
745     ReachingKernelEntries.indicatePessimisticFixpoint();
746     SPMDCompatibilityTracker.indicatePessimisticFixpoint();
747     ReachedKnownParallelRegions.indicatePessimisticFixpoint();
748     ReachedUnknownParallelRegions.indicatePessimisticFixpoint();
749     NestedParallelism = true;
750     return ChangeStatus::CHANGED;
751   }
752 
753   /// See AbstractState::indicateOptimisticFixpoint(...)
754   ChangeStatus indicateOptimisticFixpoint() override {
755     IsAtFixpoint = true;
756     ParallelLevels.indicateOptimisticFixpoint();
757     ReachingKernelEntries.indicateOptimisticFixpoint();
758     SPMDCompatibilityTracker.indicateOptimisticFixpoint();
759     ReachedKnownParallelRegions.indicateOptimisticFixpoint();
760     ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
761     return ChangeStatus::UNCHANGED;
762   }
763 
764   /// Return the assumed state
765   KernelInfoState &getAssumed() { return *this; }
766   const KernelInfoState &getAssumed() const { return *this; }
767 
768   bool operator==(const KernelInfoState &RHS) const {
769     if (SPMDCompatibilityTracker != RHS.SPMDCompatibilityTracker)
770       return false;
771     if (ReachedKnownParallelRegions != RHS.ReachedKnownParallelRegions)
772       return false;
773     if (ReachedUnknownParallelRegions != RHS.ReachedUnknownParallelRegions)
774       return false;
775     if (ReachingKernelEntries != RHS.ReachingKernelEntries)
776       return false;
777     if (ParallelLevels != RHS.ParallelLevels)
778       return false;
779     if (NestedParallelism != RHS.NestedParallelism)
780       return false;
781     return true;
782   }
783 
784   /// Returns true if this kernel contains any OpenMP parallel regions.
785   bool mayContainParallelRegion() {
786     return !ReachedKnownParallelRegions.empty() ||
787            !ReachedUnknownParallelRegions.empty();
788   }
789 
790   /// Return empty set as the best state of potential values.
791   static KernelInfoState getBestState() { return KernelInfoState(true); }
792 
793   static KernelInfoState getBestState(KernelInfoState &KIS) {
794     return getBestState();
795   }
796 
797   /// Return full set as the worst state of potential values.
798   static KernelInfoState getWorstState() { return KernelInfoState(false); }
799 
800   /// "Clamp" this state with \p KIS.
801   KernelInfoState operator^=(const KernelInfoState &KIS) {
802     // Do not merge two different _init and _deinit call sites.
803     if (KIS.KernelInitCB) {
804       if (KernelInitCB && KernelInitCB != KIS.KernelInitCB)
805         llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
806                          "assumptions.");
807       KernelInitCB = KIS.KernelInitCB;
808     }
809     if (KIS.KernelDeinitCB) {
810       if (KernelDeinitCB && KernelDeinitCB != KIS.KernelDeinitCB)
811         llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
812                          "assumptions.");
813       KernelDeinitCB = KIS.KernelDeinitCB;
814     }
815     if (KIS.KernelEnvC) {
816       if (KernelEnvC && KernelEnvC != KIS.KernelEnvC)
817         llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
818                          "assumptions.");
819       KernelEnvC = KIS.KernelEnvC;
820     }
821     SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker;
822     ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions;
823     ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions;
824     NestedParallelism |= KIS.NestedParallelism;
825     return *this;
826   }
827 
828   KernelInfoState operator&=(const KernelInfoState &KIS) {
829     return (*this ^= KIS);
830   }
831 
832   ///}
833 };
834 
835 /// Used to map the values physically (in the IR) stored in an offload
836 /// array, to a vector in memory.
837 struct OffloadArray {
838   /// Physical array (in the IR).
839   AllocaInst *Array = nullptr;
840   /// Mapped values.
841   SmallVector<Value *, 8> StoredValues;
842   /// Last stores made in the offload array.
843   SmallVector<StoreInst *, 8> LastAccesses;
844 
845   OffloadArray() = default;
846 
847   /// Initializes the OffloadArray with the values stored in \p Array before
848   /// instruction \p Before is reached. Returns false if the initialization
849   /// fails.
850   /// This MUST be used immediately after the construction of the object.
851   bool initialize(AllocaInst &Array, Instruction &Before) {
852     if (!Array.getAllocatedType()->isArrayTy())
853       return false;
854 
855     if (!getValues(Array, Before))
856       return false;
857 
858     this->Array = &Array;
859     return true;
860   }
861 
862   static const unsigned DeviceIDArgNum = 1;
863   static const unsigned BasePtrsArgNum = 3;
864   static const unsigned PtrsArgNum = 4;
865   static const unsigned SizesArgNum = 5;
866 
867 private:
868   /// Traverses the BasicBlock where \p Array is, collecting the stores made to
869   /// \p Array, leaving StoredValues with the values stored before the
870   /// instruction \p Before is reached.
871   bool getValues(AllocaInst &Array, Instruction &Before) {
872     // Initialize container.
873     const uint64_t NumValues = Array.getAllocatedType()->getArrayNumElements();
874     StoredValues.assign(NumValues, nullptr);
875     LastAccesses.assign(NumValues, nullptr);
876 
877     // TODO: This assumes the instruction \p Before is in the same
878     //  BasicBlock as Array. Make it general, for any control flow graph.
879     BasicBlock *BB = Array.getParent();
880     if (BB != Before.getParent())
881       return false;
882 
883     const DataLayout &DL = Array.getDataLayout();
884     const unsigned int PointerSize = DL.getPointerSize();
885 
886     for (Instruction &I : *BB) {
887       if (&I == &Before)
888         break;
889 
890       if (!isa<StoreInst>(&I))
891         continue;
892 
893       auto *S = cast<StoreInst>(&I);
894       int64_t Offset = -1;
895       auto *Dst =
896           GetPointerBaseWithConstantOffset(S->getPointerOperand(), Offset, DL);
897       if (Dst == &Array) {
898         int64_t Idx = Offset / PointerSize;
899         StoredValues[Idx] = getUnderlyingObject(S->getValueOperand());
900         LastAccesses[Idx] = S;
901       }
902     }
903 
904     return isFilled();
905   }
906 
907   /// Returns true if all values in StoredValues and
908   /// LastAccesses are not nullptrs.
909   bool isFilled() {
910     const unsigned NumValues = StoredValues.size();
911     for (unsigned I = 0; I < NumValues; ++I) {
912       if (!StoredValues[I] || !LastAccesses[I])
913         return false;
914     }
915 
916     return true;
917   }
918 };
919 
920 struct OpenMPOpt {
921 
922   using OptimizationRemarkGetter =
923       function_ref<OptimizationRemarkEmitter &(Function *)>;
924 
925   OpenMPOpt(SmallVectorImpl<Function *> &SCC, CallGraphUpdater &CGUpdater,
926             OptimizationRemarkGetter OREGetter,
927             OMPInformationCache &OMPInfoCache, Attributor &A)
928       : M(*(*SCC.begin())->getParent()), SCC(SCC), CGUpdater(CGUpdater),
929         OREGetter(OREGetter), OMPInfoCache(OMPInfoCache), A(A) {}
930 
931   /// Check if any remarks are enabled for openmp-opt
932   bool remarksEnabled() {
933     auto &Ctx = M.getContext();
934     return Ctx.getDiagHandlerPtr()->isAnyRemarkEnabled(DEBUG_TYPE);
935   }
936 
937   /// Run all OpenMP optimizations on the underlying SCC.
938   bool run(bool IsModulePass) {
939     if (SCC.empty())
940       return false;
941 
942     bool Changed = false;
943 
944     LLVM_DEBUG(dbgs() << TAG << "Run on SCC with " << SCC.size()
945                       << " functions\n");
946 
947     if (IsModulePass) {
948       Changed |= runAttributor(IsModulePass);
949 
950       // Recollect uses, in case Attributor deleted any.
951       OMPInfoCache.recollectUses();
952 
953       // TODO: This should be folded into buildCustomStateMachine.
954       Changed |= rewriteDeviceCodeStateMachine();
955 
956       if (remarksEnabled())
957         analysisGlobalization();
958     } else {
959       if (PrintICVValues)
960         printICVs();
961       if (PrintOpenMPKernels)
962         printKernels();
963 
964       Changed |= runAttributor(IsModulePass);
965 
966       // Recollect uses, in case Attributor deleted any.
967       OMPInfoCache.recollectUses();
968 
969       Changed |= deleteParallelRegions();
970 
971       if (HideMemoryTransferLatency)
972         Changed |= hideMemTransfersLatency();
973       Changed |= deduplicateRuntimeCalls();
974       if (EnableParallelRegionMerging) {
975         if (mergeParallelRegions()) {
976           deduplicateRuntimeCalls();
977           Changed = true;
978         }
979       }
980     }
981 
982     if (OMPInfoCache.OpenMPPostLink)
983       Changed |= removeRuntimeSymbols();
984 
985     return Changed;
986   }
987 
988   /// Print initial ICV values for testing.
989   /// FIXME: This should be done from the Attributor once it is added.
990   void printICVs() const {
991     InternalControlVar ICVs[] = {ICV_nthreads, ICV_active_levels, ICV_cancel,
992                                  ICV_proc_bind};
993 
994     for (Function *F : SCC) {
995       for (auto ICV : ICVs) {
996         auto ICVInfo = OMPInfoCache.ICVs[ICV];
997         auto Remark = [&](OptimizationRemarkAnalysis ORA) {
998           return ORA << "OpenMP ICV " << ore::NV("OpenMPICV", ICVInfo.Name)
999                      << " Value: "
1000                      << (ICVInfo.InitValue
1001                              ? toString(ICVInfo.InitValue->getValue(), 10, true)
1002                              : "IMPLEMENTATION_DEFINED");
1003         };
1004 
1005         emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPICVTracker", Remark);
1006       }
1007     }
1008   }
1009 
1010   /// Print OpenMP GPU kernels for testing.
1011   void printKernels() const {
1012     for (Function *F : SCC) {
1013       if (!omp::isOpenMPKernel(*F))
1014         continue;
1015 
1016       auto Remark = [&](OptimizationRemarkAnalysis ORA) {
1017         return ORA << "OpenMP GPU kernel "
1018                    << ore::NV("OpenMPGPUKernel", F->getName()) << "\n";
1019       };
1020 
1021       emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPGPU", Remark);
1022     }
1023   }
1024 
1025   /// Return the call if \p U is a callee use in a regular call. If \p RFI is
1026   /// given it has to be the callee or a nullptr is returned.
1027   static CallInst *getCallIfRegularCall(
1028       Use &U, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
1029     CallInst *CI = dyn_cast<CallInst>(U.getUser());
1030     if (CI && CI->isCallee(&U) && !CI->hasOperandBundles() &&
1031         (!RFI ||
1032          (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
1033       return CI;
1034     return nullptr;
1035   }
1036 
1037   /// Return the call if \p V is a regular call. If \p RFI is given it has to be
1038   /// the callee or a nullptr is returned.
1039   static CallInst *getCallIfRegularCall(
1040       Value &V, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
1041     CallInst *CI = dyn_cast<CallInst>(&V);
1042     if (CI && !CI->hasOperandBundles() &&
1043         (!RFI ||
1044          (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
1045       return CI;
1046     return nullptr;
1047   }
1048 
1049 private:
1050   /// Merge parallel regions when it is safe.
1051   bool mergeParallelRegions() {
1052     const unsigned CallbackCalleeOperand = 2;
1053     const unsigned CallbackFirstArgOperand = 3;
1054     using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
1055 
1056     // Check if there are any __kmpc_fork_call calls to merge.
1057     OMPInformationCache::RuntimeFunctionInfo &RFI =
1058         OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
1059 
1060     if (!RFI.Declaration)
1061       return false;
1062 
1063     // Unmergable calls that prevent merging a parallel region.
1064     OMPInformationCache::RuntimeFunctionInfo UnmergableCallsInfo[] = {
1065         OMPInfoCache.RFIs[OMPRTL___kmpc_push_proc_bind],
1066         OMPInfoCache.RFIs[OMPRTL___kmpc_push_num_threads],
1067     };
1068 
1069     bool Changed = false;
1070     LoopInfo *LI = nullptr;
1071     DominatorTree *DT = nullptr;
1072 
1073     SmallDenseMap<BasicBlock *, SmallPtrSet<Instruction *, 4>> BB2PRMap;
1074 
1075     BasicBlock *StartBB = nullptr, *EndBB = nullptr;
1076     auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
1077       BasicBlock *CGStartBB = CodeGenIP.getBlock();
1078       BasicBlock *CGEndBB =
1079           SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
1080       assert(StartBB != nullptr && "StartBB should not be null");
1081       CGStartBB->getTerminator()->setSuccessor(0, StartBB);
1082       assert(EndBB != nullptr && "EndBB should not be null");
1083       EndBB->getTerminator()->setSuccessor(0, CGEndBB);
1084     };
1085 
1086     auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value &,
1087                       Value &Inner, Value *&ReplacementValue) -> InsertPointTy {
1088       ReplacementValue = &Inner;
1089       return CodeGenIP;
1090     };
1091 
1092     auto FiniCB = [&](InsertPointTy CodeGenIP) {};
1093 
1094     /// Create a sequential execution region within a merged parallel region,
1095     /// encapsulated in a master construct with a barrier for synchronization.
1096     auto CreateSequentialRegion = [&](Function *OuterFn,
1097                                       BasicBlock *OuterPredBB,
1098                                       Instruction *SeqStartI,
1099                                       Instruction *SeqEndI) {
1100       // Isolate the instructions of the sequential region to a separate
1101       // block.
1102       BasicBlock *ParentBB = SeqStartI->getParent();
1103       BasicBlock *SeqEndBB =
1104           SplitBlock(ParentBB, SeqEndI->getNextNode(), DT, LI);
1105       BasicBlock *SeqAfterBB =
1106           SplitBlock(SeqEndBB, &*SeqEndBB->getFirstInsertionPt(), DT, LI);
1107       BasicBlock *SeqStartBB =
1108           SplitBlock(ParentBB, SeqStartI, DT, LI, nullptr, "seq.par.merged");
1109 
1110       assert(ParentBB->getUniqueSuccessor() == SeqStartBB &&
1111              "Expected a different CFG");
1112       const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
1113       ParentBB->getTerminator()->eraseFromParent();
1114 
1115       auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
1116         BasicBlock *CGStartBB = CodeGenIP.getBlock();
1117         BasicBlock *CGEndBB =
1118             SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
1119         assert(SeqStartBB != nullptr && "SeqStartBB should not be null");
1120         CGStartBB->getTerminator()->setSuccessor(0, SeqStartBB);
1121         assert(SeqEndBB != nullptr && "SeqEndBB should not be null");
1122         SeqEndBB->getTerminator()->setSuccessor(0, CGEndBB);
1123       };
1124       auto FiniCB = [&](InsertPointTy CodeGenIP) {};
1125 
1126       // Find outputs from the sequential region to outside users and
1127       // broadcast their values to them.
1128       for (Instruction &I : *SeqStartBB) {
1129         SmallPtrSet<Instruction *, 4> OutsideUsers;
1130         for (User *Usr : I.users()) {
1131           Instruction &UsrI = *cast<Instruction>(Usr);
1132           // Ignore outputs to LT intrinsics, code extraction for the merged
1133           // parallel region will fix them.
1134           if (UsrI.isLifetimeStartOrEnd())
1135             continue;
1136 
1137           if (UsrI.getParent() != SeqStartBB)
1138             OutsideUsers.insert(&UsrI);
1139         }
1140 
1141         if (OutsideUsers.empty())
1142           continue;
1143 
1144         // Emit an alloca in the outer region to store the broadcasted
1145         // value.
1146         const DataLayout &DL = M.getDataLayout();
1147         AllocaInst *AllocaI = new AllocaInst(
1148             I.getType(), DL.getAllocaAddrSpace(), nullptr,
1149             I.getName() + ".seq.output.alloc", OuterFn->front().begin());
1150 
1151         // Emit a store instruction in the sequential BB to update the
1152         // value.
1153         new StoreInst(&I, AllocaI, SeqStartBB->getTerminator()->getIterator());
1154 
1155         // Emit a load instruction and replace the use of the output value
1156         // with it.
1157         for (Instruction *UsrI : OutsideUsers) {
1158           LoadInst *LoadI = new LoadInst(I.getType(), AllocaI,
1159                                          I.getName() + ".seq.output.load",
1160                                          UsrI->getIterator());
1161           UsrI->replaceUsesOfWith(&I, LoadI);
1162         }
1163       }
1164 
1165       OpenMPIRBuilder::LocationDescription Loc(
1166           InsertPointTy(ParentBB, ParentBB->end()), DL);
1167       InsertPointTy SeqAfterIP =
1168           OMPInfoCache.OMPBuilder.createMaster(Loc, BodyGenCB, FiniCB);
1169 
1170       OMPInfoCache.OMPBuilder.createBarrier(SeqAfterIP, OMPD_parallel);
1171 
1172       BranchInst::Create(SeqAfterBB, SeqAfterIP.getBlock());
1173 
1174       LLVM_DEBUG(dbgs() << TAG << "After sequential inlining " << *OuterFn
1175                         << "\n");
1176     };
1177 
1178     // Helper to merge the __kmpc_fork_call calls in MergableCIs. They are all
1179     // contained in BB and only separated by instructions that can be
1180     // redundantly executed in parallel. The block BB is split before the first
1181     // call (in MergableCIs) and after the last so the entire region we merge
1182     // into a single parallel region is contained in a single basic block
1183     // without any other instructions. We use the OpenMPIRBuilder to outline
1184     // that block and call the resulting function via __kmpc_fork_call.
1185     auto Merge = [&](const SmallVectorImpl<CallInst *> &MergableCIs,
1186                      BasicBlock *BB) {
1187       // TODO: Change the interface to allow single CIs expanded, e.g, to
1188       // include an outer loop.
1189       assert(MergableCIs.size() > 1 && "Assumed multiple mergable CIs");
1190 
1191       auto Remark = [&](OptimizationRemark OR) {
1192         OR << "Parallel region merged with parallel region"
1193            << (MergableCIs.size() > 2 ? "s" : "") << " at ";
1194         for (auto *CI : llvm::drop_begin(MergableCIs)) {
1195           OR << ore::NV("OpenMPParallelMerge", CI->getDebugLoc());
1196           if (CI != MergableCIs.back())
1197             OR << ", ";
1198         }
1199         return OR << ".";
1200       };
1201 
1202       emitRemark<OptimizationRemark>(MergableCIs.front(), "OMP150", Remark);
1203 
1204       Function *OriginalFn = BB->getParent();
1205       LLVM_DEBUG(dbgs() << TAG << "Merge " << MergableCIs.size()
1206                         << " parallel regions in " << OriginalFn->getName()
1207                         << "\n");
1208 
1209       // Isolate the calls to merge in a separate block.
1210       EndBB = SplitBlock(BB, MergableCIs.back()->getNextNode(), DT, LI);
1211       BasicBlock *AfterBB =
1212           SplitBlock(EndBB, &*EndBB->getFirstInsertionPt(), DT, LI);
1213       StartBB = SplitBlock(BB, MergableCIs.front(), DT, LI, nullptr,
1214                            "omp.par.merged");
1215 
1216       assert(BB->getUniqueSuccessor() == StartBB && "Expected a different CFG");
1217       const DebugLoc DL = BB->getTerminator()->getDebugLoc();
1218       BB->getTerminator()->eraseFromParent();
1219 
1220       // Create sequential regions for sequential instructions that are
1221       // in-between mergable parallel regions.
1222       for (auto *It = MergableCIs.begin(), *End = MergableCIs.end() - 1;
1223            It != End; ++It) {
1224         Instruction *ForkCI = *It;
1225         Instruction *NextForkCI = *(It + 1);
1226 
1227         // Continue if there are not in-between instructions.
1228         if (ForkCI->getNextNode() == NextForkCI)
1229           continue;
1230 
1231         CreateSequentialRegion(OriginalFn, BB, ForkCI->getNextNode(),
1232                                NextForkCI->getPrevNode());
1233       }
1234 
1235       OpenMPIRBuilder::LocationDescription Loc(InsertPointTy(BB, BB->end()),
1236                                                DL);
1237       IRBuilder<>::InsertPoint AllocaIP(
1238           &OriginalFn->getEntryBlock(),
1239           OriginalFn->getEntryBlock().getFirstInsertionPt());
1240       // Create the merged parallel region with default proc binding, to
1241       // avoid overriding binding settings, and without explicit cancellation.
1242       InsertPointTy AfterIP = OMPInfoCache.OMPBuilder.createParallel(
1243           Loc, AllocaIP, BodyGenCB, PrivCB, FiniCB, nullptr, nullptr,
1244           OMP_PROC_BIND_default, /* IsCancellable */ false);
1245       BranchInst::Create(AfterBB, AfterIP.getBlock());
1246 
1247       // Perform the actual outlining.
1248       OMPInfoCache.OMPBuilder.finalize(OriginalFn);
1249 
1250       Function *OutlinedFn = MergableCIs.front()->getCaller();
1251 
1252       // Replace the __kmpc_fork_call calls with direct calls to the outlined
1253       // callbacks.
1254       SmallVector<Value *, 8> Args;
1255       for (auto *CI : MergableCIs) {
1256         Value *Callee = CI->getArgOperand(CallbackCalleeOperand);
1257         FunctionType *FT = OMPInfoCache.OMPBuilder.ParallelTask;
1258         Args.clear();
1259         Args.push_back(OutlinedFn->getArg(0));
1260         Args.push_back(OutlinedFn->getArg(1));
1261         for (unsigned U = CallbackFirstArgOperand, E = CI->arg_size(); U < E;
1262              ++U)
1263           Args.push_back(CI->getArgOperand(U));
1264 
1265         CallInst *NewCI =
1266             CallInst::Create(FT, Callee, Args, "", CI->getIterator());
1267         if (CI->getDebugLoc())
1268           NewCI->setDebugLoc(CI->getDebugLoc());
1269 
1270         // Forward parameter attributes from the callback to the callee.
1271         for (unsigned U = CallbackFirstArgOperand, E = CI->arg_size(); U < E;
1272              ++U)
1273           for (const Attribute &A : CI->getAttributes().getParamAttrs(U))
1274             NewCI->addParamAttr(
1275                 U - (CallbackFirstArgOperand - CallbackCalleeOperand), A);
1276 
1277         // Emit an explicit barrier to replace the implicit fork-join barrier.
1278         if (CI != MergableCIs.back()) {
1279           // TODO: Remove barrier if the merged parallel region includes the
1280           // 'nowait' clause.
1281           OMPInfoCache.OMPBuilder.createBarrier(
1282               InsertPointTy(NewCI->getParent(),
1283                             NewCI->getNextNode()->getIterator()),
1284               OMPD_parallel);
1285         }
1286 
1287         CI->eraseFromParent();
1288       }
1289 
1290       assert(OutlinedFn != OriginalFn && "Outlining failed");
1291       CGUpdater.registerOutlinedFunction(*OriginalFn, *OutlinedFn);
1292       CGUpdater.reanalyzeFunction(*OriginalFn);
1293 
1294       NumOpenMPParallelRegionsMerged += MergableCIs.size();
1295 
1296       return true;
1297     };
1298 
1299     // Helper function that identifes sequences of
1300     // __kmpc_fork_call uses in a basic block.
1301     auto DetectPRsCB = [&](Use &U, Function &F) {
1302       CallInst *CI = getCallIfRegularCall(U, &RFI);
1303       BB2PRMap[CI->getParent()].insert(CI);
1304 
1305       return false;
1306     };
1307 
1308     BB2PRMap.clear();
1309     RFI.foreachUse(SCC, DetectPRsCB);
1310     SmallVector<SmallVector<CallInst *, 4>, 4> MergableCIsVector;
1311     // Find mergable parallel regions within a basic block that are
1312     // safe to merge, that is any in-between instructions can safely
1313     // execute in parallel after merging.
1314     // TODO: support merging across basic-blocks.
1315     for (auto &It : BB2PRMap) {
1316       auto &CIs = It.getSecond();
1317       if (CIs.size() < 2)
1318         continue;
1319 
1320       BasicBlock *BB = It.getFirst();
1321       SmallVector<CallInst *, 4> MergableCIs;
1322 
1323       /// Returns true if the instruction is mergable, false otherwise.
1324       /// A terminator instruction is unmergable by definition since merging
1325       /// works within a BB. Instructions before the mergable region are
1326       /// mergable if they are not calls to OpenMP runtime functions that may
1327       /// set different execution parameters for subsequent parallel regions.
1328       /// Instructions in-between parallel regions are mergable if they are not
1329       /// calls to any non-intrinsic function since that may call a non-mergable
1330       /// OpenMP runtime function.
1331       auto IsMergable = [&](Instruction &I, bool IsBeforeMergableRegion) {
1332         // We do not merge across BBs, hence return false (unmergable) if the
1333         // instruction is a terminator.
1334         if (I.isTerminator())
1335           return false;
1336 
1337         if (!isa<CallInst>(&I))
1338           return true;
1339 
1340         CallInst *CI = cast<CallInst>(&I);
1341         if (IsBeforeMergableRegion) {
1342           Function *CalledFunction = CI->getCalledFunction();
1343           if (!CalledFunction)
1344             return false;
1345           // Return false (unmergable) if the call before the parallel
1346           // region calls an explicit affinity (proc_bind) or number of
1347           // threads (num_threads) compiler-generated function. Those settings
1348           // may be incompatible with following parallel regions.
1349           // TODO: ICV tracking to detect compatibility.
1350           for (const auto &RFI : UnmergableCallsInfo) {
1351             if (CalledFunction == RFI.Declaration)
1352               return false;
1353           }
1354         } else {
1355           // Return false (unmergable) if there is a call instruction
1356           // in-between parallel regions when it is not an intrinsic. It
1357           // may call an unmergable OpenMP runtime function in its callpath.
1358           // TODO: Keep track of possible OpenMP calls in the callpath.
1359           if (!isa<IntrinsicInst>(CI))
1360             return false;
1361         }
1362 
1363         return true;
1364       };
1365       // Find maximal number of parallel region CIs that are safe to merge.
1366       for (auto It = BB->begin(), End = BB->end(); It != End;) {
1367         Instruction &I = *It;
1368         ++It;
1369 
1370         if (CIs.count(&I)) {
1371           MergableCIs.push_back(cast<CallInst>(&I));
1372           continue;
1373         }
1374 
1375         // Continue expanding if the instruction is mergable.
1376         if (IsMergable(I, MergableCIs.empty()))
1377           continue;
1378 
1379         // Forward the instruction iterator to skip the next parallel region
1380         // since there is an unmergable instruction which can affect it.
1381         for (; It != End; ++It) {
1382           Instruction &SkipI = *It;
1383           if (CIs.count(&SkipI)) {
1384             LLVM_DEBUG(dbgs() << TAG << "Skip parallel region " << SkipI
1385                               << " due to " << I << "\n");
1386             ++It;
1387             break;
1388           }
1389         }
1390 
1391         // Store mergable regions found.
1392         if (MergableCIs.size() > 1) {
1393           MergableCIsVector.push_back(MergableCIs);
1394           LLVM_DEBUG(dbgs() << TAG << "Found " << MergableCIs.size()
1395                             << " parallel regions in block " << BB->getName()
1396                             << " of function " << BB->getParent()->getName()
1397                             << "\n";);
1398         }
1399 
1400         MergableCIs.clear();
1401       }
1402 
1403       if (!MergableCIsVector.empty()) {
1404         Changed = true;
1405 
1406         for (auto &MergableCIs : MergableCIsVector)
1407           Merge(MergableCIs, BB);
1408         MergableCIsVector.clear();
1409       }
1410     }
1411 
1412     if (Changed) {
1413       /// Re-collect use for fork calls, emitted barrier calls, and
1414       /// any emitted master/end_master calls.
1415       OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_fork_call);
1416       OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_barrier);
1417       OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_master);
1418       OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_end_master);
1419     }
1420 
1421     return Changed;
1422   }
1423 
1424   /// Try to delete parallel regions if possible.
1425   bool deleteParallelRegions() {
1426     const unsigned CallbackCalleeOperand = 2;
1427 
1428     OMPInformationCache::RuntimeFunctionInfo &RFI =
1429         OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
1430 
1431     if (!RFI.Declaration)
1432       return false;
1433 
1434     bool Changed = false;
1435     auto DeleteCallCB = [&](Use &U, Function &) {
1436       CallInst *CI = getCallIfRegularCall(U);
1437       if (!CI)
1438         return false;
1439       auto *Fn = dyn_cast<Function>(
1440           CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts());
1441       if (!Fn)
1442         return false;
1443       if (!Fn->onlyReadsMemory())
1444         return false;
1445       if (!Fn->hasFnAttribute(Attribute::WillReturn))
1446         return false;
1447 
1448       LLVM_DEBUG(dbgs() << TAG << "Delete read-only parallel region in "
1449                         << CI->getCaller()->getName() << "\n");
1450 
1451       auto Remark = [&](OptimizationRemark OR) {
1452         return OR << "Removing parallel region with no side-effects.";
1453       };
1454       emitRemark<OptimizationRemark>(CI, "OMP160", Remark);
1455 
1456       CI->eraseFromParent();
1457       Changed = true;
1458       ++NumOpenMPParallelRegionsDeleted;
1459       return true;
1460     };
1461 
1462     RFI.foreachUse(SCC, DeleteCallCB);
1463 
1464     return Changed;
1465   }
1466 
1467   /// Try to eliminate runtime calls by reusing existing ones.
1468   bool deduplicateRuntimeCalls() {
1469     bool Changed = false;
1470 
1471     RuntimeFunction DeduplicableRuntimeCallIDs[] = {
1472         OMPRTL_omp_get_num_threads,
1473         OMPRTL_omp_in_parallel,
1474         OMPRTL_omp_get_cancellation,
1475         OMPRTL_omp_get_supported_active_levels,
1476         OMPRTL_omp_get_level,
1477         OMPRTL_omp_get_ancestor_thread_num,
1478         OMPRTL_omp_get_team_size,
1479         OMPRTL_omp_get_active_level,
1480         OMPRTL_omp_in_final,
1481         OMPRTL_omp_get_proc_bind,
1482         OMPRTL_omp_get_num_places,
1483         OMPRTL_omp_get_num_procs,
1484         OMPRTL_omp_get_place_num,
1485         OMPRTL_omp_get_partition_num_places,
1486         OMPRTL_omp_get_partition_place_nums};
1487 
1488     // Global-tid is handled separately.
1489     SmallSetVector<Value *, 16> GTIdArgs;
1490     collectGlobalThreadIdArguments(GTIdArgs);
1491     LLVM_DEBUG(dbgs() << TAG << "Found " << GTIdArgs.size()
1492                       << " global thread ID arguments\n");
1493 
1494     for (Function *F : SCC) {
1495       for (auto DeduplicableRuntimeCallID : DeduplicableRuntimeCallIDs)
1496         Changed |= deduplicateRuntimeCalls(
1497             *F, OMPInfoCache.RFIs[DeduplicableRuntimeCallID]);
1498 
1499       // __kmpc_global_thread_num is special as we can replace it with an
1500       // argument in enough cases to make it worth trying.
1501       Value *GTIdArg = nullptr;
1502       for (Argument &Arg : F->args())
1503         if (GTIdArgs.count(&Arg)) {
1504           GTIdArg = &Arg;
1505           break;
1506         }
1507       Changed |= deduplicateRuntimeCalls(
1508           *F, OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num], GTIdArg);
1509     }
1510 
1511     return Changed;
1512   }
1513 
1514   /// Tries to remove known runtime symbols that are optional from the module.
1515   bool removeRuntimeSymbols() {
1516     // The RPC client symbol is defined in `libc` and indicates that something
1517     // required an RPC server. If its users were all optimized out then we can
1518     // safely remove it.
1519     // TODO: This should be somewhere more common in the future.
1520     if (GlobalVariable *GV = M.getNamedGlobal("__llvm_libc_rpc_client")) {
1521       if (!GV->getType()->isPointerTy())
1522         return false;
1523 
1524       Constant *C = GV->getInitializer();
1525       if (!C)
1526         return false;
1527 
1528       // Check to see if the only user of the RPC client is the external handle.
1529       GlobalVariable *Client = dyn_cast<GlobalVariable>(C->stripPointerCasts());
1530       if (!Client || Client->getNumUses() > 1 ||
1531           Client->user_back() != GV->getInitializer())
1532         return false;
1533 
1534       Client->replaceAllUsesWith(PoisonValue::get(Client->getType()));
1535       Client->eraseFromParent();
1536 
1537       GV->replaceAllUsesWith(PoisonValue::get(GV->getType()));
1538       GV->eraseFromParent();
1539 
1540       return true;
1541     }
1542     return false;
1543   }
1544 
1545   /// Tries to hide the latency of runtime calls that involve host to
1546   /// device memory transfers by splitting them into their "issue" and "wait"
1547   /// versions. The "issue" is moved upwards as much as possible. The "wait" is
1548   /// moved downards as much as possible. The "issue" issues the memory transfer
1549   /// asynchronously, returning a handle. The "wait" waits in the returned
1550   /// handle for the memory transfer to finish.
1551   bool hideMemTransfersLatency() {
1552     auto &RFI = OMPInfoCache.RFIs[OMPRTL___tgt_target_data_begin_mapper];
1553     bool Changed = false;
1554     auto SplitMemTransfers = [&](Use &U, Function &Decl) {
1555       auto *RTCall = getCallIfRegularCall(U, &RFI);
1556       if (!RTCall)
1557         return false;
1558 
1559       OffloadArray OffloadArrays[3];
1560       if (!getValuesInOffloadArrays(*RTCall, OffloadArrays))
1561         return false;
1562 
1563       LLVM_DEBUG(dumpValuesInOffloadArrays(OffloadArrays));
1564 
1565       // TODO: Check if can be moved upwards.
1566       bool WasSplit = false;
1567       Instruction *WaitMovementPoint = canBeMovedDownwards(*RTCall);
1568       if (WaitMovementPoint)
1569         WasSplit = splitTargetDataBeginRTC(*RTCall, *WaitMovementPoint);
1570 
1571       Changed |= WasSplit;
1572       return WasSplit;
1573     };
1574     if (OMPInfoCache.runtimeFnsAvailable(
1575             {OMPRTL___tgt_target_data_begin_mapper_issue,
1576              OMPRTL___tgt_target_data_begin_mapper_wait}))
1577       RFI.foreachUse(SCC, SplitMemTransfers);
1578 
1579     return Changed;
1580   }
1581 
1582   void analysisGlobalization() {
1583     auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
1584 
1585     auto CheckGlobalization = [&](Use &U, Function &Decl) {
1586       if (CallInst *CI = getCallIfRegularCall(U, &RFI)) {
1587         auto Remark = [&](OptimizationRemarkMissed ORM) {
1588           return ORM
1589                  << "Found thread data sharing on the GPU. "
1590                  << "Expect degraded performance due to data globalization.";
1591         };
1592         emitRemark<OptimizationRemarkMissed>(CI, "OMP112", Remark);
1593       }
1594 
1595       return false;
1596     };
1597 
1598     RFI.foreachUse(SCC, CheckGlobalization);
1599   }
1600 
1601   /// Maps the values stored in the offload arrays passed as arguments to
1602   /// \p RuntimeCall into the offload arrays in \p OAs.
1603   bool getValuesInOffloadArrays(CallInst &RuntimeCall,
1604                                 MutableArrayRef<OffloadArray> OAs) {
1605     assert(OAs.size() == 3 && "Need space for three offload arrays!");
1606 
1607     // A runtime call that involves memory offloading looks something like:
1608     // call void @__tgt_target_data_begin_mapper(arg0, arg1,
1609     //   i8** %offload_baseptrs, i8** %offload_ptrs, i64* %offload_sizes,
1610     // ...)
1611     // So, the idea is to access the allocas that allocate space for these
1612     // offload arrays, offload_baseptrs, offload_ptrs, offload_sizes.
1613     // Therefore:
1614     // i8** %offload_baseptrs.
1615     Value *BasePtrsArg =
1616         RuntimeCall.getArgOperand(OffloadArray::BasePtrsArgNum);
1617     // i8** %offload_ptrs.
1618     Value *PtrsArg = RuntimeCall.getArgOperand(OffloadArray::PtrsArgNum);
1619     // i8** %offload_sizes.
1620     Value *SizesArg = RuntimeCall.getArgOperand(OffloadArray::SizesArgNum);
1621 
1622     // Get values stored in **offload_baseptrs.
1623     auto *V = getUnderlyingObject(BasePtrsArg);
1624     if (!isa<AllocaInst>(V))
1625       return false;
1626     auto *BasePtrsArray = cast<AllocaInst>(V);
1627     if (!OAs[0].initialize(*BasePtrsArray, RuntimeCall))
1628       return false;
1629 
1630     // Get values stored in **offload_baseptrs.
1631     V = getUnderlyingObject(PtrsArg);
1632     if (!isa<AllocaInst>(V))
1633       return false;
1634     auto *PtrsArray = cast<AllocaInst>(V);
1635     if (!OAs[1].initialize(*PtrsArray, RuntimeCall))
1636       return false;
1637 
1638     // Get values stored in **offload_sizes.
1639     V = getUnderlyingObject(SizesArg);
1640     // If it's a [constant] global array don't analyze it.
1641     if (isa<GlobalValue>(V))
1642       return isa<Constant>(V);
1643     if (!isa<AllocaInst>(V))
1644       return false;
1645 
1646     auto *SizesArray = cast<AllocaInst>(V);
1647     if (!OAs[2].initialize(*SizesArray, RuntimeCall))
1648       return false;
1649 
1650     return true;
1651   }
1652 
1653   /// Prints the values in the OffloadArrays \p OAs using LLVM_DEBUG.
1654   /// For now this is a way to test that the function getValuesInOffloadArrays
1655   /// is working properly.
1656   /// TODO: Move this to a unittest when unittests are available for OpenMPOpt.
1657   void dumpValuesInOffloadArrays(ArrayRef<OffloadArray> OAs) {
1658     assert(OAs.size() == 3 && "There are three offload arrays to debug!");
1659 
1660     LLVM_DEBUG(dbgs() << TAG << " Successfully got offload values:\n");
1661     std::string ValuesStr;
1662     raw_string_ostream Printer(ValuesStr);
1663     std::string Separator = " --- ";
1664 
1665     for (auto *BP : OAs[0].StoredValues) {
1666       BP->print(Printer);
1667       Printer << Separator;
1668     }
1669     LLVM_DEBUG(dbgs() << "\t\toffload_baseptrs: " << ValuesStr << "\n");
1670     ValuesStr.clear();
1671 
1672     for (auto *P : OAs[1].StoredValues) {
1673       P->print(Printer);
1674       Printer << Separator;
1675     }
1676     LLVM_DEBUG(dbgs() << "\t\toffload_ptrs: " << ValuesStr << "\n");
1677     ValuesStr.clear();
1678 
1679     for (auto *S : OAs[2].StoredValues) {
1680       S->print(Printer);
1681       Printer << Separator;
1682     }
1683     LLVM_DEBUG(dbgs() << "\t\toffload_sizes: " << ValuesStr << "\n");
1684   }
1685 
1686   /// Returns the instruction where the "wait" counterpart \p RuntimeCall can be
1687   /// moved. Returns nullptr if the movement is not possible, or not worth it.
1688   Instruction *canBeMovedDownwards(CallInst &RuntimeCall) {
1689     // FIXME: This traverses only the BasicBlock where RuntimeCall is.
1690     //  Make it traverse the CFG.
1691 
1692     Instruction *CurrentI = &RuntimeCall;
1693     bool IsWorthIt = false;
1694     while ((CurrentI = CurrentI->getNextNode())) {
1695 
1696       // TODO: Once we detect the regions to be offloaded we should use the
1697       //  alias analysis manager to check if CurrentI may modify one of
1698       //  the offloaded regions.
1699       if (CurrentI->mayHaveSideEffects() || CurrentI->mayReadFromMemory()) {
1700         if (IsWorthIt)
1701           return CurrentI;
1702 
1703         return nullptr;
1704       }
1705 
1706       // FIXME: For now if we move it over anything without side effect
1707       //  is worth it.
1708       IsWorthIt = true;
1709     }
1710 
1711     // Return end of BasicBlock.
1712     return RuntimeCall.getParent()->getTerminator();
1713   }
1714 
1715   /// Splits \p RuntimeCall into its "issue" and "wait" counterparts.
1716   bool splitTargetDataBeginRTC(CallInst &RuntimeCall,
1717                                Instruction &WaitMovementPoint) {
1718     // Create stack allocated handle (__tgt_async_info) at the beginning of the
1719     // function. Used for storing information of the async transfer, allowing to
1720     // wait on it later.
1721     auto &IRBuilder = OMPInfoCache.OMPBuilder;
1722     Function *F = RuntimeCall.getCaller();
1723     BasicBlock &Entry = F->getEntryBlock();
1724     IRBuilder.Builder.SetInsertPoint(&Entry,
1725                                      Entry.getFirstNonPHIOrDbgOrAlloca());
1726     Value *Handle = IRBuilder.Builder.CreateAlloca(
1727         IRBuilder.AsyncInfo, /*ArraySize=*/nullptr, "handle");
1728     Handle =
1729         IRBuilder.Builder.CreateAddrSpaceCast(Handle, IRBuilder.AsyncInfoPtr);
1730 
1731     // Add "issue" runtime call declaration:
1732     // declare %struct.tgt_async_info @__tgt_target_data_begin_issue(i64, i32,
1733     //   i8**, i8**, i64*, i64*)
1734     FunctionCallee IssueDecl = IRBuilder.getOrCreateRuntimeFunction(
1735         M, OMPRTL___tgt_target_data_begin_mapper_issue);
1736 
1737     // Change RuntimeCall call site for its asynchronous version.
1738     SmallVector<Value *, 16> Args;
1739     for (auto &Arg : RuntimeCall.args())
1740       Args.push_back(Arg.get());
1741     Args.push_back(Handle);
1742 
1743     CallInst *IssueCallsite = CallInst::Create(IssueDecl, Args, /*NameStr=*/"",
1744                                                RuntimeCall.getIterator());
1745     OMPInfoCache.setCallingConvention(IssueDecl, IssueCallsite);
1746     RuntimeCall.eraseFromParent();
1747 
1748     // Add "wait" runtime call declaration:
1749     // declare void @__tgt_target_data_begin_wait(i64, %struct.__tgt_async_info)
1750     FunctionCallee WaitDecl = IRBuilder.getOrCreateRuntimeFunction(
1751         M, OMPRTL___tgt_target_data_begin_mapper_wait);
1752 
1753     Value *WaitParams[2] = {
1754         IssueCallsite->getArgOperand(
1755             OffloadArray::DeviceIDArgNum), // device_id.
1756         Handle                             // handle to wait on.
1757     };
1758     CallInst *WaitCallsite = CallInst::Create(
1759         WaitDecl, WaitParams, /*NameStr=*/"", WaitMovementPoint.getIterator());
1760     OMPInfoCache.setCallingConvention(WaitDecl, WaitCallsite);
1761 
1762     return true;
1763   }
1764 
1765   static Value *combinedIdentStruct(Value *CurrentIdent, Value *NextIdent,
1766                                     bool GlobalOnly, bool &SingleChoice) {
1767     if (CurrentIdent == NextIdent)
1768       return CurrentIdent;
1769 
1770     // TODO: Figure out how to actually combine multiple debug locations. For
1771     //       now we just keep an existing one if there is a single choice.
1772     if (!GlobalOnly || isa<GlobalValue>(NextIdent)) {
1773       SingleChoice = !CurrentIdent;
1774       return NextIdent;
1775     }
1776     return nullptr;
1777   }
1778 
1779   /// Return an `struct ident_t*` value that represents the ones used in the
1780   /// calls of \p RFI inside of \p F. If \p GlobalOnly is true, we will not
1781   /// return a local `struct ident_t*`. For now, if we cannot find a suitable
1782   /// return value we create one from scratch. We also do not yet combine
1783   /// information, e.g., the source locations, see combinedIdentStruct.
1784   Value *
1785   getCombinedIdentFromCallUsesIn(OMPInformationCache::RuntimeFunctionInfo &RFI,
1786                                  Function &F, bool GlobalOnly) {
1787     bool SingleChoice = true;
1788     Value *Ident = nullptr;
1789     auto CombineIdentStruct = [&](Use &U, Function &Caller) {
1790       CallInst *CI = getCallIfRegularCall(U, &RFI);
1791       if (!CI || &F != &Caller)
1792         return false;
1793       Ident = combinedIdentStruct(Ident, CI->getArgOperand(0),
1794                                   /* GlobalOnly */ true, SingleChoice);
1795       return false;
1796     };
1797     RFI.foreachUse(SCC, CombineIdentStruct);
1798 
1799     if (!Ident || !SingleChoice) {
1800       // The IRBuilder uses the insertion block to get to the module, this is
1801       // unfortunate but we work around it for now.
1802       if (!OMPInfoCache.OMPBuilder.getInsertionPoint().getBlock())
1803         OMPInfoCache.OMPBuilder.updateToLocation(OpenMPIRBuilder::InsertPointTy(
1804             &F.getEntryBlock(), F.getEntryBlock().begin()));
1805       // Create a fallback location if non was found.
1806       // TODO: Use the debug locations of the calls instead.
1807       uint32_t SrcLocStrSize;
1808       Constant *Loc =
1809           OMPInfoCache.OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
1810       Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(Loc, SrcLocStrSize);
1811     }
1812     return Ident;
1813   }
1814 
1815   /// Try to eliminate calls of \p RFI in \p F by reusing an existing one or
1816   /// \p ReplVal if given.
1817   bool deduplicateRuntimeCalls(Function &F,
1818                                OMPInformationCache::RuntimeFunctionInfo &RFI,
1819                                Value *ReplVal = nullptr) {
1820     auto *UV = RFI.getUseVector(F);
1821     if (!UV || UV->size() + (ReplVal != nullptr) < 2)
1822       return false;
1823 
1824     LLVM_DEBUG(
1825         dbgs() << TAG << "Deduplicate " << UV->size() << " uses of " << RFI.Name
1826                << (ReplVal ? " with an existing value\n" : "\n") << "\n");
1827 
1828     assert((!ReplVal || (isa<Argument>(ReplVal) &&
1829                          cast<Argument>(ReplVal)->getParent() == &F)) &&
1830            "Unexpected replacement value!");
1831 
1832     // TODO: Use dominance to find a good position instead.
1833     auto CanBeMoved = [this](CallBase &CB) {
1834       unsigned NumArgs = CB.arg_size();
1835       if (NumArgs == 0)
1836         return true;
1837       if (CB.getArgOperand(0)->getType() != OMPInfoCache.OMPBuilder.IdentPtr)
1838         return false;
1839       for (unsigned U = 1; U < NumArgs; ++U)
1840         if (isa<Instruction>(CB.getArgOperand(U)))
1841           return false;
1842       return true;
1843     };
1844 
1845     if (!ReplVal) {
1846       auto *DT =
1847           OMPInfoCache.getAnalysisResultForFunction<DominatorTreeAnalysis>(F);
1848       if (!DT)
1849         return false;
1850       Instruction *IP = nullptr;
1851       for (Use *U : *UV) {
1852         if (CallInst *CI = getCallIfRegularCall(*U, &RFI)) {
1853           if (IP)
1854             IP = DT->findNearestCommonDominator(IP, CI);
1855           else
1856             IP = CI;
1857           if (!CanBeMoved(*CI))
1858             continue;
1859           if (!ReplVal)
1860             ReplVal = CI;
1861         }
1862       }
1863       if (!ReplVal)
1864         return false;
1865       assert(IP && "Expected insertion point!");
1866       cast<Instruction>(ReplVal)->moveBefore(IP);
1867     }
1868 
1869     // If we use a call as a replacement value we need to make sure the ident is
1870     // valid at the new location. For now we just pick a global one, either
1871     // existing and used by one of the calls, or created from scratch.
1872     if (CallBase *CI = dyn_cast<CallBase>(ReplVal)) {
1873       if (!CI->arg_empty() &&
1874           CI->getArgOperand(0)->getType() == OMPInfoCache.OMPBuilder.IdentPtr) {
1875         Value *Ident = getCombinedIdentFromCallUsesIn(RFI, F,
1876                                                       /* GlobalOnly */ true);
1877         CI->setArgOperand(0, Ident);
1878       }
1879     }
1880 
1881     bool Changed = false;
1882     auto ReplaceAndDeleteCB = [&](Use &U, Function &Caller) {
1883       CallInst *CI = getCallIfRegularCall(U, &RFI);
1884       if (!CI || CI == ReplVal || &F != &Caller)
1885         return false;
1886       assert(CI->getCaller() == &F && "Unexpected call!");
1887 
1888       auto Remark = [&](OptimizationRemark OR) {
1889         return OR << "OpenMP runtime call "
1890                   << ore::NV("OpenMPOptRuntime", RFI.Name) << " deduplicated.";
1891       };
1892       if (CI->getDebugLoc())
1893         emitRemark<OptimizationRemark>(CI, "OMP170", Remark);
1894       else
1895         emitRemark<OptimizationRemark>(&F, "OMP170", Remark);
1896 
1897       CI->replaceAllUsesWith(ReplVal);
1898       CI->eraseFromParent();
1899       ++NumOpenMPRuntimeCallsDeduplicated;
1900       Changed = true;
1901       return true;
1902     };
1903     RFI.foreachUse(SCC, ReplaceAndDeleteCB);
1904 
1905     return Changed;
1906   }
1907 
1908   /// Collect arguments that represent the global thread id in \p GTIdArgs.
1909   void collectGlobalThreadIdArguments(SmallSetVector<Value *, 16> &GTIdArgs) {
1910     // TODO: Below we basically perform a fixpoint iteration with a pessimistic
1911     //       initialization. We could define an AbstractAttribute instead and
1912     //       run the Attributor here once it can be run as an SCC pass.
1913 
1914     // Helper to check the argument \p ArgNo at all call sites of \p F for
1915     // a GTId.
1916     auto CallArgOpIsGTId = [&](Function &F, unsigned ArgNo, CallInst &RefCI) {
1917       if (!F.hasLocalLinkage())
1918         return false;
1919       for (Use &U : F.uses()) {
1920         if (CallInst *CI = getCallIfRegularCall(U)) {
1921           Value *ArgOp = CI->getArgOperand(ArgNo);
1922           if (CI == &RefCI || GTIdArgs.count(ArgOp) ||
1923               getCallIfRegularCall(
1924                   *ArgOp, &OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num]))
1925             continue;
1926         }
1927         return false;
1928       }
1929       return true;
1930     };
1931 
1932     // Helper to identify uses of a GTId as GTId arguments.
1933     auto AddUserArgs = [&](Value &GTId) {
1934       for (Use &U : GTId.uses())
1935         if (CallInst *CI = dyn_cast<CallInst>(U.getUser()))
1936           if (CI->isArgOperand(&U))
1937             if (Function *Callee = CI->getCalledFunction())
1938               if (CallArgOpIsGTId(*Callee, U.getOperandNo(), *CI))
1939                 GTIdArgs.insert(Callee->getArg(U.getOperandNo()));
1940     };
1941 
1942     // The argument users of __kmpc_global_thread_num calls are GTIds.
1943     OMPInformationCache::RuntimeFunctionInfo &GlobThreadNumRFI =
1944         OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num];
1945 
1946     GlobThreadNumRFI.foreachUse(SCC, [&](Use &U, Function &F) {
1947       if (CallInst *CI = getCallIfRegularCall(U, &GlobThreadNumRFI))
1948         AddUserArgs(*CI);
1949       return false;
1950     });
1951 
1952     // Transitively search for more arguments by looking at the users of the
1953     // ones we know already. During the search the GTIdArgs vector is extended
1954     // so we cannot cache the size nor can we use a range based for.
1955     for (unsigned U = 0; U < GTIdArgs.size(); ++U)
1956       AddUserArgs(*GTIdArgs[U]);
1957   }
1958 
1959   /// Kernel (=GPU) optimizations and utility functions
1960   ///
1961   ///{{
1962 
1963   /// Cache to remember the unique kernel for a function.
1964   DenseMap<Function *, std::optional<Kernel>> UniqueKernelMap;
1965 
1966   /// Find the unique kernel that will execute \p F, if any.
1967   Kernel getUniqueKernelFor(Function &F);
1968 
1969   /// Find the unique kernel that will execute \p I, if any.
1970   Kernel getUniqueKernelFor(Instruction &I) {
1971     return getUniqueKernelFor(*I.getFunction());
1972   }
1973 
1974   /// Rewrite the device (=GPU) code state machine create in non-SPMD mode in
1975   /// the cases we can avoid taking the address of a function.
1976   bool rewriteDeviceCodeStateMachine();
1977 
1978   ///
1979   ///}}
1980 
1981   /// Emit a remark generically
1982   ///
1983   /// This template function can be used to generically emit a remark. The
1984   /// RemarkKind should be one of the following:
1985   ///   - OptimizationRemark to indicate a successful optimization attempt
1986   ///   - OptimizationRemarkMissed to report a failed optimization attempt
1987   ///   - OptimizationRemarkAnalysis to provide additional information about an
1988   ///     optimization attempt
1989   ///
1990   /// The remark is built using a callback function provided by the caller that
1991   /// takes a RemarkKind as input and returns a RemarkKind.
1992   template <typename RemarkKind, typename RemarkCallBack>
1993   void emitRemark(Instruction *I, StringRef RemarkName,
1994                   RemarkCallBack &&RemarkCB) const {
1995     Function *F = I->getParent()->getParent();
1996     auto &ORE = OREGetter(F);
1997 
1998     if (RemarkName.starts_with("OMP"))
1999       ORE.emit([&]() {
2000         return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I))
2001                << " [" << RemarkName << "]";
2002       });
2003     else
2004       ORE.emit(
2005           [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I)); });
2006   }
2007 
2008   /// Emit a remark on a function.
2009   template <typename RemarkKind, typename RemarkCallBack>
2010   void emitRemark(Function *F, StringRef RemarkName,
2011                   RemarkCallBack &&RemarkCB) const {
2012     auto &ORE = OREGetter(F);
2013 
2014     if (RemarkName.starts_with("OMP"))
2015       ORE.emit([&]() {
2016         return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F))
2017                << " [" << RemarkName << "]";
2018       });
2019     else
2020       ORE.emit(
2021           [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F)); });
2022   }
2023 
2024   /// The underlying module.
2025   Module &M;
2026 
2027   /// The SCC we are operating on.
2028   SmallVectorImpl<Function *> &SCC;
2029 
2030   /// Callback to update the call graph, the first argument is a removed call,
2031   /// the second an optional replacement call.
2032   CallGraphUpdater &CGUpdater;
2033 
2034   /// Callback to get an OptimizationRemarkEmitter from a Function *
2035   OptimizationRemarkGetter OREGetter;
2036 
2037   /// OpenMP-specific information cache. Also Used for Attributor runs.
2038   OMPInformationCache &OMPInfoCache;
2039 
2040   /// Attributor instance.
2041   Attributor &A;
2042 
2043   /// Helper function to run Attributor on SCC.
2044   bool runAttributor(bool IsModulePass) {
2045     if (SCC.empty())
2046       return false;
2047 
2048     registerAAs(IsModulePass);
2049 
2050     ChangeStatus Changed = A.run();
2051 
2052     LLVM_DEBUG(dbgs() << "[Attributor] Done with " << SCC.size()
2053                       << " functions, result: " << Changed << ".\n");
2054 
2055     if (Changed == ChangeStatus::CHANGED)
2056       OMPInfoCache.invalidateAnalyses();
2057 
2058     return Changed == ChangeStatus::CHANGED;
2059   }
2060 
2061   void registerFoldRuntimeCall(RuntimeFunction RF);
2062 
2063   /// Populate the Attributor with abstract attribute opportunities in the
2064   /// functions.
2065   void registerAAs(bool IsModulePass);
2066 
2067 public:
2068   /// Callback to register AAs for live functions, including internal functions
2069   /// marked live during the traversal.
2070   static void registerAAsForFunction(Attributor &A, const Function &F);
2071 };
2072 
2073 Kernel OpenMPOpt::getUniqueKernelFor(Function &F) {
2074   if (OMPInfoCache.CGSCC && !OMPInfoCache.CGSCC->empty() &&
2075       !OMPInfoCache.CGSCC->contains(&F))
2076     return nullptr;
2077 
2078   // Use a scope to keep the lifetime of the CachedKernel short.
2079   {
2080     std::optional<Kernel> &CachedKernel = UniqueKernelMap[&F];
2081     if (CachedKernel)
2082       return *CachedKernel;
2083 
2084     // TODO: We should use an AA to create an (optimistic and callback
2085     //       call-aware) call graph. For now we stick to simple patterns that
2086     //       are less powerful, basically the worst fixpoint.
2087     if (isOpenMPKernel(F)) {
2088       CachedKernel = Kernel(&F);
2089       return *CachedKernel;
2090     }
2091 
2092     CachedKernel = nullptr;
2093     if (!F.hasLocalLinkage()) {
2094 
2095       // See https://openmp.llvm.org/remarks/OptimizationRemarks.html
2096       auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2097         return ORA << "Potentially unknown OpenMP target region caller.";
2098       };
2099       emitRemark<OptimizationRemarkAnalysis>(&F, "OMP100", Remark);
2100 
2101       return nullptr;
2102     }
2103   }
2104 
2105   auto GetUniqueKernelForUse = [&](const Use &U) -> Kernel {
2106     if (auto *Cmp = dyn_cast<ICmpInst>(U.getUser())) {
2107       // Allow use in equality comparisons.
2108       if (Cmp->isEquality())
2109         return getUniqueKernelFor(*Cmp);
2110       return nullptr;
2111     }
2112     if (auto *CB = dyn_cast<CallBase>(U.getUser())) {
2113       // Allow direct calls.
2114       if (CB->isCallee(&U))
2115         return getUniqueKernelFor(*CB);
2116 
2117       OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
2118           OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
2119       // Allow the use in __kmpc_parallel_51 calls.
2120       if (OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI))
2121         return getUniqueKernelFor(*CB);
2122       return nullptr;
2123     }
2124     // Disallow every other use.
2125     return nullptr;
2126   };
2127 
2128   // TODO: In the future we want to track more than just a unique kernel.
2129   SmallPtrSet<Kernel, 2> PotentialKernels;
2130   OMPInformationCache::foreachUse(F, [&](const Use &U) {
2131     PotentialKernels.insert(GetUniqueKernelForUse(U));
2132   });
2133 
2134   Kernel K = nullptr;
2135   if (PotentialKernels.size() == 1)
2136     K = *PotentialKernels.begin();
2137 
2138   // Cache the result.
2139   UniqueKernelMap[&F] = K;
2140 
2141   return K;
2142 }
2143 
2144 bool OpenMPOpt::rewriteDeviceCodeStateMachine() {
2145   OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
2146       OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
2147 
2148   bool Changed = false;
2149   if (!KernelParallelRFI)
2150     return Changed;
2151 
2152   // If we have disabled state machine changes, exit
2153   if (DisableOpenMPOptStateMachineRewrite)
2154     return Changed;
2155 
2156   for (Function *F : SCC) {
2157 
2158     // Check if the function is a use in a __kmpc_parallel_51 call at
2159     // all.
2160     bool UnknownUse = false;
2161     bool KernelParallelUse = false;
2162     unsigned NumDirectCalls = 0;
2163 
2164     SmallVector<Use *, 2> ToBeReplacedStateMachineUses;
2165     OMPInformationCache::foreachUse(*F, [&](Use &U) {
2166       if (auto *CB = dyn_cast<CallBase>(U.getUser()))
2167         if (CB->isCallee(&U)) {
2168           ++NumDirectCalls;
2169           return;
2170         }
2171 
2172       if (isa<ICmpInst>(U.getUser())) {
2173         ToBeReplacedStateMachineUses.push_back(&U);
2174         return;
2175       }
2176 
2177       // Find wrapper functions that represent parallel kernels.
2178       CallInst *CI =
2179           OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI);
2180       const unsigned int WrapperFunctionArgNo = 6;
2181       if (!KernelParallelUse && CI &&
2182           CI->getArgOperandNo(&U) == WrapperFunctionArgNo) {
2183         KernelParallelUse = true;
2184         ToBeReplacedStateMachineUses.push_back(&U);
2185         return;
2186       }
2187       UnknownUse = true;
2188     });
2189 
2190     // Do not emit a remark if we haven't seen a __kmpc_parallel_51
2191     // use.
2192     if (!KernelParallelUse)
2193       continue;
2194 
2195     // If this ever hits, we should investigate.
2196     // TODO: Checking the number of uses is not a necessary restriction and
2197     // should be lifted.
2198     if (UnknownUse || NumDirectCalls != 1 ||
2199         ToBeReplacedStateMachineUses.size() > 2) {
2200       auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2201         return ORA << "Parallel region is used in "
2202                    << (UnknownUse ? "unknown" : "unexpected")
2203                    << " ways. Will not attempt to rewrite the state machine.";
2204       };
2205       emitRemark<OptimizationRemarkAnalysis>(F, "OMP101", Remark);
2206       continue;
2207     }
2208 
2209     // Even if we have __kmpc_parallel_51 calls, we (for now) give
2210     // up if the function is not called from a unique kernel.
2211     Kernel K = getUniqueKernelFor(*F);
2212     if (!K) {
2213       auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2214         return ORA << "Parallel region is not called from a unique kernel. "
2215                       "Will not attempt to rewrite the state machine.";
2216       };
2217       emitRemark<OptimizationRemarkAnalysis>(F, "OMP102", Remark);
2218       continue;
2219     }
2220 
2221     // We now know F is a parallel body function called only from the kernel K.
2222     // We also identified the state machine uses in which we replace the
2223     // function pointer by a new global symbol for identification purposes. This
2224     // ensures only direct calls to the function are left.
2225 
2226     Module &M = *F->getParent();
2227     Type *Int8Ty = Type::getInt8Ty(M.getContext());
2228 
2229     auto *ID = new GlobalVariable(
2230         M, Int8Ty, /* isConstant */ true, GlobalValue::PrivateLinkage,
2231         UndefValue::get(Int8Ty), F->getName() + ".ID");
2232 
2233     for (Use *U : ToBeReplacedStateMachineUses)
2234       U->set(ConstantExpr::getPointerBitCastOrAddrSpaceCast(
2235           ID, U->get()->getType()));
2236 
2237     ++NumOpenMPParallelRegionsReplacedInGPUStateMachine;
2238 
2239     Changed = true;
2240   }
2241 
2242   return Changed;
2243 }
2244 
2245 /// Abstract Attribute for tracking ICV values.
2246 struct AAICVTracker : public StateWrapper<BooleanState, AbstractAttribute> {
2247   using Base = StateWrapper<BooleanState, AbstractAttribute>;
2248   AAICVTracker(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
2249 
2250   /// Returns true if value is assumed to be tracked.
2251   bool isAssumedTracked() const { return getAssumed(); }
2252 
2253   /// Returns true if value is known to be tracked.
2254   bool isKnownTracked() const { return getAssumed(); }
2255 
2256   /// Create an abstract attribute biew for the position \p IRP.
2257   static AAICVTracker &createForPosition(const IRPosition &IRP, Attributor &A);
2258 
2259   /// Return the value with which \p I can be replaced for specific \p ICV.
2260   virtual std::optional<Value *> getReplacementValue(InternalControlVar ICV,
2261                                                      const Instruction *I,
2262                                                      Attributor &A) const {
2263     return std::nullopt;
2264   }
2265 
2266   /// Return an assumed unique ICV value if a single candidate is found. If
2267   /// there cannot be one, return a nullptr. If it is not clear yet, return
2268   /// std::nullopt.
2269   virtual std::optional<Value *>
2270   getUniqueReplacementValue(InternalControlVar ICV) const = 0;
2271 
2272   // Currently only nthreads is being tracked.
2273   // this array will only grow with time.
2274   InternalControlVar TrackableICVs[1] = {ICV_nthreads};
2275 
2276   /// See AbstractAttribute::getName()
2277   const std::string getName() const override { return "AAICVTracker"; }
2278 
2279   /// See AbstractAttribute::getIdAddr()
2280   const char *getIdAddr() const override { return &ID; }
2281 
2282   /// This function should return true if the type of the \p AA is AAICVTracker
2283   static bool classof(const AbstractAttribute *AA) {
2284     return (AA->getIdAddr() == &ID);
2285   }
2286 
2287   static const char ID;
2288 };
2289 
2290 struct AAICVTrackerFunction : public AAICVTracker {
2291   AAICVTrackerFunction(const IRPosition &IRP, Attributor &A)
2292       : AAICVTracker(IRP, A) {}
2293 
2294   // FIXME: come up with better string.
2295   const std::string getAsStr(Attributor *) const override {
2296     return "ICVTrackerFunction";
2297   }
2298 
2299   // FIXME: come up with some stats.
2300   void trackStatistics() const override {}
2301 
2302   /// We don't manifest anything for this AA.
2303   ChangeStatus manifest(Attributor &A) override {
2304     return ChangeStatus::UNCHANGED;
2305   }
2306 
2307   // Map of ICV to their values at specific program point.
2308   EnumeratedArray<DenseMap<Instruction *, Value *>, InternalControlVar,
2309                   InternalControlVar::ICV___last>
2310       ICVReplacementValuesMap;
2311 
2312   ChangeStatus updateImpl(Attributor &A) override {
2313     ChangeStatus HasChanged = ChangeStatus::UNCHANGED;
2314 
2315     Function *F = getAnchorScope();
2316 
2317     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2318 
2319     for (InternalControlVar ICV : TrackableICVs) {
2320       auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2321 
2322       auto &ValuesMap = ICVReplacementValuesMap[ICV];
2323       auto TrackValues = [&](Use &U, Function &) {
2324         CallInst *CI = OpenMPOpt::getCallIfRegularCall(U);
2325         if (!CI)
2326           return false;
2327 
2328         // FIXME: handle setters with more that 1 arguments.
2329         /// Track new value.
2330         if (ValuesMap.insert(std::make_pair(CI, CI->getArgOperand(0))).second)
2331           HasChanged = ChangeStatus::CHANGED;
2332 
2333         return false;
2334       };
2335 
2336       auto CallCheck = [&](Instruction &I) {
2337         std::optional<Value *> ReplVal = getValueForCall(A, I, ICV);
2338         if (ReplVal && ValuesMap.insert(std::make_pair(&I, *ReplVal)).second)
2339           HasChanged = ChangeStatus::CHANGED;
2340 
2341         return true;
2342       };
2343 
2344       // Track all changes of an ICV.
2345       SetterRFI.foreachUse(TrackValues, F);
2346 
2347       bool UsedAssumedInformation = false;
2348       A.checkForAllInstructions(CallCheck, *this, {Instruction::Call},
2349                                 UsedAssumedInformation,
2350                                 /* CheckBBLivenessOnly */ true);
2351 
2352       /// TODO: Figure out a way to avoid adding entry in
2353       /// ICVReplacementValuesMap
2354       Instruction *Entry = &F->getEntryBlock().front();
2355       if (HasChanged == ChangeStatus::CHANGED && !ValuesMap.count(Entry))
2356         ValuesMap.insert(std::make_pair(Entry, nullptr));
2357     }
2358 
2359     return HasChanged;
2360   }
2361 
2362   /// Helper to check if \p I is a call and get the value for it if it is
2363   /// unique.
2364   std::optional<Value *> getValueForCall(Attributor &A, const Instruction &I,
2365                                          InternalControlVar &ICV) const {
2366 
2367     const auto *CB = dyn_cast<CallBase>(&I);
2368     if (!CB || CB->hasFnAttr("no_openmp") ||
2369         CB->hasFnAttr("no_openmp_routines"))
2370       return std::nullopt;
2371 
2372     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2373     auto &GetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Getter];
2374     auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2375     Function *CalledFunction = CB->getCalledFunction();
2376 
2377     // Indirect call, assume ICV changes.
2378     if (CalledFunction == nullptr)
2379       return nullptr;
2380     if (CalledFunction == GetterRFI.Declaration)
2381       return std::nullopt;
2382     if (CalledFunction == SetterRFI.Declaration) {
2383       if (ICVReplacementValuesMap[ICV].count(&I))
2384         return ICVReplacementValuesMap[ICV].lookup(&I);
2385 
2386       return nullptr;
2387     }
2388 
2389     // Since we don't know, assume it changes the ICV.
2390     if (CalledFunction->isDeclaration())
2391       return nullptr;
2392 
2393     const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2394         *this, IRPosition::callsite_returned(*CB), DepClassTy::REQUIRED);
2395 
2396     if (ICVTrackingAA->isAssumedTracked()) {
2397       std::optional<Value *> URV =
2398           ICVTrackingAA->getUniqueReplacementValue(ICV);
2399       if (!URV || (*URV && AA::isValidAtPosition(AA::ValueAndContext(**URV, I),
2400                                                  OMPInfoCache)))
2401         return URV;
2402     }
2403 
2404     // If we don't know, assume it changes.
2405     return nullptr;
2406   }
2407 
2408   // We don't check unique value for a function, so return std::nullopt.
2409   std::optional<Value *>
2410   getUniqueReplacementValue(InternalControlVar ICV) const override {
2411     return std::nullopt;
2412   }
2413 
2414   /// Return the value with which \p I can be replaced for specific \p ICV.
2415   std::optional<Value *> getReplacementValue(InternalControlVar ICV,
2416                                              const Instruction *I,
2417                                              Attributor &A) const override {
2418     const auto &ValuesMap = ICVReplacementValuesMap[ICV];
2419     if (ValuesMap.count(I))
2420       return ValuesMap.lookup(I);
2421 
2422     SmallVector<const Instruction *, 16> Worklist;
2423     SmallPtrSet<const Instruction *, 16> Visited;
2424     Worklist.push_back(I);
2425 
2426     std::optional<Value *> ReplVal;
2427 
2428     while (!Worklist.empty()) {
2429       const Instruction *CurrInst = Worklist.pop_back_val();
2430       if (!Visited.insert(CurrInst).second)
2431         continue;
2432 
2433       const BasicBlock *CurrBB = CurrInst->getParent();
2434 
2435       // Go up and look for all potential setters/calls that might change the
2436       // ICV.
2437       while ((CurrInst = CurrInst->getPrevNode())) {
2438         if (ValuesMap.count(CurrInst)) {
2439           std::optional<Value *> NewReplVal = ValuesMap.lookup(CurrInst);
2440           // Unknown value, track new.
2441           if (!ReplVal) {
2442             ReplVal = NewReplVal;
2443             break;
2444           }
2445 
2446           // If we found a new value, we can't know the icv value anymore.
2447           if (NewReplVal)
2448             if (ReplVal != NewReplVal)
2449               return nullptr;
2450 
2451           break;
2452         }
2453 
2454         std::optional<Value *> NewReplVal = getValueForCall(A, *CurrInst, ICV);
2455         if (!NewReplVal)
2456           continue;
2457 
2458         // Unknown value, track new.
2459         if (!ReplVal) {
2460           ReplVal = NewReplVal;
2461           break;
2462         }
2463 
2464         // if (NewReplVal.hasValue())
2465         // We found a new value, we can't know the icv value anymore.
2466         if (ReplVal != NewReplVal)
2467           return nullptr;
2468       }
2469 
2470       // If we are in the same BB and we have a value, we are done.
2471       if (CurrBB == I->getParent() && ReplVal)
2472         return ReplVal;
2473 
2474       // Go through all predecessors and add terminators for analysis.
2475       for (const BasicBlock *Pred : predecessors(CurrBB))
2476         if (const Instruction *Terminator = Pred->getTerminator())
2477           Worklist.push_back(Terminator);
2478     }
2479 
2480     return ReplVal;
2481   }
2482 };
2483 
2484 struct AAICVTrackerFunctionReturned : AAICVTracker {
2485   AAICVTrackerFunctionReturned(const IRPosition &IRP, Attributor &A)
2486       : AAICVTracker(IRP, A) {}
2487 
2488   // FIXME: come up with better string.
2489   const std::string getAsStr(Attributor *) const override {
2490     return "ICVTrackerFunctionReturned";
2491   }
2492 
2493   // FIXME: come up with some stats.
2494   void trackStatistics() const override {}
2495 
2496   /// We don't manifest anything for this AA.
2497   ChangeStatus manifest(Attributor &A) override {
2498     return ChangeStatus::UNCHANGED;
2499   }
2500 
2501   // Map of ICV to their values at specific program point.
2502   EnumeratedArray<std::optional<Value *>, InternalControlVar,
2503                   InternalControlVar::ICV___last>
2504       ICVReplacementValuesMap;
2505 
2506   /// Return the value with which \p I can be replaced for specific \p ICV.
2507   std::optional<Value *>
2508   getUniqueReplacementValue(InternalControlVar ICV) const override {
2509     return ICVReplacementValuesMap[ICV];
2510   }
2511 
2512   ChangeStatus updateImpl(Attributor &A) override {
2513     ChangeStatus Changed = ChangeStatus::UNCHANGED;
2514     const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2515         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2516 
2517     if (!ICVTrackingAA->isAssumedTracked())
2518       return indicatePessimisticFixpoint();
2519 
2520     for (InternalControlVar ICV : TrackableICVs) {
2521       std::optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2522       std::optional<Value *> UniqueICVValue;
2523 
2524       auto CheckReturnInst = [&](Instruction &I) {
2525         std::optional<Value *> NewReplVal =
2526             ICVTrackingAA->getReplacementValue(ICV, &I, A);
2527 
2528         // If we found a second ICV value there is no unique returned value.
2529         if (UniqueICVValue && UniqueICVValue != NewReplVal)
2530           return false;
2531 
2532         UniqueICVValue = NewReplVal;
2533 
2534         return true;
2535       };
2536 
2537       bool UsedAssumedInformation = false;
2538       if (!A.checkForAllInstructions(CheckReturnInst, *this, {Instruction::Ret},
2539                                      UsedAssumedInformation,
2540                                      /* CheckBBLivenessOnly */ true))
2541         UniqueICVValue = nullptr;
2542 
2543       if (UniqueICVValue == ReplVal)
2544         continue;
2545 
2546       ReplVal = UniqueICVValue;
2547       Changed = ChangeStatus::CHANGED;
2548     }
2549 
2550     return Changed;
2551   }
2552 };
2553 
2554 struct AAICVTrackerCallSite : AAICVTracker {
2555   AAICVTrackerCallSite(const IRPosition &IRP, Attributor &A)
2556       : AAICVTracker(IRP, A) {}
2557 
2558   void initialize(Attributor &A) override {
2559     assert(getAnchorScope() && "Expected anchor function");
2560 
2561     // We only initialize this AA for getters, so we need to know which ICV it
2562     // gets.
2563     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2564     for (InternalControlVar ICV : TrackableICVs) {
2565       auto ICVInfo = OMPInfoCache.ICVs[ICV];
2566       auto &Getter = OMPInfoCache.RFIs[ICVInfo.Getter];
2567       if (Getter.Declaration == getAssociatedFunction()) {
2568         AssociatedICV = ICVInfo.Kind;
2569         return;
2570       }
2571     }
2572 
2573     /// Unknown ICV.
2574     indicatePessimisticFixpoint();
2575   }
2576 
2577   ChangeStatus manifest(Attributor &A) override {
2578     if (!ReplVal || !*ReplVal)
2579       return ChangeStatus::UNCHANGED;
2580 
2581     A.changeAfterManifest(IRPosition::inst(*getCtxI()), **ReplVal);
2582     A.deleteAfterManifest(*getCtxI());
2583 
2584     return ChangeStatus::CHANGED;
2585   }
2586 
2587   // FIXME: come up with better string.
2588   const std::string getAsStr(Attributor *) const override {
2589     return "ICVTrackerCallSite";
2590   }
2591 
2592   // FIXME: come up with some stats.
2593   void trackStatistics() const override {}
2594 
2595   InternalControlVar AssociatedICV;
2596   std::optional<Value *> ReplVal;
2597 
2598   ChangeStatus updateImpl(Attributor &A) override {
2599     const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2600         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2601 
2602     // We don't have any information, so we assume it changes the ICV.
2603     if (!ICVTrackingAA->isAssumedTracked())
2604       return indicatePessimisticFixpoint();
2605 
2606     std::optional<Value *> NewReplVal =
2607         ICVTrackingAA->getReplacementValue(AssociatedICV, getCtxI(), A);
2608 
2609     if (ReplVal == NewReplVal)
2610       return ChangeStatus::UNCHANGED;
2611 
2612     ReplVal = NewReplVal;
2613     return ChangeStatus::CHANGED;
2614   }
2615 
2616   // Return the value with which associated value can be replaced for specific
2617   // \p ICV.
2618   std::optional<Value *>
2619   getUniqueReplacementValue(InternalControlVar ICV) const override {
2620     return ReplVal;
2621   }
2622 };
2623 
2624 struct AAICVTrackerCallSiteReturned : AAICVTracker {
2625   AAICVTrackerCallSiteReturned(const IRPosition &IRP, Attributor &A)
2626       : AAICVTracker(IRP, A) {}
2627 
2628   // FIXME: come up with better string.
2629   const std::string getAsStr(Attributor *) const override {
2630     return "ICVTrackerCallSiteReturned";
2631   }
2632 
2633   // FIXME: come up with some stats.
2634   void trackStatistics() const override {}
2635 
2636   /// We don't manifest anything for this AA.
2637   ChangeStatus manifest(Attributor &A) override {
2638     return ChangeStatus::UNCHANGED;
2639   }
2640 
2641   // Map of ICV to their values at specific program point.
2642   EnumeratedArray<std::optional<Value *>, InternalControlVar,
2643                   InternalControlVar::ICV___last>
2644       ICVReplacementValuesMap;
2645 
2646   /// Return the value with which associated value can be replaced for specific
2647   /// \p ICV.
2648   std::optional<Value *>
2649   getUniqueReplacementValue(InternalControlVar ICV) const override {
2650     return ICVReplacementValuesMap[ICV];
2651   }
2652 
2653   ChangeStatus updateImpl(Attributor &A) override {
2654     ChangeStatus Changed = ChangeStatus::UNCHANGED;
2655     const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2656         *this, IRPosition::returned(*getAssociatedFunction()),
2657         DepClassTy::REQUIRED);
2658 
2659     // We don't have any information, so we assume it changes the ICV.
2660     if (!ICVTrackingAA->isAssumedTracked())
2661       return indicatePessimisticFixpoint();
2662 
2663     for (InternalControlVar ICV : TrackableICVs) {
2664       std::optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2665       std::optional<Value *> NewReplVal =
2666           ICVTrackingAA->getUniqueReplacementValue(ICV);
2667 
2668       if (ReplVal == NewReplVal)
2669         continue;
2670 
2671       ReplVal = NewReplVal;
2672       Changed = ChangeStatus::CHANGED;
2673     }
2674     return Changed;
2675   }
2676 };
2677 
2678 /// Determines if \p BB exits the function unconditionally itself or reaches a
2679 /// block that does through only unique successors.
2680 static bool hasFunctionEndAsUniqueSuccessor(const BasicBlock *BB) {
2681   if (succ_empty(BB))
2682     return true;
2683   const BasicBlock *const Successor = BB->getUniqueSuccessor();
2684   if (!Successor)
2685     return false;
2686   return hasFunctionEndAsUniqueSuccessor(Successor);
2687 }
2688 
2689 struct AAExecutionDomainFunction : public AAExecutionDomain {
2690   AAExecutionDomainFunction(const IRPosition &IRP, Attributor &A)
2691       : AAExecutionDomain(IRP, A) {}
2692 
2693   ~AAExecutionDomainFunction() { delete RPOT; }
2694 
2695   void initialize(Attributor &A) override {
2696     Function *F = getAnchorScope();
2697     assert(F && "Expected anchor function");
2698     RPOT = new ReversePostOrderTraversal<Function *>(F);
2699   }
2700 
2701   const std::string getAsStr(Attributor *) const override {
2702     unsigned TotalBlocks = 0, InitialThreadBlocks = 0, AlignedBlocks = 0;
2703     for (auto &It : BEDMap) {
2704       if (!It.getFirst())
2705         continue;
2706       TotalBlocks++;
2707       InitialThreadBlocks += It.getSecond().IsExecutedByInitialThreadOnly;
2708       AlignedBlocks += It.getSecond().IsReachedFromAlignedBarrierOnly &&
2709                        It.getSecond().IsReachingAlignedBarrierOnly;
2710     }
2711     return "[AAExecutionDomain] " + std::to_string(InitialThreadBlocks) + "/" +
2712            std::to_string(AlignedBlocks) + " of " +
2713            std::to_string(TotalBlocks) +
2714            " executed by initial thread / aligned";
2715   }
2716 
2717   /// See AbstractAttribute::trackStatistics().
2718   void trackStatistics() const override {}
2719 
2720   ChangeStatus manifest(Attributor &A) override {
2721     LLVM_DEBUG({
2722       for (const BasicBlock &BB : *getAnchorScope()) {
2723         if (!isExecutedByInitialThreadOnly(BB))
2724           continue;
2725         dbgs() << TAG << " Basic block @" << getAnchorScope()->getName() << " "
2726                << BB.getName() << " is executed by a single thread.\n";
2727       }
2728     });
2729 
2730     ChangeStatus Changed = ChangeStatus::UNCHANGED;
2731 
2732     if (DisableOpenMPOptBarrierElimination)
2733       return Changed;
2734 
2735     SmallPtrSet<CallBase *, 16> DeletedBarriers;
2736     auto HandleAlignedBarrier = [&](CallBase *CB) {
2737       const ExecutionDomainTy &ED = CB ? CEDMap[{CB, PRE}] : BEDMap[nullptr];
2738       if (!ED.IsReachedFromAlignedBarrierOnly ||
2739           ED.EncounteredNonLocalSideEffect)
2740         return;
2741       if (!ED.EncounteredAssumes.empty() && !A.isModulePass())
2742         return;
2743 
2744       // We can remove this barrier, if it is one, or aligned barriers reaching
2745       // the kernel end (if CB is nullptr). Aligned barriers reaching the kernel
2746       // end should only be removed if the kernel end is their unique successor;
2747       // otherwise, they may have side-effects that aren't accounted for in the
2748       // kernel end in their other successors. If those barriers have other
2749       // barriers reaching them, those can be transitively removed as well as
2750       // long as the kernel end is also their unique successor.
2751       if (CB) {
2752         DeletedBarriers.insert(CB);
2753         A.deleteAfterManifest(*CB);
2754         ++NumBarriersEliminated;
2755         Changed = ChangeStatus::CHANGED;
2756       } else if (!ED.AlignedBarriers.empty()) {
2757         Changed = ChangeStatus::CHANGED;
2758         SmallVector<CallBase *> Worklist(ED.AlignedBarriers.begin(),
2759                                          ED.AlignedBarriers.end());
2760         SmallSetVector<CallBase *, 16> Visited;
2761         while (!Worklist.empty()) {
2762           CallBase *LastCB = Worklist.pop_back_val();
2763           if (!Visited.insert(LastCB))
2764             continue;
2765           if (LastCB->getFunction() != getAnchorScope())
2766             continue;
2767           if (!hasFunctionEndAsUniqueSuccessor(LastCB->getParent()))
2768             continue;
2769           if (!DeletedBarriers.count(LastCB)) {
2770             ++NumBarriersEliminated;
2771             A.deleteAfterManifest(*LastCB);
2772             continue;
2773           }
2774           // The final aligned barrier (LastCB) reaching the kernel end was
2775           // removed already. This means we can go one step further and remove
2776           // the barriers encoutered last before (LastCB).
2777           const ExecutionDomainTy &LastED = CEDMap[{LastCB, PRE}];
2778           Worklist.append(LastED.AlignedBarriers.begin(),
2779                           LastED.AlignedBarriers.end());
2780         }
2781       }
2782 
2783       // If we actually eliminated a barrier we need to eliminate the associated
2784       // llvm.assumes as well to avoid creating UB.
2785       if (!ED.EncounteredAssumes.empty() && (CB || !ED.AlignedBarriers.empty()))
2786         for (auto *AssumeCB : ED.EncounteredAssumes)
2787           A.deleteAfterManifest(*AssumeCB);
2788     };
2789 
2790     for (auto *CB : AlignedBarriers)
2791       HandleAlignedBarrier(CB);
2792 
2793     // Handle the "kernel end barrier" for kernels too.
2794     if (omp::isOpenMPKernel(*getAnchorScope()))
2795       HandleAlignedBarrier(nullptr);
2796 
2797     return Changed;
2798   }
2799 
2800   bool isNoOpFence(const FenceInst &FI) const override {
2801     return getState().isValidState() && !NonNoOpFences.count(&FI);
2802   }
2803 
2804   /// Merge barrier and assumption information from \p PredED into the successor
2805   /// \p ED.
2806   void
2807   mergeInPredecessorBarriersAndAssumptions(Attributor &A, ExecutionDomainTy &ED,
2808                                            const ExecutionDomainTy &PredED);
2809 
2810   /// Merge all information from \p PredED into the successor \p ED. If
2811   /// \p InitialEdgeOnly is set, only the initial edge will enter the block
2812   /// represented by \p ED from this predecessor.
2813   bool mergeInPredecessor(Attributor &A, ExecutionDomainTy &ED,
2814                           const ExecutionDomainTy &PredED,
2815                           bool InitialEdgeOnly = false);
2816 
2817   /// Accumulate information for the entry block in \p EntryBBED.
2818   bool handleCallees(Attributor &A, ExecutionDomainTy &EntryBBED);
2819 
2820   /// See AbstractAttribute::updateImpl.
2821   ChangeStatus updateImpl(Attributor &A) override;
2822 
2823   /// Query interface, see AAExecutionDomain
2824   ///{
2825   bool isExecutedByInitialThreadOnly(const BasicBlock &BB) const override {
2826     if (!isValidState())
2827       return false;
2828     assert(BB.getParent() == getAnchorScope() && "Block is out of scope!");
2829     return BEDMap.lookup(&BB).IsExecutedByInitialThreadOnly;
2830   }
2831 
2832   bool isExecutedInAlignedRegion(Attributor &A,
2833                                  const Instruction &I) const override {
2834     assert(I.getFunction() == getAnchorScope() &&
2835            "Instruction is out of scope!");
2836     if (!isValidState())
2837       return false;
2838 
2839     bool ForwardIsOk = true;
2840     const Instruction *CurI;
2841 
2842     // Check forward until a call or the block end is reached.
2843     CurI = &I;
2844     do {
2845       auto *CB = dyn_cast<CallBase>(CurI);
2846       if (!CB)
2847         continue;
2848       if (CB != &I && AlignedBarriers.contains(const_cast<CallBase *>(CB)))
2849         return true;
2850       const auto &It = CEDMap.find({CB, PRE});
2851       if (It == CEDMap.end())
2852         continue;
2853       if (!It->getSecond().IsReachingAlignedBarrierOnly)
2854         ForwardIsOk = false;
2855       break;
2856     } while ((CurI = CurI->getNextNonDebugInstruction()));
2857 
2858     if (!CurI && !BEDMap.lookup(I.getParent()).IsReachingAlignedBarrierOnly)
2859       ForwardIsOk = false;
2860 
2861     // Check backward until a call or the block beginning is reached.
2862     CurI = &I;
2863     do {
2864       auto *CB = dyn_cast<CallBase>(CurI);
2865       if (!CB)
2866         continue;
2867       if (CB != &I && AlignedBarriers.contains(const_cast<CallBase *>(CB)))
2868         return true;
2869       const auto &It = CEDMap.find({CB, POST});
2870       if (It == CEDMap.end())
2871         continue;
2872       if (It->getSecond().IsReachedFromAlignedBarrierOnly)
2873         break;
2874       return false;
2875     } while ((CurI = CurI->getPrevNonDebugInstruction()));
2876 
2877     // Delayed decision on the forward pass to allow aligned barrier detection
2878     // in the backwards traversal.
2879     if (!ForwardIsOk)
2880       return false;
2881 
2882     if (!CurI) {
2883       const BasicBlock *BB = I.getParent();
2884       if (BB == &BB->getParent()->getEntryBlock())
2885         return BEDMap.lookup(nullptr).IsReachedFromAlignedBarrierOnly;
2886       if (!llvm::all_of(predecessors(BB), [&](const BasicBlock *PredBB) {
2887             return BEDMap.lookup(PredBB).IsReachedFromAlignedBarrierOnly;
2888           })) {
2889         return false;
2890       }
2891     }
2892 
2893     // On neither traversal we found a anything but aligned barriers.
2894     return true;
2895   }
2896 
2897   ExecutionDomainTy getExecutionDomain(const BasicBlock &BB) const override {
2898     assert(isValidState() &&
2899            "No request should be made against an invalid state!");
2900     return BEDMap.lookup(&BB);
2901   }
2902   std::pair<ExecutionDomainTy, ExecutionDomainTy>
2903   getExecutionDomain(const CallBase &CB) const override {
2904     assert(isValidState() &&
2905            "No request should be made against an invalid state!");
2906     return {CEDMap.lookup({&CB, PRE}), CEDMap.lookup({&CB, POST})};
2907   }
2908   ExecutionDomainTy getFunctionExecutionDomain() const override {
2909     assert(isValidState() &&
2910            "No request should be made against an invalid state!");
2911     return InterProceduralED;
2912   }
2913   ///}
2914 
2915   // Check if the edge into the successor block contains a condition that only
2916   // lets the main thread execute it.
2917   static bool isInitialThreadOnlyEdge(Attributor &A, BranchInst *Edge,
2918                                       BasicBlock &SuccessorBB) {
2919     if (!Edge || !Edge->isConditional())
2920       return false;
2921     if (Edge->getSuccessor(0) != &SuccessorBB)
2922       return false;
2923 
2924     auto *Cmp = dyn_cast<CmpInst>(Edge->getCondition());
2925     if (!Cmp || !Cmp->isTrueWhenEqual() || !Cmp->isEquality())
2926       return false;
2927 
2928     ConstantInt *C = dyn_cast<ConstantInt>(Cmp->getOperand(1));
2929     if (!C)
2930       return false;
2931 
2932     // Match: -1 == __kmpc_target_init (for non-SPMD kernels only!)
2933     if (C->isAllOnesValue()) {
2934       auto *CB = dyn_cast<CallBase>(Cmp->getOperand(0));
2935       auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2936       auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
2937       CB = CB ? OpenMPOpt::getCallIfRegularCall(*CB, &RFI) : nullptr;
2938       if (!CB)
2939         return false;
2940       ConstantStruct *KernelEnvC =
2941           KernelInfo::getKernelEnvironementFromKernelInitCB(CB);
2942       ConstantInt *ExecModeC =
2943           KernelInfo::getExecModeFromKernelEnvironment(KernelEnvC);
2944       return ExecModeC->getSExtValue() & OMP_TGT_EXEC_MODE_GENERIC;
2945     }
2946 
2947     if (C->isZero()) {
2948       // Match: 0 == llvm.nvvm.read.ptx.sreg.tid.x()
2949       if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
2950         if (II->getIntrinsicID() == Intrinsic::nvvm_read_ptx_sreg_tid_x)
2951           return true;
2952 
2953       // Match: 0 == llvm.amdgcn.workitem.id.x()
2954       if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
2955         if (II->getIntrinsicID() == Intrinsic::amdgcn_workitem_id_x)
2956           return true;
2957     }
2958 
2959     return false;
2960   };
2961 
2962   /// Mapping containing information about the function for other AAs.
2963   ExecutionDomainTy InterProceduralED;
2964 
2965   enum Direction { PRE = 0, POST = 1 };
2966   /// Mapping containing information per block.
2967   DenseMap<const BasicBlock *, ExecutionDomainTy> BEDMap;
2968   DenseMap<PointerIntPair<const CallBase *, 1, Direction>, ExecutionDomainTy>
2969       CEDMap;
2970   SmallSetVector<CallBase *, 16> AlignedBarriers;
2971 
2972   ReversePostOrderTraversal<Function *> *RPOT = nullptr;
2973 
2974   /// Set \p R to \V and report true if that changed \p R.
2975   static bool setAndRecord(bool &R, bool V) {
2976     bool Eq = (R == V);
2977     R = V;
2978     return !Eq;
2979   }
2980 
2981   /// Collection of fences known to be non-no-opt. All fences not in this set
2982   /// can be assumed no-opt.
2983   SmallPtrSet<const FenceInst *, 8> NonNoOpFences;
2984 };
2985 
2986 void AAExecutionDomainFunction::mergeInPredecessorBarriersAndAssumptions(
2987     Attributor &A, ExecutionDomainTy &ED, const ExecutionDomainTy &PredED) {
2988   for (auto *EA : PredED.EncounteredAssumes)
2989     ED.addAssumeInst(A, *EA);
2990 
2991   for (auto *AB : PredED.AlignedBarriers)
2992     ED.addAlignedBarrier(A, *AB);
2993 }
2994 
2995 bool AAExecutionDomainFunction::mergeInPredecessor(
2996     Attributor &A, ExecutionDomainTy &ED, const ExecutionDomainTy &PredED,
2997     bool InitialEdgeOnly) {
2998 
2999   bool Changed = false;
3000   Changed |=
3001       setAndRecord(ED.IsExecutedByInitialThreadOnly,
3002                    InitialEdgeOnly || (PredED.IsExecutedByInitialThreadOnly &&
3003                                        ED.IsExecutedByInitialThreadOnly));
3004 
3005   Changed |= setAndRecord(ED.IsReachedFromAlignedBarrierOnly,
3006                           ED.IsReachedFromAlignedBarrierOnly &&
3007                               PredED.IsReachedFromAlignedBarrierOnly);
3008   Changed |= setAndRecord(ED.EncounteredNonLocalSideEffect,
3009                           ED.EncounteredNonLocalSideEffect |
3010                               PredED.EncounteredNonLocalSideEffect);
3011   // Do not track assumptions and barriers as part of Changed.
3012   if (ED.IsReachedFromAlignedBarrierOnly)
3013     mergeInPredecessorBarriersAndAssumptions(A, ED, PredED);
3014   else
3015     ED.clearAssumeInstAndAlignedBarriers();
3016   return Changed;
3017 }
3018 
3019 bool AAExecutionDomainFunction::handleCallees(Attributor &A,
3020                                               ExecutionDomainTy &EntryBBED) {
3021   SmallVector<std::pair<ExecutionDomainTy, ExecutionDomainTy>, 4> CallSiteEDs;
3022   auto PredForCallSite = [&](AbstractCallSite ACS) {
3023     const auto *EDAA = A.getAAFor<AAExecutionDomain>(
3024         *this, IRPosition::function(*ACS.getInstruction()->getFunction()),
3025         DepClassTy::OPTIONAL);
3026     if (!EDAA || !EDAA->getState().isValidState())
3027       return false;
3028     CallSiteEDs.emplace_back(
3029         EDAA->getExecutionDomain(*cast<CallBase>(ACS.getInstruction())));
3030     return true;
3031   };
3032 
3033   ExecutionDomainTy ExitED;
3034   bool AllCallSitesKnown;
3035   if (A.checkForAllCallSites(PredForCallSite, *this,
3036                              /* RequiresAllCallSites */ true,
3037                              AllCallSitesKnown)) {
3038     for (const auto &[CSInED, CSOutED] : CallSiteEDs) {
3039       mergeInPredecessor(A, EntryBBED, CSInED);
3040       ExitED.IsReachingAlignedBarrierOnly &=
3041           CSOutED.IsReachingAlignedBarrierOnly;
3042     }
3043 
3044   } else {
3045     // We could not find all predecessors, so this is either a kernel or a
3046     // function with external linkage (or with some other weird uses).
3047     if (omp::isOpenMPKernel(*getAnchorScope())) {
3048       EntryBBED.IsExecutedByInitialThreadOnly = false;
3049       EntryBBED.IsReachedFromAlignedBarrierOnly = true;
3050       EntryBBED.EncounteredNonLocalSideEffect = false;
3051       ExitED.IsReachingAlignedBarrierOnly = false;
3052     } else {
3053       EntryBBED.IsExecutedByInitialThreadOnly = false;
3054       EntryBBED.IsReachedFromAlignedBarrierOnly = false;
3055       EntryBBED.EncounteredNonLocalSideEffect = true;
3056       ExitED.IsReachingAlignedBarrierOnly = false;
3057     }
3058   }
3059 
3060   bool Changed = false;
3061   auto &FnED = BEDMap[nullptr];
3062   Changed |= setAndRecord(FnED.IsReachedFromAlignedBarrierOnly,
3063                           FnED.IsReachedFromAlignedBarrierOnly &
3064                               EntryBBED.IsReachedFromAlignedBarrierOnly);
3065   Changed |= setAndRecord(FnED.IsReachingAlignedBarrierOnly,
3066                           FnED.IsReachingAlignedBarrierOnly &
3067                               ExitED.IsReachingAlignedBarrierOnly);
3068   Changed |= setAndRecord(FnED.IsExecutedByInitialThreadOnly,
3069                           EntryBBED.IsExecutedByInitialThreadOnly);
3070   return Changed;
3071 }
3072 
3073 ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
3074 
3075   bool Changed = false;
3076 
3077   // Helper to deal with an aligned barrier encountered during the forward
3078   // traversal. \p CB is the aligned barrier, \p ED is the execution domain when
3079   // it was encountered.
3080   auto HandleAlignedBarrier = [&](CallBase &CB, ExecutionDomainTy &ED) {
3081     Changed |= AlignedBarriers.insert(&CB);
3082     // First, update the barrier ED kept in the separate CEDMap.
3083     auto &CallInED = CEDMap[{&CB, PRE}];
3084     Changed |= mergeInPredecessor(A, CallInED, ED);
3085     CallInED.IsReachingAlignedBarrierOnly = true;
3086     // Next adjust the ED we use for the traversal.
3087     ED.EncounteredNonLocalSideEffect = false;
3088     ED.IsReachedFromAlignedBarrierOnly = true;
3089     // Aligned barrier collection has to come last.
3090     ED.clearAssumeInstAndAlignedBarriers();
3091     ED.addAlignedBarrier(A, CB);
3092     auto &CallOutED = CEDMap[{&CB, POST}];
3093     Changed |= mergeInPredecessor(A, CallOutED, ED);
3094   };
3095 
3096   auto *LivenessAA =
3097       A.getAAFor<AAIsDead>(*this, getIRPosition(), DepClassTy::OPTIONAL);
3098 
3099   Function *F = getAnchorScope();
3100   BasicBlock &EntryBB = F->getEntryBlock();
3101   bool IsKernel = omp::isOpenMPKernel(*F);
3102 
3103   SmallVector<Instruction *> SyncInstWorklist;
3104   for (auto &RIt : *RPOT) {
3105     BasicBlock &BB = *RIt;
3106 
3107     bool IsEntryBB = &BB == &EntryBB;
3108     // TODO: We use local reasoning since we don't have a divergence analysis
3109     // 	     running as well. We could basically allow uniform branches here.
3110     bool AlignedBarrierLastInBlock = IsEntryBB && IsKernel;
3111     bool IsExplicitlyAligned = IsEntryBB && IsKernel;
3112     ExecutionDomainTy ED;
3113     // Propagate "incoming edges" into information about this block.
3114     if (IsEntryBB) {
3115       Changed |= handleCallees(A, ED);
3116     } else {
3117       // For live non-entry blocks we only propagate
3118       // information via live edges.
3119       if (LivenessAA && LivenessAA->isAssumedDead(&BB))
3120         continue;
3121 
3122       for (auto *PredBB : predecessors(&BB)) {
3123         if (LivenessAA && LivenessAA->isEdgeDead(PredBB, &BB))
3124           continue;
3125         bool InitialEdgeOnly = isInitialThreadOnlyEdge(
3126             A, dyn_cast<BranchInst>(PredBB->getTerminator()), BB);
3127         mergeInPredecessor(A, ED, BEDMap[PredBB], InitialEdgeOnly);
3128       }
3129     }
3130 
3131     // Now we traverse the block, accumulate effects in ED and attach
3132     // information to calls.
3133     for (Instruction &I : BB) {
3134       bool UsedAssumedInformation;
3135       if (A.isAssumedDead(I, *this, LivenessAA, UsedAssumedInformation,
3136                           /* CheckBBLivenessOnly */ false, DepClassTy::OPTIONAL,
3137                           /* CheckForDeadStore */ true))
3138         continue;
3139 
3140       // Asummes and "assume-like" (dbg, lifetime, ...) are handled first, the
3141       // former is collected the latter is ignored.
3142       if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
3143         if (auto *AI = dyn_cast_or_null<AssumeInst>(II)) {
3144           ED.addAssumeInst(A, *AI);
3145           continue;
3146         }
3147         // TODO: Should we also collect and delete lifetime markers?
3148         if (II->isAssumeLikeIntrinsic())
3149           continue;
3150       }
3151 
3152       if (auto *FI = dyn_cast<FenceInst>(&I)) {
3153         if (!ED.EncounteredNonLocalSideEffect) {
3154           // An aligned fence without non-local side-effects is a no-op.
3155           if (ED.IsReachedFromAlignedBarrierOnly)
3156             continue;
3157           // A non-aligned fence without non-local side-effects is a no-op
3158           // if the ordering only publishes non-local side-effects (or less).
3159           switch (FI->getOrdering()) {
3160           case AtomicOrdering::NotAtomic:
3161             continue;
3162           case AtomicOrdering::Unordered:
3163             continue;
3164           case AtomicOrdering::Monotonic:
3165             continue;
3166           case AtomicOrdering::Acquire:
3167             break;
3168           case AtomicOrdering::Release:
3169             continue;
3170           case AtomicOrdering::AcquireRelease:
3171             break;
3172           case AtomicOrdering::SequentiallyConsistent:
3173             break;
3174           };
3175         }
3176         NonNoOpFences.insert(FI);
3177       }
3178 
3179       auto *CB = dyn_cast<CallBase>(&I);
3180       bool IsNoSync = AA::isNoSyncInst(A, I, *this);
3181       bool IsAlignedBarrier =
3182           !IsNoSync && CB &&
3183           AANoSync::isAlignedBarrier(*CB, AlignedBarrierLastInBlock);
3184 
3185       AlignedBarrierLastInBlock &= IsNoSync;
3186       IsExplicitlyAligned &= IsNoSync;
3187 
3188       // Next we check for calls. Aligned barriers are handled
3189       // explicitly, everything else is kept for the backward traversal and will
3190       // also affect our state.
3191       if (CB) {
3192         if (IsAlignedBarrier) {
3193           HandleAlignedBarrier(*CB, ED);
3194           AlignedBarrierLastInBlock = true;
3195           IsExplicitlyAligned = true;
3196           continue;
3197         }
3198 
3199         // Check the pointer(s) of a memory intrinsic explicitly.
3200         if (isa<MemIntrinsic>(&I)) {
3201           if (!ED.EncounteredNonLocalSideEffect &&
3202               AA::isPotentiallyAffectedByBarrier(A, I, *this))
3203             ED.EncounteredNonLocalSideEffect = true;
3204           if (!IsNoSync) {
3205             ED.IsReachedFromAlignedBarrierOnly = false;
3206             SyncInstWorklist.push_back(&I);
3207           }
3208           continue;
3209         }
3210 
3211         // Record how we entered the call, then accumulate the effect of the
3212         // call in ED for potential use by the callee.
3213         auto &CallInED = CEDMap[{CB, PRE}];
3214         Changed |= mergeInPredecessor(A, CallInED, ED);
3215 
3216         // If we have a sync-definition we can check if it starts/ends in an
3217         // aligned barrier. If we are unsure we assume any sync breaks
3218         // alignment.
3219         Function *Callee = CB->getCalledFunction();
3220         if (!IsNoSync && Callee && !Callee->isDeclaration()) {
3221           const auto *EDAA = A.getAAFor<AAExecutionDomain>(
3222               *this, IRPosition::function(*Callee), DepClassTy::OPTIONAL);
3223           if (EDAA && EDAA->getState().isValidState()) {
3224             const auto &CalleeED = EDAA->getFunctionExecutionDomain();
3225             ED.IsReachedFromAlignedBarrierOnly =
3226                 CalleeED.IsReachedFromAlignedBarrierOnly;
3227             AlignedBarrierLastInBlock = ED.IsReachedFromAlignedBarrierOnly;
3228             if (IsNoSync || !CalleeED.IsReachedFromAlignedBarrierOnly)
3229               ED.EncounteredNonLocalSideEffect |=
3230                   CalleeED.EncounteredNonLocalSideEffect;
3231             else
3232               ED.EncounteredNonLocalSideEffect =
3233                   CalleeED.EncounteredNonLocalSideEffect;
3234             if (!CalleeED.IsReachingAlignedBarrierOnly) {
3235               Changed |=
3236                   setAndRecord(CallInED.IsReachingAlignedBarrierOnly, false);
3237               SyncInstWorklist.push_back(&I);
3238             }
3239             if (CalleeED.IsReachedFromAlignedBarrierOnly)
3240               mergeInPredecessorBarriersAndAssumptions(A, ED, CalleeED);
3241             auto &CallOutED = CEDMap[{CB, POST}];
3242             Changed |= mergeInPredecessor(A, CallOutED, ED);
3243             continue;
3244           }
3245         }
3246         if (!IsNoSync) {
3247           ED.IsReachedFromAlignedBarrierOnly = false;
3248           Changed |= setAndRecord(CallInED.IsReachingAlignedBarrierOnly, false);
3249           SyncInstWorklist.push_back(&I);
3250         }
3251         AlignedBarrierLastInBlock &= ED.IsReachedFromAlignedBarrierOnly;
3252         ED.EncounteredNonLocalSideEffect |= !CB->doesNotAccessMemory();
3253         auto &CallOutED = CEDMap[{CB, POST}];
3254         Changed |= mergeInPredecessor(A, CallOutED, ED);
3255       }
3256 
3257       if (!I.mayHaveSideEffects() && !I.mayReadFromMemory())
3258         continue;
3259 
3260       // If we have a callee we try to use fine-grained information to
3261       // determine local side-effects.
3262       if (CB) {
3263         const auto *MemAA = A.getAAFor<AAMemoryLocation>(
3264             *this, IRPosition::callsite_function(*CB), DepClassTy::OPTIONAL);
3265 
3266         auto AccessPred = [&](const Instruction *I, const Value *Ptr,
3267                               AAMemoryLocation::AccessKind,
3268                               AAMemoryLocation::MemoryLocationsKind) {
3269           return !AA::isPotentiallyAffectedByBarrier(A, {Ptr}, *this, I);
3270         };
3271         if (MemAA && MemAA->getState().isValidState() &&
3272             MemAA->checkForAllAccessesToMemoryKind(
3273                 AccessPred, AAMemoryLocation::ALL_LOCATIONS))
3274           continue;
3275       }
3276 
3277       auto &InfoCache = A.getInfoCache();
3278       if (!I.mayHaveSideEffects() && InfoCache.isOnlyUsedByAssume(I))
3279         continue;
3280 
3281       if (auto *LI = dyn_cast<LoadInst>(&I))
3282         if (LI->hasMetadata(LLVMContext::MD_invariant_load))
3283           continue;
3284 
3285       if (!ED.EncounteredNonLocalSideEffect &&
3286           AA::isPotentiallyAffectedByBarrier(A, I, *this))
3287         ED.EncounteredNonLocalSideEffect = true;
3288     }
3289 
3290     bool IsEndAndNotReachingAlignedBarriersOnly = false;
3291     if (!isa<UnreachableInst>(BB.getTerminator()) &&
3292         !BB.getTerminator()->getNumSuccessors()) {
3293 
3294       Changed |= mergeInPredecessor(A, InterProceduralED, ED);
3295 
3296       auto &FnED = BEDMap[nullptr];
3297       if (IsKernel && !IsExplicitlyAligned)
3298         FnED.IsReachingAlignedBarrierOnly = false;
3299       Changed |= mergeInPredecessor(A, FnED, ED);
3300 
3301       if (!FnED.IsReachingAlignedBarrierOnly) {
3302         IsEndAndNotReachingAlignedBarriersOnly = true;
3303         SyncInstWorklist.push_back(BB.getTerminator());
3304         auto &BBED = BEDMap[&BB];
3305         Changed |= setAndRecord(BBED.IsReachingAlignedBarrierOnly, false);
3306       }
3307     }
3308 
3309     ExecutionDomainTy &StoredED = BEDMap[&BB];
3310     ED.IsReachingAlignedBarrierOnly = StoredED.IsReachingAlignedBarrierOnly &
3311                                       !IsEndAndNotReachingAlignedBarriersOnly;
3312 
3313     // Check if we computed anything different as part of the forward
3314     // traversal. We do not take assumptions and aligned barriers into account
3315     // as they do not influence the state we iterate. Backward traversal values
3316     // are handled later on.
3317     if (ED.IsExecutedByInitialThreadOnly !=
3318             StoredED.IsExecutedByInitialThreadOnly ||
3319         ED.IsReachedFromAlignedBarrierOnly !=
3320             StoredED.IsReachedFromAlignedBarrierOnly ||
3321         ED.EncounteredNonLocalSideEffect !=
3322             StoredED.EncounteredNonLocalSideEffect)
3323       Changed = true;
3324 
3325     // Update the state with the new value.
3326     StoredED = std::move(ED);
3327   }
3328 
3329   // Propagate (non-aligned) sync instruction effects backwards until the
3330   // entry is hit or an aligned barrier.
3331   SmallSetVector<BasicBlock *, 16> Visited;
3332   while (!SyncInstWorklist.empty()) {
3333     Instruction *SyncInst = SyncInstWorklist.pop_back_val();
3334     Instruction *CurInst = SyncInst;
3335     bool HitAlignedBarrierOrKnownEnd = false;
3336     while ((CurInst = CurInst->getPrevNode())) {
3337       auto *CB = dyn_cast<CallBase>(CurInst);
3338       if (!CB)
3339         continue;
3340       auto &CallOutED = CEDMap[{CB, POST}];
3341       Changed |= setAndRecord(CallOutED.IsReachingAlignedBarrierOnly, false);
3342       auto &CallInED = CEDMap[{CB, PRE}];
3343       HitAlignedBarrierOrKnownEnd =
3344           AlignedBarriers.count(CB) || !CallInED.IsReachingAlignedBarrierOnly;
3345       if (HitAlignedBarrierOrKnownEnd)
3346         break;
3347       Changed |= setAndRecord(CallInED.IsReachingAlignedBarrierOnly, false);
3348     }
3349     if (HitAlignedBarrierOrKnownEnd)
3350       continue;
3351     BasicBlock *SyncBB = SyncInst->getParent();
3352     for (auto *PredBB : predecessors(SyncBB)) {
3353       if (LivenessAA && LivenessAA->isEdgeDead(PredBB, SyncBB))
3354         continue;
3355       if (!Visited.insert(PredBB))
3356         continue;
3357       auto &PredED = BEDMap[PredBB];
3358       if (setAndRecord(PredED.IsReachingAlignedBarrierOnly, false)) {
3359         Changed = true;
3360         SyncInstWorklist.push_back(PredBB->getTerminator());
3361       }
3362     }
3363     if (SyncBB != &EntryBB)
3364       continue;
3365     Changed |=
3366         setAndRecord(InterProceduralED.IsReachingAlignedBarrierOnly, false);
3367   }
3368 
3369   return Changed ? ChangeStatus::CHANGED : ChangeStatus::UNCHANGED;
3370 }
3371 
3372 /// Try to replace memory allocation calls called by a single thread with a
3373 /// static buffer of shared memory.
3374 struct AAHeapToShared : public StateWrapper<BooleanState, AbstractAttribute> {
3375   using Base = StateWrapper<BooleanState, AbstractAttribute>;
3376   AAHeapToShared(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
3377 
3378   /// Create an abstract attribute view for the position \p IRP.
3379   static AAHeapToShared &createForPosition(const IRPosition &IRP,
3380                                            Attributor &A);
3381 
3382   /// Returns true if HeapToShared conversion is assumed to be possible.
3383   virtual bool isAssumedHeapToShared(CallBase &CB) const = 0;
3384 
3385   /// Returns true if HeapToShared conversion is assumed and the CB is a
3386   /// callsite to a free operation to be removed.
3387   virtual bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const = 0;
3388 
3389   /// See AbstractAttribute::getName().
3390   const std::string getName() const override { return "AAHeapToShared"; }
3391 
3392   /// See AbstractAttribute::getIdAddr().
3393   const char *getIdAddr() const override { return &ID; }
3394 
3395   /// This function should return true if the type of the \p AA is
3396   /// AAHeapToShared.
3397   static bool classof(const AbstractAttribute *AA) {
3398     return (AA->getIdAddr() == &ID);
3399   }
3400 
3401   /// Unique ID (due to the unique address)
3402   static const char ID;
3403 };
3404 
3405 struct AAHeapToSharedFunction : public AAHeapToShared {
3406   AAHeapToSharedFunction(const IRPosition &IRP, Attributor &A)
3407       : AAHeapToShared(IRP, A) {}
3408 
3409   const std::string getAsStr(Attributor *) const override {
3410     return "[AAHeapToShared] " + std::to_string(MallocCalls.size()) +
3411            " malloc calls eligible.";
3412   }
3413 
3414   /// See AbstractAttribute::trackStatistics().
3415   void trackStatistics() const override {}
3416 
3417   /// This functions finds free calls that will be removed by the
3418   /// HeapToShared transformation.
3419   void findPotentialRemovedFreeCalls(Attributor &A) {
3420     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3421     auto &FreeRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
3422 
3423     PotentialRemovedFreeCalls.clear();
3424     // Update free call users of found malloc calls.
3425     for (CallBase *CB : MallocCalls) {
3426       SmallVector<CallBase *, 4> FreeCalls;
3427       for (auto *U : CB->users()) {
3428         CallBase *C = dyn_cast<CallBase>(U);
3429         if (C && C->getCalledFunction() == FreeRFI.Declaration)
3430           FreeCalls.push_back(C);
3431       }
3432 
3433       if (FreeCalls.size() != 1)
3434         continue;
3435 
3436       PotentialRemovedFreeCalls.insert(FreeCalls.front());
3437     }
3438   }
3439 
3440   void initialize(Attributor &A) override {
3441     if (DisableOpenMPOptDeglobalization) {
3442       indicatePessimisticFixpoint();
3443       return;
3444     }
3445 
3446     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3447     auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3448     if (!RFI.Declaration)
3449       return;
3450 
3451     Attributor::SimplifictionCallbackTy SCB =
3452         [](const IRPosition &, const AbstractAttribute *,
3453            bool &) -> std::optional<Value *> { return nullptr; };
3454 
3455     Function *F = getAnchorScope();
3456     for (User *U : RFI.Declaration->users())
3457       if (CallBase *CB = dyn_cast<CallBase>(U)) {
3458         if (CB->getFunction() != F)
3459           continue;
3460         MallocCalls.insert(CB);
3461         A.registerSimplificationCallback(IRPosition::callsite_returned(*CB),
3462                                          SCB);
3463       }
3464 
3465     findPotentialRemovedFreeCalls(A);
3466   }
3467 
3468   bool isAssumedHeapToShared(CallBase &CB) const override {
3469     return isValidState() && MallocCalls.count(&CB);
3470   }
3471 
3472   bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const override {
3473     return isValidState() && PotentialRemovedFreeCalls.count(&CB);
3474   }
3475 
3476   ChangeStatus manifest(Attributor &A) override {
3477     if (MallocCalls.empty())
3478       return ChangeStatus::UNCHANGED;
3479 
3480     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3481     auto &FreeCall = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
3482 
3483     Function *F = getAnchorScope();
3484     auto *HS = A.lookupAAFor<AAHeapToStack>(IRPosition::function(*F), this,
3485                                             DepClassTy::OPTIONAL);
3486 
3487     ChangeStatus Changed = ChangeStatus::UNCHANGED;
3488     for (CallBase *CB : MallocCalls) {
3489       // Skip replacing this if HeapToStack has already claimed it.
3490       if (HS && HS->isAssumedHeapToStack(*CB))
3491         continue;
3492 
3493       // Find the unique free call to remove it.
3494       SmallVector<CallBase *, 4> FreeCalls;
3495       for (auto *U : CB->users()) {
3496         CallBase *C = dyn_cast<CallBase>(U);
3497         if (C && C->getCalledFunction() == FreeCall.Declaration)
3498           FreeCalls.push_back(C);
3499       }
3500       if (FreeCalls.size() != 1)
3501         continue;
3502 
3503       auto *AllocSize = cast<ConstantInt>(CB->getArgOperand(0));
3504 
3505       if (AllocSize->getZExtValue() + SharedMemoryUsed > SharedMemoryLimit) {
3506         LLVM_DEBUG(dbgs() << TAG << "Cannot replace call " << *CB
3507                           << " with shared memory."
3508                           << " Shared memory usage is limited to "
3509                           << SharedMemoryLimit << " bytes\n");
3510         continue;
3511       }
3512 
3513       LLVM_DEBUG(dbgs() << TAG << "Replace globalization call " << *CB
3514                         << " with " << AllocSize->getZExtValue()
3515                         << " bytes of shared memory\n");
3516 
3517       // Create a new shared memory buffer of the same size as the allocation
3518       // and replace all the uses of the original allocation with it.
3519       Module *M = CB->getModule();
3520       Type *Int8Ty = Type::getInt8Ty(M->getContext());
3521       Type *Int8ArrTy = ArrayType::get(Int8Ty, AllocSize->getZExtValue());
3522       auto *SharedMem = new GlobalVariable(
3523           *M, Int8ArrTy, /* IsConstant */ false, GlobalValue::InternalLinkage,
3524           PoisonValue::get(Int8ArrTy), CB->getName() + "_shared", nullptr,
3525           GlobalValue::NotThreadLocal,
3526           static_cast<unsigned>(AddressSpace::Shared));
3527       auto *NewBuffer =
3528           ConstantExpr::getPointerCast(SharedMem, Int8Ty->getPointerTo());
3529 
3530       auto Remark = [&](OptimizationRemark OR) {
3531         return OR << "Replaced globalized variable with "
3532                   << ore::NV("SharedMemory", AllocSize->getZExtValue())
3533                   << (AllocSize->isOne() ? " byte " : " bytes ")
3534                   << "of shared memory.";
3535       };
3536       A.emitRemark<OptimizationRemark>(CB, "OMP111", Remark);
3537 
3538       MaybeAlign Alignment = CB->getRetAlign();
3539       assert(Alignment &&
3540              "HeapToShared on allocation without alignment attribute");
3541       SharedMem->setAlignment(*Alignment);
3542 
3543       A.changeAfterManifest(IRPosition::callsite_returned(*CB), *NewBuffer);
3544       A.deleteAfterManifest(*CB);
3545       A.deleteAfterManifest(*FreeCalls.front());
3546 
3547       SharedMemoryUsed += AllocSize->getZExtValue();
3548       NumBytesMovedToSharedMemory = SharedMemoryUsed;
3549       Changed = ChangeStatus::CHANGED;
3550     }
3551 
3552     return Changed;
3553   }
3554 
3555   ChangeStatus updateImpl(Attributor &A) override {
3556     if (MallocCalls.empty())
3557       return indicatePessimisticFixpoint();
3558     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3559     auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3560     if (!RFI.Declaration)
3561       return ChangeStatus::UNCHANGED;
3562 
3563     Function *F = getAnchorScope();
3564 
3565     auto NumMallocCalls = MallocCalls.size();
3566 
3567     // Only consider malloc calls executed by a single thread with a constant.
3568     for (User *U : RFI.Declaration->users()) {
3569       if (CallBase *CB = dyn_cast<CallBase>(U)) {
3570         if (CB->getCaller() != F)
3571           continue;
3572         if (!MallocCalls.count(CB))
3573           continue;
3574         if (!isa<ConstantInt>(CB->getArgOperand(0))) {
3575           MallocCalls.remove(CB);
3576           continue;
3577         }
3578         const auto *ED = A.getAAFor<AAExecutionDomain>(
3579             *this, IRPosition::function(*F), DepClassTy::REQUIRED);
3580         if (!ED || !ED->isExecutedByInitialThreadOnly(*CB))
3581           MallocCalls.remove(CB);
3582       }
3583     }
3584 
3585     findPotentialRemovedFreeCalls(A);
3586 
3587     if (NumMallocCalls != MallocCalls.size())
3588       return ChangeStatus::CHANGED;
3589 
3590     return ChangeStatus::UNCHANGED;
3591   }
3592 
3593   /// Collection of all malloc calls in a function.
3594   SmallSetVector<CallBase *, 4> MallocCalls;
3595   /// Collection of potentially removed free calls in a function.
3596   SmallPtrSet<CallBase *, 4> PotentialRemovedFreeCalls;
3597   /// The total amount of shared memory that has been used for HeapToShared.
3598   unsigned SharedMemoryUsed = 0;
3599 };
3600 
3601 struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
3602   using Base = StateWrapper<KernelInfoState, AbstractAttribute>;
3603   AAKernelInfo(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
3604 
3605   /// The callee value is tracked beyond a simple stripPointerCasts, so we allow
3606   /// unknown callees.
3607   static bool requiresCalleeForCallBase() { return false; }
3608 
3609   /// Statistics are tracked as part of manifest for now.
3610   void trackStatistics() const override {}
3611 
3612   /// See AbstractAttribute::getAsStr()
3613   const std::string getAsStr(Attributor *) const override {
3614     if (!isValidState())
3615       return "<invalid>";
3616     return std::string(SPMDCompatibilityTracker.isAssumed() ? "SPMD"
3617                                                             : "generic") +
3618            std::string(SPMDCompatibilityTracker.isAtFixpoint() ? " [FIX]"
3619                                                                : "") +
3620            std::string(" #PRs: ") +
3621            (ReachedKnownParallelRegions.isValidState()
3622                 ? std::to_string(ReachedKnownParallelRegions.size())
3623                 : "<invalid>") +
3624            ", #Unknown PRs: " +
3625            (ReachedUnknownParallelRegions.isValidState()
3626                 ? std::to_string(ReachedUnknownParallelRegions.size())
3627                 : "<invalid>") +
3628            ", #Reaching Kernels: " +
3629            (ReachingKernelEntries.isValidState()
3630                 ? std::to_string(ReachingKernelEntries.size())
3631                 : "<invalid>") +
3632            ", #ParLevels: " +
3633            (ParallelLevels.isValidState()
3634                 ? std::to_string(ParallelLevels.size())
3635                 : "<invalid>") +
3636            ", NestedPar: " + (NestedParallelism ? "yes" : "no");
3637   }
3638 
3639   /// Create an abstract attribute biew for the position \p IRP.
3640   static AAKernelInfo &createForPosition(const IRPosition &IRP, Attributor &A);
3641 
3642   /// See AbstractAttribute::getName()
3643   const std::string getName() const override { return "AAKernelInfo"; }
3644 
3645   /// See AbstractAttribute::getIdAddr()
3646   const char *getIdAddr() const override { return &ID; }
3647 
3648   /// This function should return true if the type of the \p AA is AAKernelInfo
3649   static bool classof(const AbstractAttribute *AA) {
3650     return (AA->getIdAddr() == &ID);
3651   }
3652 
3653   static const char ID;
3654 };
3655 
3656 /// The function kernel info abstract attribute, basically, what can we say
3657 /// about a function with regards to the KernelInfoState.
3658 struct AAKernelInfoFunction : AAKernelInfo {
3659   AAKernelInfoFunction(const IRPosition &IRP, Attributor &A)
3660       : AAKernelInfo(IRP, A) {}
3661 
3662   SmallPtrSet<Instruction *, 4> GuardedInstructions;
3663 
3664   SmallPtrSetImpl<Instruction *> &getGuardedInstructions() {
3665     return GuardedInstructions;
3666   }
3667 
3668   void setConfigurationOfKernelEnvironment(ConstantStruct *ConfigC) {
3669     Constant *NewKernelEnvC = ConstantFoldInsertValueInstruction(
3670         KernelEnvC, ConfigC, {KernelInfo::ConfigurationIdx});
3671     assert(NewKernelEnvC && "Failed to create new kernel environment");
3672     KernelEnvC = cast<ConstantStruct>(NewKernelEnvC);
3673   }
3674 
3675 #define KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MEMBER)                        \
3676   void set##MEMBER##OfKernelEnvironment(ConstantInt *NewVal) {                 \
3677     ConstantStruct *ConfigC =                                                  \
3678         KernelInfo::getConfigurationFromKernelEnvironment(KernelEnvC);         \
3679     Constant *NewConfigC = ConstantFoldInsertValueInstruction(                 \
3680         ConfigC, NewVal, {KernelInfo::MEMBER##Idx});                           \
3681     assert(NewConfigC && "Failed to create new configuration environment");    \
3682     setConfigurationOfKernelEnvironment(cast<ConstantStruct>(NewConfigC));     \
3683   }
3684 
3685   KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(UseGenericStateMachine)
3686   KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MayUseNestedParallelism)
3687   KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(ExecMode)
3688   KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MinThreads)
3689   KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MaxThreads)
3690   KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MinTeams)
3691   KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MaxTeams)
3692 
3693 #undef KERNEL_ENVIRONMENT_CONFIGURATION_SETTER
3694 
3695   /// See AbstractAttribute::initialize(...).
3696   void initialize(Attributor &A) override {
3697     // This is a high-level transform that might change the constant arguments
3698     // of the init and dinit calls. We need to tell the Attributor about this
3699     // to avoid other parts using the current constant value for simpliication.
3700     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3701 
3702     Function *Fn = getAnchorScope();
3703 
3704     OMPInformationCache::RuntimeFunctionInfo &InitRFI =
3705         OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
3706     OMPInformationCache::RuntimeFunctionInfo &DeinitRFI =
3707         OMPInfoCache.RFIs[OMPRTL___kmpc_target_deinit];
3708 
3709     // For kernels we perform more initialization work, first we find the init
3710     // and deinit calls.
3711     auto StoreCallBase = [](Use &U,
3712                             OMPInformationCache::RuntimeFunctionInfo &RFI,
3713                             CallBase *&Storage) {
3714       CallBase *CB = OpenMPOpt::getCallIfRegularCall(U, &RFI);
3715       assert(CB &&
3716              "Unexpected use of __kmpc_target_init or __kmpc_target_deinit!");
3717       assert(!Storage &&
3718              "Multiple uses of __kmpc_target_init or __kmpc_target_deinit!");
3719       Storage = CB;
3720       return false;
3721     };
3722     InitRFI.foreachUse(
3723         [&](Use &U, Function &) {
3724           StoreCallBase(U, InitRFI, KernelInitCB);
3725           return false;
3726         },
3727         Fn);
3728     DeinitRFI.foreachUse(
3729         [&](Use &U, Function &) {
3730           StoreCallBase(U, DeinitRFI, KernelDeinitCB);
3731           return false;
3732         },
3733         Fn);
3734 
3735     // Ignore kernels without initializers such as global constructors.
3736     if (!KernelInitCB || !KernelDeinitCB)
3737       return;
3738 
3739     // Add itself to the reaching kernel and set IsKernelEntry.
3740     ReachingKernelEntries.insert(Fn);
3741     IsKernelEntry = true;
3742 
3743     KernelEnvC =
3744         KernelInfo::getKernelEnvironementFromKernelInitCB(KernelInitCB);
3745     GlobalVariable *KernelEnvGV =
3746         KernelInfo::getKernelEnvironementGVFromKernelInitCB(KernelInitCB);
3747 
3748     Attributor::GlobalVariableSimplifictionCallbackTy
3749         KernelConfigurationSimplifyCB =
3750             [&](const GlobalVariable &GV, const AbstractAttribute *AA,
3751                 bool &UsedAssumedInformation) -> std::optional<Constant *> {
3752       if (!isAtFixpoint()) {
3753         if (!AA)
3754           return nullptr;
3755         UsedAssumedInformation = true;
3756         A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
3757       }
3758       return KernelEnvC;
3759     };
3760 
3761     A.registerGlobalVariableSimplificationCallback(
3762         *KernelEnvGV, KernelConfigurationSimplifyCB);
3763 
3764     // Check if we know we are in SPMD-mode already.
3765     ConstantInt *ExecModeC =
3766         KernelInfo::getExecModeFromKernelEnvironment(KernelEnvC);
3767     ConstantInt *AssumedExecModeC = ConstantInt::get(
3768         ExecModeC->getIntegerType(),
3769         ExecModeC->getSExtValue() | OMP_TGT_EXEC_MODE_GENERIC_SPMD);
3770     if (ExecModeC->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD)
3771       SPMDCompatibilityTracker.indicateOptimisticFixpoint();
3772     else if (DisableOpenMPOptSPMDization)
3773       // This is a generic region but SPMDization is disabled so stop
3774       // tracking.
3775       SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3776     else
3777       setExecModeOfKernelEnvironment(AssumedExecModeC);
3778 
3779     const Triple T(Fn->getParent()->getTargetTriple());
3780     auto *Int32Ty = Type::getInt32Ty(Fn->getContext());
3781     auto [MinThreads, MaxThreads] =
3782         OpenMPIRBuilder::readThreadBoundsForKernel(T, *Fn);
3783     if (MinThreads)
3784       setMinThreadsOfKernelEnvironment(ConstantInt::get(Int32Ty, MinThreads));
3785     if (MaxThreads)
3786       setMaxThreadsOfKernelEnvironment(ConstantInt::get(Int32Ty, MaxThreads));
3787     auto [MinTeams, MaxTeams] =
3788         OpenMPIRBuilder::readTeamBoundsForKernel(T, *Fn);
3789     if (MinTeams)
3790       setMinTeamsOfKernelEnvironment(ConstantInt::get(Int32Ty, MinTeams));
3791     if (MaxTeams)
3792       setMaxTeamsOfKernelEnvironment(ConstantInt::get(Int32Ty, MaxTeams));
3793 
3794     ConstantInt *MayUseNestedParallelismC =
3795         KernelInfo::getMayUseNestedParallelismFromKernelEnvironment(KernelEnvC);
3796     ConstantInt *AssumedMayUseNestedParallelismC = ConstantInt::get(
3797         MayUseNestedParallelismC->getIntegerType(), NestedParallelism);
3798     setMayUseNestedParallelismOfKernelEnvironment(
3799         AssumedMayUseNestedParallelismC);
3800 
3801     if (!DisableOpenMPOptStateMachineRewrite) {
3802       ConstantInt *UseGenericStateMachineC =
3803           KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
3804               KernelEnvC);
3805       ConstantInt *AssumedUseGenericStateMachineC =
3806           ConstantInt::get(UseGenericStateMachineC->getIntegerType(), false);
3807       setUseGenericStateMachineOfKernelEnvironment(
3808           AssumedUseGenericStateMachineC);
3809     }
3810 
3811     // Register virtual uses of functions we might need to preserve.
3812     auto RegisterVirtualUse = [&](RuntimeFunction RFKind,
3813                                   Attributor::VirtualUseCallbackTy &CB) {
3814       if (!OMPInfoCache.RFIs[RFKind].Declaration)
3815         return;
3816       A.registerVirtualUseCallback(*OMPInfoCache.RFIs[RFKind].Declaration, CB);
3817     };
3818 
3819     // Add a dependence to ensure updates if the state changes.
3820     auto AddDependence = [](Attributor &A, const AAKernelInfo *KI,
3821                             const AbstractAttribute *QueryingAA) {
3822       if (QueryingAA) {
3823         A.recordDependence(*KI, *QueryingAA, DepClassTy::OPTIONAL);
3824       }
3825       return true;
3826     };
3827 
3828     Attributor::VirtualUseCallbackTy CustomStateMachineUseCB =
3829         [&](Attributor &A, const AbstractAttribute *QueryingAA) {
3830           // Whenever we create a custom state machine we will insert calls to
3831           // __kmpc_get_hardware_num_threads_in_block,
3832           // __kmpc_get_warp_size,
3833           // __kmpc_barrier_simple_generic,
3834           // __kmpc_kernel_parallel, and
3835           // __kmpc_kernel_end_parallel.
3836           // Not needed if we are on track for SPMDzation.
3837           if (SPMDCompatibilityTracker.isValidState())
3838             return AddDependence(A, this, QueryingAA);
3839           // Not needed if we can't rewrite due to an invalid state.
3840           if (!ReachedKnownParallelRegions.isValidState())
3841             return AddDependence(A, this, QueryingAA);
3842           return false;
3843         };
3844 
3845     // Not needed if we are pre-runtime merge.
3846     if (!KernelInitCB->getCalledFunction()->isDeclaration()) {
3847       RegisterVirtualUse(OMPRTL___kmpc_get_hardware_num_threads_in_block,
3848                          CustomStateMachineUseCB);
3849       RegisterVirtualUse(OMPRTL___kmpc_get_warp_size, CustomStateMachineUseCB);
3850       RegisterVirtualUse(OMPRTL___kmpc_barrier_simple_generic,
3851                          CustomStateMachineUseCB);
3852       RegisterVirtualUse(OMPRTL___kmpc_kernel_parallel,
3853                          CustomStateMachineUseCB);
3854       RegisterVirtualUse(OMPRTL___kmpc_kernel_end_parallel,
3855                          CustomStateMachineUseCB);
3856     }
3857 
3858     // If we do not perform SPMDzation we do not need the virtual uses below.
3859     if (SPMDCompatibilityTracker.isAtFixpoint())
3860       return;
3861 
3862     Attributor::VirtualUseCallbackTy HWThreadIdUseCB =
3863         [&](Attributor &A, const AbstractAttribute *QueryingAA) {
3864           // Whenever we perform SPMDzation we will insert
3865           // __kmpc_get_hardware_thread_id_in_block calls.
3866           if (!SPMDCompatibilityTracker.isValidState())
3867             return AddDependence(A, this, QueryingAA);
3868           return false;
3869         };
3870     RegisterVirtualUse(OMPRTL___kmpc_get_hardware_thread_id_in_block,
3871                        HWThreadIdUseCB);
3872 
3873     Attributor::VirtualUseCallbackTy SPMDBarrierUseCB =
3874         [&](Attributor &A, const AbstractAttribute *QueryingAA) {
3875           // Whenever we perform SPMDzation with guarding we will insert
3876           // __kmpc_simple_barrier_spmd calls. If SPMDzation failed, there is
3877           // nothing to guard, or there are no parallel regions, we don't need
3878           // the calls.
3879           if (!SPMDCompatibilityTracker.isValidState())
3880             return AddDependence(A, this, QueryingAA);
3881           if (SPMDCompatibilityTracker.empty())
3882             return AddDependence(A, this, QueryingAA);
3883           if (!mayContainParallelRegion())
3884             return AddDependence(A, this, QueryingAA);
3885           return false;
3886         };
3887     RegisterVirtualUse(OMPRTL___kmpc_barrier_simple_spmd, SPMDBarrierUseCB);
3888   }
3889 
3890   /// Sanitize the string \p S such that it is a suitable global symbol name.
3891   static std::string sanitizeForGlobalName(std::string S) {
3892     std::replace_if(
3893         S.begin(), S.end(),
3894         [](const char C) {
3895           return !((C >= 'a' && C <= 'z') || (C >= 'A' && C <= 'Z') ||
3896                    (C >= '0' && C <= '9') || C == '_');
3897         },
3898         '.');
3899     return S;
3900   }
3901 
3902   /// Modify the IR based on the KernelInfoState as the fixpoint iteration is
3903   /// finished now.
3904   ChangeStatus manifest(Attributor &A) override {
3905     // If we are not looking at a kernel with __kmpc_target_init and
3906     // __kmpc_target_deinit call we cannot actually manifest the information.
3907     if (!KernelInitCB || !KernelDeinitCB)
3908       return ChangeStatus::UNCHANGED;
3909 
3910     ChangeStatus Changed = ChangeStatus::UNCHANGED;
3911 
3912     bool HasBuiltStateMachine = true;
3913     if (!changeToSPMDMode(A, Changed)) {
3914       if (!KernelInitCB->getCalledFunction()->isDeclaration())
3915         HasBuiltStateMachine = buildCustomStateMachine(A, Changed);
3916       else
3917         HasBuiltStateMachine = false;
3918     }
3919 
3920     // We need to reset KernelEnvC if specific rewriting is not done.
3921     ConstantStruct *ExistingKernelEnvC =
3922         KernelInfo::getKernelEnvironementFromKernelInitCB(KernelInitCB);
3923     ConstantInt *OldUseGenericStateMachineVal =
3924         KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
3925             ExistingKernelEnvC);
3926     if (!HasBuiltStateMachine)
3927       setUseGenericStateMachineOfKernelEnvironment(
3928           OldUseGenericStateMachineVal);
3929 
3930     // At last, update the KernelEnvc
3931     GlobalVariable *KernelEnvGV =
3932         KernelInfo::getKernelEnvironementGVFromKernelInitCB(KernelInitCB);
3933     if (KernelEnvGV->getInitializer() != KernelEnvC) {
3934       KernelEnvGV->setInitializer(KernelEnvC);
3935       Changed = ChangeStatus::CHANGED;
3936     }
3937 
3938     return Changed;
3939   }
3940 
3941   void insertInstructionGuardsHelper(Attributor &A) {
3942     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3943 
3944     auto CreateGuardedRegion = [&](Instruction *RegionStartI,
3945                                    Instruction *RegionEndI) {
3946       LoopInfo *LI = nullptr;
3947       DominatorTree *DT = nullptr;
3948       MemorySSAUpdater *MSU = nullptr;
3949       using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
3950 
3951       BasicBlock *ParentBB = RegionStartI->getParent();
3952       Function *Fn = ParentBB->getParent();
3953       Module &M = *Fn->getParent();
3954 
3955       // Create all the blocks and logic.
3956       // ParentBB:
3957       //    goto RegionCheckTidBB
3958       // RegionCheckTidBB:
3959       //    Tid = __kmpc_hardware_thread_id()
3960       //    if (Tid != 0)
3961       //        goto RegionBarrierBB
3962       // RegionStartBB:
3963       //    <execute instructions guarded>
3964       //    goto RegionEndBB
3965       // RegionEndBB:
3966       //    <store escaping values to shared mem>
3967       //    goto RegionBarrierBB
3968       //  RegionBarrierBB:
3969       //    __kmpc_simple_barrier_spmd()
3970       //    // second barrier is omitted if lacking escaping values.
3971       //    <load escaping values from shared mem>
3972       //    __kmpc_simple_barrier_spmd()
3973       //    goto RegionExitBB
3974       // RegionExitBB:
3975       //    <execute rest of instructions>
3976 
3977       BasicBlock *RegionEndBB = SplitBlock(ParentBB, RegionEndI->getNextNode(),
3978                                            DT, LI, MSU, "region.guarded.end");
3979       BasicBlock *RegionBarrierBB =
3980           SplitBlock(RegionEndBB, &*RegionEndBB->getFirstInsertionPt(), DT, LI,
3981                      MSU, "region.barrier");
3982       BasicBlock *RegionExitBB =
3983           SplitBlock(RegionBarrierBB, &*RegionBarrierBB->getFirstInsertionPt(),
3984                      DT, LI, MSU, "region.exit");
3985       BasicBlock *RegionStartBB =
3986           SplitBlock(ParentBB, RegionStartI, DT, LI, MSU, "region.guarded");
3987 
3988       assert(ParentBB->getUniqueSuccessor() == RegionStartBB &&
3989              "Expected a different CFG");
3990 
3991       BasicBlock *RegionCheckTidBB = SplitBlock(
3992           ParentBB, ParentBB->getTerminator(), DT, LI, MSU, "region.check.tid");
3993 
3994       // Register basic blocks with the Attributor.
3995       A.registerManifestAddedBasicBlock(*RegionEndBB);
3996       A.registerManifestAddedBasicBlock(*RegionBarrierBB);
3997       A.registerManifestAddedBasicBlock(*RegionExitBB);
3998       A.registerManifestAddedBasicBlock(*RegionStartBB);
3999       A.registerManifestAddedBasicBlock(*RegionCheckTidBB);
4000 
4001       bool HasBroadcastValues = false;
4002       // Find escaping outputs from the guarded region to outside users and
4003       // broadcast their values to them.
4004       for (Instruction &I : *RegionStartBB) {
4005         SmallVector<Use *, 4> OutsideUses;
4006         for (Use &U : I.uses()) {
4007           Instruction &UsrI = *cast<Instruction>(U.getUser());
4008           if (UsrI.getParent() != RegionStartBB)
4009             OutsideUses.push_back(&U);
4010         }
4011 
4012         if (OutsideUses.empty())
4013           continue;
4014 
4015         HasBroadcastValues = true;
4016 
4017         // Emit a global variable in shared memory to store the broadcasted
4018         // value.
4019         auto *SharedMem = new GlobalVariable(
4020             M, I.getType(), /* IsConstant */ false,
4021             GlobalValue::InternalLinkage, UndefValue::get(I.getType()),
4022             sanitizeForGlobalName(
4023                 (I.getName() + ".guarded.output.alloc").str()),
4024             nullptr, GlobalValue::NotThreadLocal,
4025             static_cast<unsigned>(AddressSpace::Shared));
4026 
4027         // Emit a store instruction to update the value.
4028         new StoreInst(&I, SharedMem,
4029                       RegionEndBB->getTerminator()->getIterator());
4030 
4031         LoadInst *LoadI = new LoadInst(
4032             I.getType(), SharedMem, I.getName() + ".guarded.output.load",
4033             RegionBarrierBB->getTerminator()->getIterator());
4034 
4035         // Emit a load instruction and replace uses of the output value.
4036         for (Use *U : OutsideUses)
4037           A.changeUseAfterManifest(*U, *LoadI);
4038       }
4039 
4040       auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4041 
4042       // Go to tid check BB in ParentBB.
4043       const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
4044       ParentBB->getTerminator()->eraseFromParent();
4045       OpenMPIRBuilder::LocationDescription Loc(
4046           InsertPointTy(ParentBB, ParentBB->end()), DL);
4047       OMPInfoCache.OMPBuilder.updateToLocation(Loc);
4048       uint32_t SrcLocStrSize;
4049       auto *SrcLocStr =
4050           OMPInfoCache.OMPBuilder.getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4051       Value *Ident =
4052           OMPInfoCache.OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4053       BranchInst::Create(RegionCheckTidBB, ParentBB)->setDebugLoc(DL);
4054 
4055       // Add check for Tid in RegionCheckTidBB
4056       RegionCheckTidBB->getTerminator()->eraseFromParent();
4057       OpenMPIRBuilder::LocationDescription LocRegionCheckTid(
4058           InsertPointTy(RegionCheckTidBB, RegionCheckTidBB->end()), DL);
4059       OMPInfoCache.OMPBuilder.updateToLocation(LocRegionCheckTid);
4060       FunctionCallee HardwareTidFn =
4061           OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4062               M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
4063       CallInst *Tid =
4064           OMPInfoCache.OMPBuilder.Builder.CreateCall(HardwareTidFn, {});
4065       Tid->setDebugLoc(DL);
4066       OMPInfoCache.setCallingConvention(HardwareTidFn, Tid);
4067       Value *TidCheck = OMPInfoCache.OMPBuilder.Builder.CreateIsNull(Tid);
4068       OMPInfoCache.OMPBuilder.Builder
4069           .CreateCondBr(TidCheck, RegionStartBB, RegionBarrierBB)
4070           ->setDebugLoc(DL);
4071 
4072       // First barrier for synchronization, ensures main thread has updated
4073       // values.
4074       FunctionCallee BarrierFn =
4075           OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4076               M, OMPRTL___kmpc_barrier_simple_spmd);
4077       OMPInfoCache.OMPBuilder.updateToLocation(InsertPointTy(
4078           RegionBarrierBB, RegionBarrierBB->getFirstInsertionPt()));
4079       CallInst *Barrier =
4080           OMPInfoCache.OMPBuilder.Builder.CreateCall(BarrierFn, {Ident, Tid});
4081       Barrier->setDebugLoc(DL);
4082       OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
4083 
4084       // Second barrier ensures workers have read broadcast values.
4085       if (HasBroadcastValues) {
4086         CallInst *Barrier =
4087             CallInst::Create(BarrierFn, {Ident, Tid}, "",
4088                              RegionBarrierBB->getTerminator()->getIterator());
4089         Barrier->setDebugLoc(DL);
4090         OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
4091       }
4092     };
4093 
4094     auto &AllocSharedRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
4095     SmallPtrSet<BasicBlock *, 8> Visited;
4096     for (Instruction *GuardedI : SPMDCompatibilityTracker) {
4097       BasicBlock *BB = GuardedI->getParent();
4098       if (!Visited.insert(BB).second)
4099         continue;
4100 
4101       SmallVector<std::pair<Instruction *, Instruction *>> Reorders;
4102       Instruction *LastEffect = nullptr;
4103       BasicBlock::reverse_iterator IP = BB->rbegin(), IPEnd = BB->rend();
4104       while (++IP != IPEnd) {
4105         if (!IP->mayHaveSideEffects() && !IP->mayReadFromMemory())
4106           continue;
4107         Instruction *I = &*IP;
4108         if (OpenMPOpt::getCallIfRegularCall(*I, &AllocSharedRFI))
4109           continue;
4110         if (!I->user_empty() || !SPMDCompatibilityTracker.contains(I)) {
4111           LastEffect = nullptr;
4112           continue;
4113         }
4114         if (LastEffect)
4115           Reorders.push_back({I, LastEffect});
4116         LastEffect = &*IP;
4117       }
4118       for (auto &Reorder : Reorders)
4119         Reorder.first->moveBefore(Reorder.second);
4120     }
4121 
4122     SmallVector<std::pair<Instruction *, Instruction *>, 4> GuardedRegions;
4123 
4124     for (Instruction *GuardedI : SPMDCompatibilityTracker) {
4125       BasicBlock *BB = GuardedI->getParent();
4126       auto *CalleeAA = A.lookupAAFor<AAKernelInfo>(
4127           IRPosition::function(*GuardedI->getFunction()), nullptr,
4128           DepClassTy::NONE);
4129       assert(CalleeAA != nullptr && "Expected Callee AAKernelInfo");
4130       auto &CalleeAAFunction = *cast<AAKernelInfoFunction>(CalleeAA);
4131       // Continue if instruction is already guarded.
4132       if (CalleeAAFunction.getGuardedInstructions().contains(GuardedI))
4133         continue;
4134 
4135       Instruction *GuardedRegionStart = nullptr, *GuardedRegionEnd = nullptr;
4136       for (Instruction &I : *BB) {
4137         // If instruction I needs to be guarded update the guarded region
4138         // bounds.
4139         if (SPMDCompatibilityTracker.contains(&I)) {
4140           CalleeAAFunction.getGuardedInstructions().insert(&I);
4141           if (GuardedRegionStart)
4142             GuardedRegionEnd = &I;
4143           else
4144             GuardedRegionStart = GuardedRegionEnd = &I;
4145 
4146           continue;
4147         }
4148 
4149         // Instruction I does not need guarding, store
4150         // any region found and reset bounds.
4151         if (GuardedRegionStart) {
4152           GuardedRegions.push_back(
4153               std::make_pair(GuardedRegionStart, GuardedRegionEnd));
4154           GuardedRegionStart = nullptr;
4155           GuardedRegionEnd = nullptr;
4156         }
4157       }
4158     }
4159 
4160     for (auto &GR : GuardedRegions)
4161       CreateGuardedRegion(GR.first, GR.second);
4162   }
4163 
4164   void forceSingleThreadPerWorkgroupHelper(Attributor &A) {
4165     // Only allow 1 thread per workgroup to continue executing the user code.
4166     //
4167     //     InitCB = __kmpc_target_init(...)
4168     //     ThreadIdInBlock = __kmpc_get_hardware_thread_id_in_block();
4169     //     if (ThreadIdInBlock != 0) return;
4170     // UserCode:
4171     //     // user code
4172     //
4173     auto &Ctx = getAnchorValue().getContext();
4174     Function *Kernel = getAssociatedFunction();
4175     assert(Kernel && "Expected an associated function!");
4176 
4177     // Create block for user code to branch to from initial block.
4178     BasicBlock *InitBB = KernelInitCB->getParent();
4179     BasicBlock *UserCodeBB = InitBB->splitBasicBlock(
4180         KernelInitCB->getNextNode(), "main.thread.user_code");
4181     BasicBlock *ReturnBB =
4182         BasicBlock::Create(Ctx, "exit.threads", Kernel, UserCodeBB);
4183 
4184     // Register blocks with attributor:
4185     A.registerManifestAddedBasicBlock(*InitBB);
4186     A.registerManifestAddedBasicBlock(*UserCodeBB);
4187     A.registerManifestAddedBasicBlock(*ReturnBB);
4188 
4189     // Debug location:
4190     const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
4191     ReturnInst::Create(Ctx, ReturnBB)->setDebugLoc(DLoc);
4192     InitBB->getTerminator()->eraseFromParent();
4193 
4194     // Prepare call to OMPRTL___kmpc_get_hardware_thread_id_in_block.
4195     Module &M = *Kernel->getParent();
4196     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4197     FunctionCallee ThreadIdInBlockFn =
4198         OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4199             M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
4200 
4201     // Get thread ID in block.
4202     CallInst *ThreadIdInBlock =
4203         CallInst::Create(ThreadIdInBlockFn, "thread_id.in.block", InitBB);
4204     OMPInfoCache.setCallingConvention(ThreadIdInBlockFn, ThreadIdInBlock);
4205     ThreadIdInBlock->setDebugLoc(DLoc);
4206 
4207     // Eliminate all threads in the block with ID not equal to 0:
4208     Instruction *IsMainThread =
4209         ICmpInst::Create(ICmpInst::ICmp, CmpInst::ICMP_NE, ThreadIdInBlock,
4210                          ConstantInt::get(ThreadIdInBlock->getType(), 0),
4211                          "thread.is_main", InitBB);
4212     IsMainThread->setDebugLoc(DLoc);
4213     BranchInst::Create(ReturnBB, UserCodeBB, IsMainThread, InitBB);
4214   }
4215 
4216   bool changeToSPMDMode(Attributor &A, ChangeStatus &Changed) {
4217     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4218 
4219     // We cannot change to SPMD mode if the runtime functions aren't availible.
4220     if (!OMPInfoCache.runtimeFnsAvailable(
4221             {OMPRTL___kmpc_get_hardware_thread_id_in_block,
4222              OMPRTL___kmpc_barrier_simple_spmd}))
4223       return false;
4224 
4225     if (!SPMDCompatibilityTracker.isAssumed()) {
4226       for (Instruction *NonCompatibleI : SPMDCompatibilityTracker) {
4227         if (!NonCompatibleI)
4228           continue;
4229 
4230         // Skip diagnostics on calls to known OpenMP runtime functions for now.
4231         if (auto *CB = dyn_cast<CallBase>(NonCompatibleI))
4232           if (OMPInfoCache.RTLFunctions.contains(CB->getCalledFunction()))
4233             continue;
4234 
4235         auto Remark = [&](OptimizationRemarkAnalysis ORA) {
4236           ORA << "Value has potential side effects preventing SPMD-mode "
4237                  "execution";
4238           if (isa<CallBase>(NonCompatibleI)) {
4239             ORA << ". Add `[[omp::assume(\"ompx_spmd_amenable\")]]` to "
4240                    "the called function to override";
4241           }
4242           return ORA << ".";
4243         };
4244         A.emitRemark<OptimizationRemarkAnalysis>(NonCompatibleI, "OMP121",
4245                                                  Remark);
4246 
4247         LLVM_DEBUG(dbgs() << TAG << "SPMD-incompatible side-effect: "
4248                           << *NonCompatibleI << "\n");
4249       }
4250 
4251       return false;
4252     }
4253 
4254     // Get the actual kernel, could be the caller of the anchor scope if we have
4255     // a debug wrapper.
4256     Function *Kernel = getAnchorScope();
4257     if (Kernel->hasLocalLinkage()) {
4258       assert(Kernel->hasOneUse() && "Unexpected use of debug kernel wrapper.");
4259       auto *CB = cast<CallBase>(Kernel->user_back());
4260       Kernel = CB->getCaller();
4261     }
4262     assert(omp::isOpenMPKernel(*Kernel) && "Expected kernel function!");
4263 
4264     // Check if the kernel is already in SPMD mode, if so, return success.
4265     ConstantStruct *ExistingKernelEnvC =
4266         KernelInfo::getKernelEnvironementFromKernelInitCB(KernelInitCB);
4267     auto *ExecModeC =
4268         KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC);
4269     const int8_t ExecModeVal = ExecModeC->getSExtValue();
4270     if (ExecModeVal != OMP_TGT_EXEC_MODE_GENERIC)
4271       return true;
4272 
4273     // We will now unconditionally modify the IR, indicate a change.
4274     Changed = ChangeStatus::CHANGED;
4275 
4276     // Do not use instruction guards when no parallel is present inside
4277     // the target region.
4278     if (mayContainParallelRegion())
4279       insertInstructionGuardsHelper(A);
4280     else
4281       forceSingleThreadPerWorkgroupHelper(A);
4282 
4283     // Adjust the global exec mode flag that tells the runtime what mode this
4284     // kernel is executed in.
4285     assert(ExecModeVal == OMP_TGT_EXEC_MODE_GENERIC &&
4286            "Initially non-SPMD kernel has SPMD exec mode!");
4287     setExecModeOfKernelEnvironment(
4288         ConstantInt::get(ExecModeC->getIntegerType(),
4289                          ExecModeVal | OMP_TGT_EXEC_MODE_GENERIC_SPMD));
4290 
4291     ++NumOpenMPTargetRegionKernelsSPMD;
4292 
4293     auto Remark = [&](OptimizationRemark OR) {
4294       return OR << "Transformed generic-mode kernel to SPMD-mode.";
4295     };
4296     A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP120", Remark);
4297     return true;
4298   };
4299 
4300   bool buildCustomStateMachine(Attributor &A, ChangeStatus &Changed) {
4301     // If we have disabled state machine rewrites, don't make a custom one
4302     if (DisableOpenMPOptStateMachineRewrite)
4303       return false;
4304 
4305     // Don't rewrite the state machine if we are not in a valid state.
4306     if (!ReachedKnownParallelRegions.isValidState())
4307       return false;
4308 
4309     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4310     if (!OMPInfoCache.runtimeFnsAvailable(
4311             {OMPRTL___kmpc_get_hardware_num_threads_in_block,
4312              OMPRTL___kmpc_get_warp_size, OMPRTL___kmpc_barrier_simple_generic,
4313              OMPRTL___kmpc_kernel_parallel, OMPRTL___kmpc_kernel_end_parallel}))
4314       return false;
4315 
4316     ConstantStruct *ExistingKernelEnvC =
4317         KernelInfo::getKernelEnvironementFromKernelInitCB(KernelInitCB);
4318 
4319     // Check if the current configuration is non-SPMD and generic state machine.
4320     // If we already have SPMD mode or a custom state machine we do not need to
4321     // go any further. If it is anything but a constant something is weird and
4322     // we give up.
4323     ConstantInt *UseStateMachineC =
4324         KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
4325             ExistingKernelEnvC);
4326     ConstantInt *ModeC =
4327         KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC);
4328 
4329     // If we are stuck with generic mode, try to create a custom device (=GPU)
4330     // state machine which is specialized for the parallel regions that are
4331     // reachable by the kernel.
4332     if (UseStateMachineC->isZero() ||
4333         (ModeC->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD))
4334       return false;
4335 
4336     Changed = ChangeStatus::CHANGED;
4337 
4338     // If not SPMD mode, indicate we use a custom state machine now.
4339     setUseGenericStateMachineOfKernelEnvironment(
4340         ConstantInt::get(UseStateMachineC->getIntegerType(), false));
4341 
4342     // If we don't actually need a state machine we are done here. This can
4343     // happen if there simply are no parallel regions. In the resulting kernel
4344     // all worker threads will simply exit right away, leaving the main thread
4345     // to do the work alone.
4346     if (!mayContainParallelRegion()) {
4347       ++NumOpenMPTargetRegionKernelsWithoutStateMachine;
4348 
4349       auto Remark = [&](OptimizationRemark OR) {
4350         return OR << "Removing unused state machine from generic-mode kernel.";
4351       };
4352       A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP130", Remark);
4353 
4354       return true;
4355     }
4356 
4357     // Keep track in the statistics of our new shiny custom state machine.
4358     if (ReachedUnknownParallelRegions.empty()) {
4359       ++NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback;
4360 
4361       auto Remark = [&](OptimizationRemark OR) {
4362         return OR << "Rewriting generic-mode kernel with a customized state "
4363                      "machine.";
4364       };
4365       A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP131", Remark);
4366     } else {
4367       ++NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback;
4368 
4369       auto Remark = [&](OptimizationRemarkAnalysis OR) {
4370         return OR << "Generic-mode kernel is executed with a customized state "
4371                      "machine that requires a fallback.";
4372       };
4373       A.emitRemark<OptimizationRemarkAnalysis>(KernelInitCB, "OMP132", Remark);
4374 
4375       // Tell the user why we ended up with a fallback.
4376       for (CallBase *UnknownParallelRegionCB : ReachedUnknownParallelRegions) {
4377         if (!UnknownParallelRegionCB)
4378           continue;
4379         auto Remark = [&](OptimizationRemarkAnalysis ORA) {
4380           return ORA << "Call may contain unknown parallel regions. Use "
4381                      << "`[[omp::assume(\"omp_no_parallelism\")]]` to "
4382                         "override.";
4383         };
4384         A.emitRemark<OptimizationRemarkAnalysis>(UnknownParallelRegionCB,
4385                                                  "OMP133", Remark);
4386       }
4387     }
4388 
4389     // Create all the blocks:
4390     //
4391     //                       InitCB = __kmpc_target_init(...)
4392     //                       BlockHwSize =
4393     //                         __kmpc_get_hardware_num_threads_in_block();
4394     //                       WarpSize = __kmpc_get_warp_size();
4395     //                       BlockSize = BlockHwSize - WarpSize;
4396     // IsWorkerCheckBB:      bool IsWorker = InitCB != -1;
4397     //                       if (IsWorker) {
4398     //                         if (InitCB >= BlockSize) return;
4399     // SMBeginBB:               __kmpc_barrier_simple_generic(...);
4400     //                         void *WorkFn;
4401     //                         bool Active = __kmpc_kernel_parallel(&WorkFn);
4402     //                         if (!WorkFn) return;
4403     // SMIsActiveCheckBB:       if (Active) {
4404     // SMIfCascadeCurrentBB:      if      (WorkFn == <ParFn0>)
4405     //                              ParFn0(...);
4406     // SMIfCascadeCurrentBB:      else if (WorkFn == <ParFn1>)
4407     //                              ParFn1(...);
4408     //                            ...
4409     // SMIfCascadeCurrentBB:      else
4410     //                              ((WorkFnTy*)WorkFn)(...);
4411     // SMEndParallelBB:           __kmpc_kernel_end_parallel(...);
4412     //                          }
4413     // SMDoneBB:                __kmpc_barrier_simple_generic(...);
4414     //                          goto SMBeginBB;
4415     //                       }
4416     // UserCodeEntryBB:      // user code
4417     //                       __kmpc_target_deinit(...)
4418     //
4419     auto &Ctx = getAnchorValue().getContext();
4420     Function *Kernel = getAssociatedFunction();
4421     assert(Kernel && "Expected an associated function!");
4422 
4423     BasicBlock *InitBB = KernelInitCB->getParent();
4424     BasicBlock *UserCodeEntryBB = InitBB->splitBasicBlock(
4425         KernelInitCB->getNextNode(), "thread.user_code.check");
4426     BasicBlock *IsWorkerCheckBB =
4427         BasicBlock::Create(Ctx, "is_worker_check", Kernel, UserCodeEntryBB);
4428     BasicBlock *StateMachineBeginBB = BasicBlock::Create(
4429         Ctx, "worker_state_machine.begin", Kernel, UserCodeEntryBB);
4430     BasicBlock *StateMachineFinishedBB = BasicBlock::Create(
4431         Ctx, "worker_state_machine.finished", Kernel, UserCodeEntryBB);
4432     BasicBlock *StateMachineIsActiveCheckBB = BasicBlock::Create(
4433         Ctx, "worker_state_machine.is_active.check", Kernel, UserCodeEntryBB);
4434     BasicBlock *StateMachineIfCascadeCurrentBB =
4435         BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
4436                            Kernel, UserCodeEntryBB);
4437     BasicBlock *StateMachineEndParallelBB =
4438         BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.end",
4439                            Kernel, UserCodeEntryBB);
4440     BasicBlock *StateMachineDoneBarrierBB = BasicBlock::Create(
4441         Ctx, "worker_state_machine.done.barrier", Kernel, UserCodeEntryBB);
4442     A.registerManifestAddedBasicBlock(*InitBB);
4443     A.registerManifestAddedBasicBlock(*UserCodeEntryBB);
4444     A.registerManifestAddedBasicBlock(*IsWorkerCheckBB);
4445     A.registerManifestAddedBasicBlock(*StateMachineBeginBB);
4446     A.registerManifestAddedBasicBlock(*StateMachineFinishedBB);
4447     A.registerManifestAddedBasicBlock(*StateMachineIsActiveCheckBB);
4448     A.registerManifestAddedBasicBlock(*StateMachineIfCascadeCurrentBB);
4449     A.registerManifestAddedBasicBlock(*StateMachineEndParallelBB);
4450     A.registerManifestAddedBasicBlock(*StateMachineDoneBarrierBB);
4451 
4452     const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
4453     ReturnInst::Create(Ctx, StateMachineFinishedBB)->setDebugLoc(DLoc);
4454     InitBB->getTerminator()->eraseFromParent();
4455 
4456     Instruction *IsWorker =
4457         ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_NE, KernelInitCB,
4458                          ConstantInt::get(KernelInitCB->getType(), -1),
4459                          "thread.is_worker", InitBB);
4460     IsWorker->setDebugLoc(DLoc);
4461     BranchInst::Create(IsWorkerCheckBB, UserCodeEntryBB, IsWorker, InitBB);
4462 
4463     Module &M = *Kernel->getParent();
4464     FunctionCallee BlockHwSizeFn =
4465         OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4466             M, OMPRTL___kmpc_get_hardware_num_threads_in_block);
4467     FunctionCallee WarpSizeFn =
4468         OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4469             M, OMPRTL___kmpc_get_warp_size);
4470     CallInst *BlockHwSize =
4471         CallInst::Create(BlockHwSizeFn, "block.hw_size", IsWorkerCheckBB);
4472     OMPInfoCache.setCallingConvention(BlockHwSizeFn, BlockHwSize);
4473     BlockHwSize->setDebugLoc(DLoc);
4474     CallInst *WarpSize =
4475         CallInst::Create(WarpSizeFn, "warp.size", IsWorkerCheckBB);
4476     OMPInfoCache.setCallingConvention(WarpSizeFn, WarpSize);
4477     WarpSize->setDebugLoc(DLoc);
4478     Instruction *BlockSize = BinaryOperator::CreateSub(
4479         BlockHwSize, WarpSize, "block.size", IsWorkerCheckBB);
4480     BlockSize->setDebugLoc(DLoc);
4481     Instruction *IsMainOrWorker = ICmpInst::Create(
4482         ICmpInst::ICmp, llvm::CmpInst::ICMP_SLT, KernelInitCB, BlockSize,
4483         "thread.is_main_or_worker", IsWorkerCheckBB);
4484     IsMainOrWorker->setDebugLoc(DLoc);
4485     BranchInst::Create(StateMachineBeginBB, StateMachineFinishedBB,
4486                        IsMainOrWorker, IsWorkerCheckBB);
4487 
4488     // Create local storage for the work function pointer.
4489     const DataLayout &DL = M.getDataLayout();
4490     Type *VoidPtrTy = PointerType::getUnqual(Ctx);
4491     Instruction *WorkFnAI =
4492         new AllocaInst(VoidPtrTy, DL.getAllocaAddrSpace(), nullptr,
4493                        "worker.work_fn.addr", Kernel->getEntryBlock().begin());
4494     WorkFnAI->setDebugLoc(DLoc);
4495 
4496     OMPInfoCache.OMPBuilder.updateToLocation(
4497         OpenMPIRBuilder::LocationDescription(
4498             IRBuilder<>::InsertPoint(StateMachineBeginBB,
4499                                      StateMachineBeginBB->end()),
4500             DLoc));
4501 
4502     Value *Ident = KernelInfo::getIdentFromKernelEnvironment(KernelEnvC);
4503     Value *GTid = KernelInitCB;
4504 
4505     FunctionCallee BarrierFn =
4506         OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4507             M, OMPRTL___kmpc_barrier_simple_generic);
4508     CallInst *Barrier =
4509         CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineBeginBB);
4510     OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
4511     Barrier->setDebugLoc(DLoc);
4512 
4513     if (WorkFnAI->getType()->getPointerAddressSpace() !=
4514         (unsigned int)AddressSpace::Generic) {
4515       WorkFnAI = new AddrSpaceCastInst(
4516           WorkFnAI, PointerType::get(Ctx, (unsigned int)AddressSpace::Generic),
4517           WorkFnAI->getName() + ".generic", StateMachineBeginBB);
4518       WorkFnAI->setDebugLoc(DLoc);
4519     }
4520 
4521     FunctionCallee KernelParallelFn =
4522         OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4523             M, OMPRTL___kmpc_kernel_parallel);
4524     CallInst *IsActiveWorker = CallInst::Create(
4525         KernelParallelFn, {WorkFnAI}, "worker.is_active", StateMachineBeginBB);
4526     OMPInfoCache.setCallingConvention(KernelParallelFn, IsActiveWorker);
4527     IsActiveWorker->setDebugLoc(DLoc);
4528     Instruction *WorkFn = new LoadInst(VoidPtrTy, WorkFnAI, "worker.work_fn",
4529                                        StateMachineBeginBB);
4530     WorkFn->setDebugLoc(DLoc);
4531 
4532     FunctionType *ParallelRegionFnTy = FunctionType::get(
4533         Type::getVoidTy(Ctx), {Type::getInt16Ty(Ctx), Type::getInt32Ty(Ctx)},
4534         false);
4535 
4536     Instruction *IsDone =
4537         ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn,
4538                          Constant::getNullValue(VoidPtrTy), "worker.is_done",
4539                          StateMachineBeginBB);
4540     IsDone->setDebugLoc(DLoc);
4541     BranchInst::Create(StateMachineFinishedBB, StateMachineIsActiveCheckBB,
4542                        IsDone, StateMachineBeginBB)
4543         ->setDebugLoc(DLoc);
4544 
4545     BranchInst::Create(StateMachineIfCascadeCurrentBB,
4546                        StateMachineDoneBarrierBB, IsActiveWorker,
4547                        StateMachineIsActiveCheckBB)
4548         ->setDebugLoc(DLoc);
4549 
4550     Value *ZeroArg =
4551         Constant::getNullValue(ParallelRegionFnTy->getParamType(0));
4552 
4553     const unsigned int WrapperFunctionArgNo = 6;
4554 
4555     // Now that we have most of the CFG skeleton it is time for the if-cascade
4556     // that checks the function pointer we got from the runtime against the
4557     // parallel regions we expect, if there are any.
4558     for (int I = 0, E = ReachedKnownParallelRegions.size(); I < E; ++I) {
4559       auto *CB = ReachedKnownParallelRegions[I];
4560       auto *ParallelRegion = dyn_cast<Function>(
4561           CB->getArgOperand(WrapperFunctionArgNo)->stripPointerCasts());
4562       BasicBlock *PRExecuteBB = BasicBlock::Create(
4563           Ctx, "worker_state_machine.parallel_region.execute", Kernel,
4564           StateMachineEndParallelBB);
4565       CallInst::Create(ParallelRegion, {ZeroArg, GTid}, "", PRExecuteBB)
4566           ->setDebugLoc(DLoc);
4567       BranchInst::Create(StateMachineEndParallelBB, PRExecuteBB)
4568           ->setDebugLoc(DLoc);
4569 
4570       BasicBlock *PRNextBB =
4571           BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
4572                              Kernel, StateMachineEndParallelBB);
4573       A.registerManifestAddedBasicBlock(*PRExecuteBB);
4574       A.registerManifestAddedBasicBlock(*PRNextBB);
4575 
4576       // Check if we need to compare the pointer at all or if we can just
4577       // call the parallel region function.
4578       Value *IsPR;
4579       if (I + 1 < E || !ReachedUnknownParallelRegions.empty()) {
4580         Instruction *CmpI = ICmpInst::Create(
4581             ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn, ParallelRegion,
4582             "worker.check_parallel_region", StateMachineIfCascadeCurrentBB);
4583         CmpI->setDebugLoc(DLoc);
4584         IsPR = CmpI;
4585       } else {
4586         IsPR = ConstantInt::getTrue(Ctx);
4587       }
4588 
4589       BranchInst::Create(PRExecuteBB, PRNextBB, IsPR,
4590                          StateMachineIfCascadeCurrentBB)
4591           ->setDebugLoc(DLoc);
4592       StateMachineIfCascadeCurrentBB = PRNextBB;
4593     }
4594 
4595     // At the end of the if-cascade we place the indirect function pointer call
4596     // in case we might need it, that is if there can be parallel regions we
4597     // have not handled in the if-cascade above.
4598     if (!ReachedUnknownParallelRegions.empty()) {
4599       StateMachineIfCascadeCurrentBB->setName(
4600           "worker_state_machine.parallel_region.fallback.execute");
4601       CallInst::Create(ParallelRegionFnTy, WorkFn, {ZeroArg, GTid}, "",
4602                        StateMachineIfCascadeCurrentBB)
4603           ->setDebugLoc(DLoc);
4604     }
4605     BranchInst::Create(StateMachineEndParallelBB,
4606                        StateMachineIfCascadeCurrentBB)
4607         ->setDebugLoc(DLoc);
4608 
4609     FunctionCallee EndParallelFn =
4610         OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4611             M, OMPRTL___kmpc_kernel_end_parallel);
4612     CallInst *EndParallel =
4613         CallInst::Create(EndParallelFn, {}, "", StateMachineEndParallelBB);
4614     OMPInfoCache.setCallingConvention(EndParallelFn, EndParallel);
4615     EndParallel->setDebugLoc(DLoc);
4616     BranchInst::Create(StateMachineDoneBarrierBB, StateMachineEndParallelBB)
4617         ->setDebugLoc(DLoc);
4618 
4619     CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineDoneBarrierBB)
4620         ->setDebugLoc(DLoc);
4621     BranchInst::Create(StateMachineBeginBB, StateMachineDoneBarrierBB)
4622         ->setDebugLoc(DLoc);
4623 
4624     return true;
4625   }
4626 
4627   /// Fixpoint iteration update function. Will be called every time a dependence
4628   /// changed its state (and in the beginning).
4629   ChangeStatus updateImpl(Attributor &A) override {
4630     KernelInfoState StateBefore = getState();
4631 
4632     // When we leave this function this RAII will make sure the member
4633     // KernelEnvC is updated properly depending on the state. That member is
4634     // used for simplification of values and needs to be up to date at all
4635     // times.
4636     struct UpdateKernelEnvCRAII {
4637       AAKernelInfoFunction &AA;
4638 
4639       UpdateKernelEnvCRAII(AAKernelInfoFunction &AA) : AA(AA) {}
4640 
4641       ~UpdateKernelEnvCRAII() {
4642         if (!AA.KernelEnvC)
4643           return;
4644 
4645         ConstantStruct *ExistingKernelEnvC =
4646             KernelInfo::getKernelEnvironementFromKernelInitCB(AA.KernelInitCB);
4647 
4648         if (!AA.isValidState()) {
4649           AA.KernelEnvC = ExistingKernelEnvC;
4650           return;
4651         }
4652 
4653         if (!AA.ReachedKnownParallelRegions.isValidState())
4654           AA.setUseGenericStateMachineOfKernelEnvironment(
4655               KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
4656                   ExistingKernelEnvC));
4657 
4658         if (!AA.SPMDCompatibilityTracker.isValidState())
4659           AA.setExecModeOfKernelEnvironment(
4660               KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC));
4661 
4662         ConstantInt *MayUseNestedParallelismC =
4663             KernelInfo::getMayUseNestedParallelismFromKernelEnvironment(
4664                 AA.KernelEnvC);
4665         ConstantInt *NewMayUseNestedParallelismC = ConstantInt::get(
4666             MayUseNestedParallelismC->getIntegerType(), AA.NestedParallelism);
4667         AA.setMayUseNestedParallelismOfKernelEnvironment(
4668             NewMayUseNestedParallelismC);
4669       }
4670     } RAII(*this);
4671 
4672     // Callback to check a read/write instruction.
4673     auto CheckRWInst = [&](Instruction &I) {
4674       // We handle calls later.
4675       if (isa<CallBase>(I))
4676         return true;
4677       // We only care about write effects.
4678       if (!I.mayWriteToMemory())
4679         return true;
4680       if (auto *SI = dyn_cast<StoreInst>(&I)) {
4681         const auto *UnderlyingObjsAA = A.getAAFor<AAUnderlyingObjects>(
4682             *this, IRPosition::value(*SI->getPointerOperand()),
4683             DepClassTy::OPTIONAL);
4684         auto *HS = A.getAAFor<AAHeapToStack>(
4685             *this, IRPosition::function(*I.getFunction()),
4686             DepClassTy::OPTIONAL);
4687         if (UnderlyingObjsAA &&
4688             UnderlyingObjsAA->forallUnderlyingObjects([&](Value &Obj) {
4689               if (AA::isAssumedThreadLocalObject(A, Obj, *this))
4690                 return true;
4691               // Check for AAHeapToStack moved objects which must not be
4692               // guarded.
4693               auto *CB = dyn_cast<CallBase>(&Obj);
4694               return CB && HS && HS->isAssumedHeapToStack(*CB);
4695             }))
4696           return true;
4697       }
4698 
4699       // Insert instruction that needs guarding.
4700       SPMDCompatibilityTracker.insert(&I);
4701       return true;
4702     };
4703 
4704     bool UsedAssumedInformationInCheckRWInst = false;
4705     if (!SPMDCompatibilityTracker.isAtFixpoint())
4706       if (!A.checkForAllReadWriteInstructions(
4707               CheckRWInst, *this, UsedAssumedInformationInCheckRWInst))
4708         SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4709 
4710     bool UsedAssumedInformationFromReachingKernels = false;
4711     if (!IsKernelEntry) {
4712       updateParallelLevels(A);
4713 
4714       bool AllReachingKernelsKnown = true;
4715       updateReachingKernelEntries(A, AllReachingKernelsKnown);
4716       UsedAssumedInformationFromReachingKernels = !AllReachingKernelsKnown;
4717 
4718       if (!SPMDCompatibilityTracker.empty()) {
4719         if (!ParallelLevels.isValidState())
4720           SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4721         else if (!ReachingKernelEntries.isValidState())
4722           SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4723         else {
4724           // Check if all reaching kernels agree on the mode as we can otherwise
4725           // not guard instructions. We might not be sure about the mode so we
4726           // we cannot fix the internal spmd-zation state either.
4727           int SPMD = 0, Generic = 0;
4728           for (auto *Kernel : ReachingKernelEntries) {
4729             auto *CBAA = A.getAAFor<AAKernelInfo>(
4730                 *this, IRPosition::function(*Kernel), DepClassTy::OPTIONAL);
4731             if (CBAA && CBAA->SPMDCompatibilityTracker.isValidState() &&
4732                 CBAA->SPMDCompatibilityTracker.isAssumed())
4733               ++SPMD;
4734             else
4735               ++Generic;
4736             if (!CBAA || !CBAA->SPMDCompatibilityTracker.isAtFixpoint())
4737               UsedAssumedInformationFromReachingKernels = true;
4738           }
4739           if (SPMD != 0 && Generic != 0)
4740             SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4741         }
4742       }
4743     }
4744 
4745     // Callback to check a call instruction.
4746     bool AllParallelRegionStatesWereFixed = true;
4747     bool AllSPMDStatesWereFixed = true;
4748     auto CheckCallInst = [&](Instruction &I) {
4749       auto &CB = cast<CallBase>(I);
4750       auto *CBAA = A.getAAFor<AAKernelInfo>(
4751           *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
4752       if (!CBAA)
4753         return false;
4754       getState() ^= CBAA->getState();
4755       AllSPMDStatesWereFixed &= CBAA->SPMDCompatibilityTracker.isAtFixpoint();
4756       AllParallelRegionStatesWereFixed &=
4757           CBAA->ReachedKnownParallelRegions.isAtFixpoint();
4758       AllParallelRegionStatesWereFixed &=
4759           CBAA->ReachedUnknownParallelRegions.isAtFixpoint();
4760       return true;
4761     };
4762 
4763     bool UsedAssumedInformationInCheckCallInst = false;
4764     if (!A.checkForAllCallLikeInstructions(
4765             CheckCallInst, *this, UsedAssumedInformationInCheckCallInst)) {
4766       LLVM_DEBUG(dbgs() << TAG
4767                         << "Failed to visit all call-like instructions!\n";);
4768       return indicatePessimisticFixpoint();
4769     }
4770 
4771     // If we haven't used any assumed information for the reached parallel
4772     // region states we can fix it.
4773     if (!UsedAssumedInformationInCheckCallInst &&
4774         AllParallelRegionStatesWereFixed) {
4775       ReachedKnownParallelRegions.indicateOptimisticFixpoint();
4776       ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
4777     }
4778 
4779     // If we haven't used any assumed information for the SPMD state we can fix
4780     // it.
4781     if (!UsedAssumedInformationInCheckRWInst &&
4782         !UsedAssumedInformationInCheckCallInst &&
4783         !UsedAssumedInformationFromReachingKernels && AllSPMDStatesWereFixed)
4784       SPMDCompatibilityTracker.indicateOptimisticFixpoint();
4785 
4786     return StateBefore == getState() ? ChangeStatus::UNCHANGED
4787                                      : ChangeStatus::CHANGED;
4788   }
4789 
4790 private:
4791   /// Update info regarding reaching kernels.
4792   void updateReachingKernelEntries(Attributor &A,
4793                                    bool &AllReachingKernelsKnown) {
4794     auto PredCallSite = [&](AbstractCallSite ACS) {
4795       Function *Caller = ACS.getInstruction()->getFunction();
4796 
4797       assert(Caller && "Caller is nullptr");
4798 
4799       auto *CAA = A.getOrCreateAAFor<AAKernelInfo>(
4800           IRPosition::function(*Caller), this, DepClassTy::REQUIRED);
4801       if (CAA && CAA->ReachingKernelEntries.isValidState()) {
4802         ReachingKernelEntries ^= CAA->ReachingKernelEntries;
4803         return true;
4804       }
4805 
4806       // We lost track of the caller of the associated function, any kernel
4807       // could reach now.
4808       ReachingKernelEntries.indicatePessimisticFixpoint();
4809 
4810       return true;
4811     };
4812 
4813     if (!A.checkForAllCallSites(PredCallSite, *this,
4814                                 true /* RequireAllCallSites */,
4815                                 AllReachingKernelsKnown))
4816       ReachingKernelEntries.indicatePessimisticFixpoint();
4817   }
4818 
4819   /// Update info regarding parallel levels.
4820   void updateParallelLevels(Attributor &A) {
4821     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4822     OMPInformationCache::RuntimeFunctionInfo &Parallel51RFI =
4823         OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
4824 
4825     auto PredCallSite = [&](AbstractCallSite ACS) {
4826       Function *Caller = ACS.getInstruction()->getFunction();
4827 
4828       assert(Caller && "Caller is nullptr");
4829 
4830       auto *CAA =
4831           A.getOrCreateAAFor<AAKernelInfo>(IRPosition::function(*Caller));
4832       if (CAA && CAA->ParallelLevels.isValidState()) {
4833         // Any function that is called by `__kmpc_parallel_51` will not be
4834         // folded as the parallel level in the function is updated. In order to
4835         // get it right, all the analysis would depend on the implentation. That
4836         // said, if in the future any change to the implementation, the analysis
4837         // could be wrong. As a consequence, we are just conservative here.
4838         if (Caller == Parallel51RFI.Declaration) {
4839           ParallelLevels.indicatePessimisticFixpoint();
4840           return true;
4841         }
4842 
4843         ParallelLevels ^= CAA->ParallelLevels;
4844 
4845         return true;
4846       }
4847 
4848       // We lost track of the caller of the associated function, any kernel
4849       // could reach now.
4850       ParallelLevels.indicatePessimisticFixpoint();
4851 
4852       return true;
4853     };
4854 
4855     bool AllCallSitesKnown = true;
4856     if (!A.checkForAllCallSites(PredCallSite, *this,
4857                                 true /* RequireAllCallSites */,
4858                                 AllCallSitesKnown))
4859       ParallelLevels.indicatePessimisticFixpoint();
4860   }
4861 };
4862 
4863 /// The call site kernel info abstract attribute, basically, what can we say
4864 /// about a call site with regards to the KernelInfoState. For now this simply
4865 /// forwards the information from the callee.
4866 struct AAKernelInfoCallSite : AAKernelInfo {
4867   AAKernelInfoCallSite(const IRPosition &IRP, Attributor &A)
4868       : AAKernelInfo(IRP, A) {}
4869 
4870   /// See AbstractAttribute::initialize(...).
4871   void initialize(Attributor &A) override {
4872     AAKernelInfo::initialize(A);
4873 
4874     CallBase &CB = cast<CallBase>(getAssociatedValue());
4875     auto *AssumptionAA = A.getAAFor<AAAssumptionInfo>(
4876         *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
4877 
4878     // Check for SPMD-mode assumptions.
4879     if (AssumptionAA && AssumptionAA->hasAssumption("ompx_spmd_amenable")) {
4880       indicateOptimisticFixpoint();
4881       return;
4882     }
4883 
4884     // First weed out calls we do not care about, that is readonly/readnone
4885     // calls, intrinsics, and "no_openmp" calls. Neither of these can reach a
4886     // parallel region or anything else we are looking for.
4887     if (!CB.mayWriteToMemory() || isa<IntrinsicInst>(CB)) {
4888       indicateOptimisticFixpoint();
4889       return;
4890     }
4891 
4892     // Next we check if we know the callee. If it is a known OpenMP function
4893     // we will handle them explicitly in the switch below. If it is not, we
4894     // will use an AAKernelInfo object on the callee to gather information and
4895     // merge that into the current state. The latter happens in the updateImpl.
4896     auto CheckCallee = [&](Function *Callee, unsigned NumCallees) {
4897       auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4898       const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
4899       if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
4900         // Unknown caller or declarations are not analyzable, we give up.
4901         if (!Callee || !A.isFunctionIPOAmendable(*Callee)) {
4902 
4903           // Unknown callees might contain parallel regions, except if they have
4904           // an appropriate assumption attached.
4905           if (!AssumptionAA ||
4906               !(AssumptionAA->hasAssumption("omp_no_openmp") ||
4907                 AssumptionAA->hasAssumption("omp_no_parallelism")))
4908             ReachedUnknownParallelRegions.insert(&CB);
4909 
4910           // If SPMDCompatibilityTracker is not fixed, we need to give up on the
4911           // idea we can run something unknown in SPMD-mode.
4912           if (!SPMDCompatibilityTracker.isAtFixpoint()) {
4913             SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4914             SPMDCompatibilityTracker.insert(&CB);
4915           }
4916 
4917           // We have updated the state for this unknown call properly, there
4918           // won't be any change so we indicate a fixpoint.
4919           indicateOptimisticFixpoint();
4920         }
4921         // If the callee is known and can be used in IPO, we will update the
4922         // state based on the callee state in updateImpl.
4923         return;
4924       }
4925       if (NumCallees > 1) {
4926         indicatePessimisticFixpoint();
4927         return;
4928       }
4929 
4930       RuntimeFunction RF = It->getSecond();
4931       switch (RF) {
4932       // All the functions we know are compatible with SPMD mode.
4933       case OMPRTL___kmpc_is_spmd_exec_mode:
4934       case OMPRTL___kmpc_distribute_static_fini:
4935       case OMPRTL___kmpc_for_static_fini:
4936       case OMPRTL___kmpc_global_thread_num:
4937       case OMPRTL___kmpc_get_hardware_num_threads_in_block:
4938       case OMPRTL___kmpc_get_hardware_num_blocks:
4939       case OMPRTL___kmpc_single:
4940       case OMPRTL___kmpc_end_single:
4941       case OMPRTL___kmpc_master:
4942       case OMPRTL___kmpc_end_master:
4943       case OMPRTL___kmpc_barrier:
4944       case OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2:
4945       case OMPRTL___kmpc_nvptx_teams_reduce_nowait_v2:
4946       case OMPRTL___kmpc_error:
4947       case OMPRTL___kmpc_flush:
4948       case OMPRTL___kmpc_get_hardware_thread_id_in_block:
4949       case OMPRTL___kmpc_get_warp_size:
4950       case OMPRTL_omp_get_thread_num:
4951       case OMPRTL_omp_get_num_threads:
4952       case OMPRTL_omp_get_max_threads:
4953       case OMPRTL_omp_in_parallel:
4954       case OMPRTL_omp_get_dynamic:
4955       case OMPRTL_omp_get_cancellation:
4956       case OMPRTL_omp_get_nested:
4957       case OMPRTL_omp_get_schedule:
4958       case OMPRTL_omp_get_thread_limit:
4959       case OMPRTL_omp_get_supported_active_levels:
4960       case OMPRTL_omp_get_max_active_levels:
4961       case OMPRTL_omp_get_level:
4962       case OMPRTL_omp_get_ancestor_thread_num:
4963       case OMPRTL_omp_get_team_size:
4964       case OMPRTL_omp_get_active_level:
4965       case OMPRTL_omp_in_final:
4966       case OMPRTL_omp_get_proc_bind:
4967       case OMPRTL_omp_get_num_places:
4968       case OMPRTL_omp_get_num_procs:
4969       case OMPRTL_omp_get_place_proc_ids:
4970       case OMPRTL_omp_get_place_num:
4971       case OMPRTL_omp_get_partition_num_places:
4972       case OMPRTL_omp_get_partition_place_nums:
4973       case OMPRTL_omp_get_wtime:
4974         break;
4975       case OMPRTL___kmpc_distribute_static_init_4:
4976       case OMPRTL___kmpc_distribute_static_init_4u:
4977       case OMPRTL___kmpc_distribute_static_init_8:
4978       case OMPRTL___kmpc_distribute_static_init_8u:
4979       case OMPRTL___kmpc_for_static_init_4:
4980       case OMPRTL___kmpc_for_static_init_4u:
4981       case OMPRTL___kmpc_for_static_init_8:
4982       case OMPRTL___kmpc_for_static_init_8u: {
4983         // Check the schedule and allow static schedule in SPMD mode.
4984         unsigned ScheduleArgOpNo = 2;
4985         auto *ScheduleTypeCI =
4986             dyn_cast<ConstantInt>(CB.getArgOperand(ScheduleArgOpNo));
4987         unsigned ScheduleTypeVal =
4988             ScheduleTypeCI ? ScheduleTypeCI->getZExtValue() : 0;
4989         switch (OMPScheduleType(ScheduleTypeVal)) {
4990         case OMPScheduleType::UnorderedStatic:
4991         case OMPScheduleType::UnorderedStaticChunked:
4992         case OMPScheduleType::OrderedDistribute:
4993         case OMPScheduleType::OrderedDistributeChunked:
4994           break;
4995         default:
4996           SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4997           SPMDCompatibilityTracker.insert(&CB);
4998           break;
4999         };
5000       } break;
5001       case OMPRTL___kmpc_target_init:
5002         KernelInitCB = &CB;
5003         break;
5004       case OMPRTL___kmpc_target_deinit:
5005         KernelDeinitCB = &CB;
5006         break;
5007       case OMPRTL___kmpc_parallel_51:
5008         if (!handleParallel51(A, CB))
5009           indicatePessimisticFixpoint();
5010         return;
5011       case OMPRTL___kmpc_omp_task:
5012         // We do not look into tasks right now, just give up.
5013         SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5014         SPMDCompatibilityTracker.insert(&CB);
5015         ReachedUnknownParallelRegions.insert(&CB);
5016         break;
5017       case OMPRTL___kmpc_alloc_shared:
5018       case OMPRTL___kmpc_free_shared:
5019         // Return without setting a fixpoint, to be resolved in updateImpl.
5020         return;
5021       default:
5022         // Unknown OpenMP runtime calls cannot be executed in SPMD-mode,
5023         // generally. However, they do not hide parallel regions.
5024         SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5025         SPMDCompatibilityTracker.insert(&CB);
5026         break;
5027       }
5028       // All other OpenMP runtime calls will not reach parallel regions so they
5029       // can be safely ignored for now. Since it is a known OpenMP runtime call
5030       // we have now modeled all effects and there is no need for any update.
5031       indicateOptimisticFixpoint();
5032     };
5033 
5034     const auto *AACE =
5035         A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::OPTIONAL);
5036     if (!AACE || !AACE->getState().isValidState() || AACE->hasUnknownCallee()) {
5037       CheckCallee(getAssociatedFunction(), 1);
5038       return;
5039     }
5040     const auto &OptimisticEdges = AACE->getOptimisticEdges();
5041     for (auto *Callee : OptimisticEdges) {
5042       CheckCallee(Callee, OptimisticEdges.size());
5043       if (isAtFixpoint())
5044         break;
5045     }
5046   }
5047 
5048   ChangeStatus updateImpl(Attributor &A) override {
5049     // TODO: Once we have call site specific value information we can provide
5050     //       call site specific liveness information and then it makes
5051     //       sense to specialize attributes for call sites arguments instead of
5052     //       redirecting requests to the callee argument.
5053     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
5054     KernelInfoState StateBefore = getState();
5055 
5056     auto CheckCallee = [&](Function *F, int NumCallees) {
5057       const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(F);
5058 
5059       // If F is not a runtime function, propagate the AAKernelInfo of the
5060       // callee.
5061       if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
5062         const IRPosition &FnPos = IRPosition::function(*F);
5063         auto *FnAA =
5064             A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED);
5065         if (!FnAA)
5066           return indicatePessimisticFixpoint();
5067         if (getState() == FnAA->getState())
5068           return ChangeStatus::UNCHANGED;
5069         getState() = FnAA->getState();
5070         return ChangeStatus::CHANGED;
5071       }
5072       if (NumCallees > 1)
5073         return indicatePessimisticFixpoint();
5074 
5075       CallBase &CB = cast<CallBase>(getAssociatedValue());
5076       if (It->getSecond() == OMPRTL___kmpc_parallel_51) {
5077         if (!handleParallel51(A, CB))
5078           return indicatePessimisticFixpoint();
5079         return StateBefore == getState() ? ChangeStatus::UNCHANGED
5080                                          : ChangeStatus::CHANGED;
5081       }
5082 
5083       // F is a runtime function that allocates or frees memory, check
5084       // AAHeapToStack and AAHeapToShared.
5085       assert(
5086           (It->getSecond() == OMPRTL___kmpc_alloc_shared ||
5087            It->getSecond() == OMPRTL___kmpc_free_shared) &&
5088           "Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call");
5089 
5090       auto *HeapToStackAA = A.getAAFor<AAHeapToStack>(
5091           *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
5092       auto *HeapToSharedAA = A.getAAFor<AAHeapToShared>(
5093           *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
5094 
5095       RuntimeFunction RF = It->getSecond();
5096 
5097       switch (RF) {
5098       // If neither HeapToStack nor HeapToShared assume the call is removed,
5099       // assume SPMD incompatibility.
5100       case OMPRTL___kmpc_alloc_shared:
5101         if ((!HeapToStackAA || !HeapToStackAA->isAssumedHeapToStack(CB)) &&
5102             (!HeapToSharedAA || !HeapToSharedAA->isAssumedHeapToShared(CB)))
5103           SPMDCompatibilityTracker.insert(&CB);
5104         break;
5105       case OMPRTL___kmpc_free_shared:
5106         if ((!HeapToStackAA ||
5107              !HeapToStackAA->isAssumedHeapToStackRemovedFree(CB)) &&
5108             (!HeapToSharedAA ||
5109              !HeapToSharedAA->isAssumedHeapToSharedRemovedFree(CB)))
5110           SPMDCompatibilityTracker.insert(&CB);
5111         break;
5112       default:
5113         SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5114         SPMDCompatibilityTracker.insert(&CB);
5115       }
5116       return ChangeStatus::CHANGED;
5117     };
5118 
5119     const auto *AACE =
5120         A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::OPTIONAL);
5121     if (!AACE || !AACE->getState().isValidState() || AACE->hasUnknownCallee()) {
5122       if (Function *F = getAssociatedFunction())
5123         CheckCallee(F, /*NumCallees=*/1);
5124     } else {
5125       const auto &OptimisticEdges = AACE->getOptimisticEdges();
5126       for (auto *Callee : OptimisticEdges) {
5127         CheckCallee(Callee, OptimisticEdges.size());
5128         if (isAtFixpoint())
5129           break;
5130       }
5131     }
5132 
5133     return StateBefore == getState() ? ChangeStatus::UNCHANGED
5134                                      : ChangeStatus::CHANGED;
5135   }
5136 
5137   /// Deal with a __kmpc_parallel_51 call (\p CB). Returns true if the call was
5138   /// handled, if a problem occurred, false is returned.
5139   bool handleParallel51(Attributor &A, CallBase &CB) {
5140     const unsigned int NonWrapperFunctionArgNo = 5;
5141     const unsigned int WrapperFunctionArgNo = 6;
5142     auto ParallelRegionOpArgNo = SPMDCompatibilityTracker.isAssumed()
5143                                      ? NonWrapperFunctionArgNo
5144                                      : WrapperFunctionArgNo;
5145 
5146     auto *ParallelRegion = dyn_cast<Function>(
5147         CB.getArgOperand(ParallelRegionOpArgNo)->stripPointerCasts());
5148     if (!ParallelRegion)
5149       return false;
5150 
5151     ReachedKnownParallelRegions.insert(&CB);
5152     /// Check nested parallelism
5153     auto *FnAA = A.getAAFor<AAKernelInfo>(
5154         *this, IRPosition::function(*ParallelRegion), DepClassTy::OPTIONAL);
5155     NestedParallelism |= !FnAA || !FnAA->getState().isValidState() ||
5156                          !FnAA->ReachedKnownParallelRegions.empty() ||
5157                          !FnAA->ReachedKnownParallelRegions.isValidState() ||
5158                          !FnAA->ReachedUnknownParallelRegions.isValidState() ||
5159                          !FnAA->ReachedUnknownParallelRegions.empty();
5160     return true;
5161   }
5162 };
5163 
5164 struct AAFoldRuntimeCall
5165     : public StateWrapper<BooleanState, AbstractAttribute> {
5166   using Base = StateWrapper<BooleanState, AbstractAttribute>;
5167 
5168   AAFoldRuntimeCall(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
5169 
5170   /// Statistics are tracked as part of manifest for now.
5171   void trackStatistics() const override {}
5172 
5173   /// Create an abstract attribute biew for the position \p IRP.
5174   static AAFoldRuntimeCall &createForPosition(const IRPosition &IRP,
5175                                               Attributor &A);
5176 
5177   /// See AbstractAttribute::getName()
5178   const std::string getName() const override { return "AAFoldRuntimeCall"; }
5179 
5180   /// See AbstractAttribute::getIdAddr()
5181   const char *getIdAddr() const override { return &ID; }
5182 
5183   /// This function should return true if the type of the \p AA is
5184   /// AAFoldRuntimeCall
5185   static bool classof(const AbstractAttribute *AA) {
5186     return (AA->getIdAddr() == &ID);
5187   }
5188 
5189   static const char ID;
5190 };
5191 
5192 struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
5193   AAFoldRuntimeCallCallSiteReturned(const IRPosition &IRP, Attributor &A)
5194       : AAFoldRuntimeCall(IRP, A) {}
5195 
5196   /// See AbstractAttribute::getAsStr()
5197   const std::string getAsStr(Attributor *) const override {
5198     if (!isValidState())
5199       return "<invalid>";
5200 
5201     std::string Str("simplified value: ");
5202 
5203     if (!SimplifiedValue)
5204       return Str + std::string("none");
5205 
5206     if (!*SimplifiedValue)
5207       return Str + std::string("nullptr");
5208 
5209     if (ConstantInt *CI = dyn_cast<ConstantInt>(*SimplifiedValue))
5210       return Str + std::to_string(CI->getSExtValue());
5211 
5212     return Str + std::string("unknown");
5213   }
5214 
5215   void initialize(Attributor &A) override {
5216     if (DisableOpenMPOptFolding)
5217       indicatePessimisticFixpoint();
5218 
5219     Function *Callee = getAssociatedFunction();
5220 
5221     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
5222     const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
5223     assert(It != OMPInfoCache.RuntimeFunctionIDMap.end() &&
5224            "Expected a known OpenMP runtime function");
5225 
5226     RFKind = It->getSecond();
5227 
5228     CallBase &CB = cast<CallBase>(getAssociatedValue());
5229     A.registerSimplificationCallback(
5230         IRPosition::callsite_returned(CB),
5231         [&](const IRPosition &IRP, const AbstractAttribute *AA,
5232             bool &UsedAssumedInformation) -> std::optional<Value *> {
5233           assert((isValidState() ||
5234                   (SimplifiedValue && *SimplifiedValue == nullptr)) &&
5235                  "Unexpected invalid state!");
5236 
5237           if (!isAtFixpoint()) {
5238             UsedAssumedInformation = true;
5239             if (AA)
5240               A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
5241           }
5242           return SimplifiedValue;
5243         });
5244   }
5245 
5246   ChangeStatus updateImpl(Attributor &A) override {
5247     ChangeStatus Changed = ChangeStatus::UNCHANGED;
5248     switch (RFKind) {
5249     case OMPRTL___kmpc_is_spmd_exec_mode:
5250       Changed |= foldIsSPMDExecMode(A);
5251       break;
5252     case OMPRTL___kmpc_parallel_level:
5253       Changed |= foldParallelLevel(A);
5254       break;
5255     case OMPRTL___kmpc_get_hardware_num_threads_in_block:
5256       Changed = Changed | foldKernelFnAttribute(A, "omp_target_thread_limit");
5257       break;
5258     case OMPRTL___kmpc_get_hardware_num_blocks:
5259       Changed = Changed | foldKernelFnAttribute(A, "omp_target_num_teams");
5260       break;
5261     default:
5262       llvm_unreachable("Unhandled OpenMP runtime function!");
5263     }
5264 
5265     return Changed;
5266   }
5267 
5268   ChangeStatus manifest(Attributor &A) override {
5269     ChangeStatus Changed = ChangeStatus::UNCHANGED;
5270 
5271     if (SimplifiedValue && *SimplifiedValue) {
5272       Instruction &I = *getCtxI();
5273       A.changeAfterManifest(IRPosition::inst(I), **SimplifiedValue);
5274       A.deleteAfterManifest(I);
5275 
5276       CallBase *CB = dyn_cast<CallBase>(&I);
5277       auto Remark = [&](OptimizationRemark OR) {
5278         if (auto *C = dyn_cast<ConstantInt>(*SimplifiedValue))
5279           return OR << "Replacing OpenMP runtime call "
5280                     << CB->getCalledFunction()->getName() << " with "
5281                     << ore::NV("FoldedValue", C->getZExtValue()) << ".";
5282         return OR << "Replacing OpenMP runtime call "
5283                   << CB->getCalledFunction()->getName() << ".";
5284       };
5285 
5286       if (CB && EnableVerboseRemarks)
5287         A.emitRemark<OptimizationRemark>(CB, "OMP180", Remark);
5288 
5289       LLVM_DEBUG(dbgs() << TAG << "Replacing runtime call: " << I << " with "
5290                         << **SimplifiedValue << "\n");
5291 
5292       Changed = ChangeStatus::CHANGED;
5293     }
5294 
5295     return Changed;
5296   }
5297 
5298   ChangeStatus indicatePessimisticFixpoint() override {
5299     SimplifiedValue = nullptr;
5300     return AAFoldRuntimeCall::indicatePessimisticFixpoint();
5301   }
5302 
5303 private:
5304   /// Fold __kmpc_is_spmd_exec_mode into a constant if possible.
5305   ChangeStatus foldIsSPMDExecMode(Attributor &A) {
5306     std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
5307 
5308     unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
5309     unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
5310     auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
5311         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
5312 
5313     if (!CallerKernelInfoAA ||
5314         !CallerKernelInfoAA->ReachingKernelEntries.isValidState())
5315       return indicatePessimisticFixpoint();
5316 
5317     for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
5318       auto *AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
5319                                           DepClassTy::REQUIRED);
5320 
5321       if (!AA || !AA->isValidState()) {
5322         SimplifiedValue = nullptr;
5323         return indicatePessimisticFixpoint();
5324       }
5325 
5326       if (AA->SPMDCompatibilityTracker.isAssumed()) {
5327         if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5328           ++KnownSPMDCount;
5329         else
5330           ++AssumedSPMDCount;
5331       } else {
5332         if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5333           ++KnownNonSPMDCount;
5334         else
5335           ++AssumedNonSPMDCount;
5336       }
5337     }
5338 
5339     if ((AssumedSPMDCount + KnownSPMDCount) &&
5340         (AssumedNonSPMDCount + KnownNonSPMDCount))
5341       return indicatePessimisticFixpoint();
5342 
5343     auto &Ctx = getAnchorValue().getContext();
5344     if (KnownSPMDCount || AssumedSPMDCount) {
5345       assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
5346              "Expected only SPMD kernels!");
5347       // All reaching kernels are in SPMD mode. Update all function calls to
5348       // __kmpc_is_spmd_exec_mode to 1.
5349       SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), true);
5350     } else if (KnownNonSPMDCount || AssumedNonSPMDCount) {
5351       assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
5352              "Expected only non-SPMD kernels!");
5353       // All reaching kernels are in non-SPMD mode. Update all function
5354       // calls to __kmpc_is_spmd_exec_mode to 0.
5355       SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), false);
5356     } else {
5357       // We have empty reaching kernels, therefore we cannot tell if the
5358       // associated call site can be folded. At this moment, SimplifiedValue
5359       // must be none.
5360       assert(!SimplifiedValue && "SimplifiedValue should be none");
5361     }
5362 
5363     return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
5364                                                     : ChangeStatus::CHANGED;
5365   }
5366 
5367   /// Fold __kmpc_parallel_level into a constant if possible.
5368   ChangeStatus foldParallelLevel(Attributor &A) {
5369     std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
5370 
5371     auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
5372         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
5373 
5374     if (!CallerKernelInfoAA ||
5375         !CallerKernelInfoAA->ParallelLevels.isValidState())
5376       return indicatePessimisticFixpoint();
5377 
5378     if (!CallerKernelInfoAA->ReachingKernelEntries.isValidState())
5379       return indicatePessimisticFixpoint();
5380 
5381     if (CallerKernelInfoAA->ReachingKernelEntries.empty()) {
5382       assert(!SimplifiedValue &&
5383              "SimplifiedValue should keep none at this point");
5384       return ChangeStatus::UNCHANGED;
5385     }
5386 
5387     unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
5388     unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
5389     for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
5390       auto *AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
5391                                           DepClassTy::REQUIRED);
5392       if (!AA || !AA->SPMDCompatibilityTracker.isValidState())
5393         return indicatePessimisticFixpoint();
5394 
5395       if (AA->SPMDCompatibilityTracker.isAssumed()) {
5396         if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5397           ++KnownSPMDCount;
5398         else
5399           ++AssumedSPMDCount;
5400       } else {
5401         if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5402           ++KnownNonSPMDCount;
5403         else
5404           ++AssumedNonSPMDCount;
5405       }
5406     }
5407 
5408     if ((AssumedSPMDCount + KnownSPMDCount) &&
5409         (AssumedNonSPMDCount + KnownNonSPMDCount))
5410       return indicatePessimisticFixpoint();
5411 
5412     auto &Ctx = getAnchorValue().getContext();
5413     // If the caller can only be reached by SPMD kernel entries, the parallel
5414     // level is 1. Similarly, if the caller can only be reached by non-SPMD
5415     // kernel entries, it is 0.
5416     if (AssumedSPMDCount || KnownSPMDCount) {
5417       assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
5418              "Expected only SPMD kernels!");
5419       SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 1);
5420     } else {
5421       assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
5422              "Expected only non-SPMD kernels!");
5423       SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 0);
5424     }
5425     return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
5426                                                     : ChangeStatus::CHANGED;
5427   }
5428 
5429   ChangeStatus foldKernelFnAttribute(Attributor &A, llvm::StringRef Attr) {
5430     // Specialize only if all the calls agree with the attribute constant value
5431     int32_t CurrentAttrValue = -1;
5432     std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
5433 
5434     auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
5435         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
5436 
5437     if (!CallerKernelInfoAA ||
5438         !CallerKernelInfoAA->ReachingKernelEntries.isValidState())
5439       return indicatePessimisticFixpoint();
5440 
5441     // Iterate over the kernels that reach this function
5442     for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
5443       int32_t NextAttrVal = K->getFnAttributeAsParsedInteger(Attr, -1);
5444 
5445       if (NextAttrVal == -1 ||
5446           (CurrentAttrValue != -1 && CurrentAttrValue != NextAttrVal))
5447         return indicatePessimisticFixpoint();
5448       CurrentAttrValue = NextAttrVal;
5449     }
5450 
5451     if (CurrentAttrValue != -1) {
5452       auto &Ctx = getAnchorValue().getContext();
5453       SimplifiedValue =
5454           ConstantInt::get(Type::getInt32Ty(Ctx), CurrentAttrValue);
5455     }
5456     return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
5457                                                     : ChangeStatus::CHANGED;
5458   }
5459 
5460   /// An optional value the associated value is assumed to fold to. That is, we
5461   /// assume the associated value (which is a call) can be replaced by this
5462   /// simplified value.
5463   std::optional<Value *> SimplifiedValue;
5464 
5465   /// The runtime function kind of the callee of the associated call site.
5466   RuntimeFunction RFKind;
5467 };
5468 
5469 } // namespace
5470 
5471 /// Register folding callsite
5472 void OpenMPOpt::registerFoldRuntimeCall(RuntimeFunction RF) {
5473   auto &RFI = OMPInfoCache.RFIs[RF];
5474   RFI.foreachUse(SCC, [&](Use &U, Function &F) {
5475     CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &RFI);
5476     if (!CI)
5477       return false;
5478     A.getOrCreateAAFor<AAFoldRuntimeCall>(
5479         IRPosition::callsite_returned(*CI), /* QueryingAA */ nullptr,
5480         DepClassTy::NONE, /* ForceUpdate */ false,
5481         /* UpdateAfterInit */ false);
5482     return false;
5483   });
5484 }
5485 
5486 void OpenMPOpt::registerAAs(bool IsModulePass) {
5487   if (SCC.empty())
5488     return;
5489 
5490   if (IsModulePass) {
5491     // Ensure we create the AAKernelInfo AAs first and without triggering an
5492     // update. This will make sure we register all value simplification
5493     // callbacks before any other AA has the chance to create an AAValueSimplify
5494     // or similar.
5495     auto CreateKernelInfoCB = [&](Use &, Function &Kernel) {
5496       A.getOrCreateAAFor<AAKernelInfo>(
5497           IRPosition::function(Kernel), /* QueryingAA */ nullptr,
5498           DepClassTy::NONE, /* ForceUpdate */ false,
5499           /* UpdateAfterInit */ false);
5500       return false;
5501     };
5502     OMPInformationCache::RuntimeFunctionInfo &InitRFI =
5503         OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
5504     InitRFI.foreachUse(SCC, CreateKernelInfoCB);
5505 
5506     registerFoldRuntimeCall(OMPRTL___kmpc_is_spmd_exec_mode);
5507     registerFoldRuntimeCall(OMPRTL___kmpc_parallel_level);
5508     registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_threads_in_block);
5509     registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_blocks);
5510   }
5511 
5512   // Create CallSite AA for all Getters.
5513   if (DeduceICVValues) {
5514     for (int Idx = 0; Idx < OMPInfoCache.ICVs.size() - 1; ++Idx) {
5515       auto ICVInfo = OMPInfoCache.ICVs[static_cast<InternalControlVar>(Idx)];
5516 
5517       auto &GetterRFI = OMPInfoCache.RFIs[ICVInfo.Getter];
5518 
5519       auto CreateAA = [&](Use &U, Function &Caller) {
5520         CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &GetterRFI);
5521         if (!CI)
5522           return false;
5523 
5524         auto &CB = cast<CallBase>(*CI);
5525 
5526         IRPosition CBPos = IRPosition::callsite_function(CB);
5527         A.getOrCreateAAFor<AAICVTracker>(CBPos);
5528         return false;
5529       };
5530 
5531       GetterRFI.foreachUse(SCC, CreateAA);
5532     }
5533   }
5534 
5535   // Create an ExecutionDomain AA for every function and a HeapToStack AA for
5536   // every function if there is a device kernel.
5537   if (!isOpenMPDevice(M))
5538     return;
5539 
5540   for (auto *F : SCC) {
5541     if (F->isDeclaration())
5542       continue;
5543 
5544     // We look at internal functions only on-demand but if any use is not a
5545     // direct call or outside the current set of analyzed functions, we have
5546     // to do it eagerly.
5547     if (F->hasLocalLinkage()) {
5548       if (llvm::all_of(F->uses(), [this](const Use &U) {
5549             const auto *CB = dyn_cast<CallBase>(U.getUser());
5550             return CB && CB->isCallee(&U) &&
5551                    A.isRunOn(const_cast<Function *>(CB->getCaller()));
5552           }))
5553         continue;
5554     }
5555     registerAAsForFunction(A, *F);
5556   }
5557 }
5558 
5559 void OpenMPOpt::registerAAsForFunction(Attributor &A, const Function &F) {
5560   if (!DisableOpenMPOptDeglobalization)
5561     A.getOrCreateAAFor<AAHeapToShared>(IRPosition::function(F));
5562   A.getOrCreateAAFor<AAExecutionDomain>(IRPosition::function(F));
5563   if (!DisableOpenMPOptDeglobalization)
5564     A.getOrCreateAAFor<AAHeapToStack>(IRPosition::function(F));
5565   if (F.hasFnAttribute(Attribute::Convergent))
5566     A.getOrCreateAAFor<AANonConvergent>(IRPosition::function(F));
5567 
5568   for (auto &I : instructions(F)) {
5569     if (auto *LI = dyn_cast<LoadInst>(&I)) {
5570       bool UsedAssumedInformation = false;
5571       A.getAssumedSimplified(IRPosition::value(*LI), /* AA */ nullptr,
5572                              UsedAssumedInformation, AA::Interprocedural);
5573       continue;
5574     }
5575     if (auto *CI = dyn_cast<CallBase>(&I)) {
5576       if (CI->isIndirectCall())
5577         A.getOrCreateAAFor<AAIndirectCallInfo>(
5578             IRPosition::callsite_function(*CI));
5579     }
5580     if (auto *SI = dyn_cast<StoreInst>(&I)) {
5581       A.getOrCreateAAFor<AAIsDead>(IRPosition::value(*SI));
5582       continue;
5583     }
5584     if (auto *FI = dyn_cast<FenceInst>(&I)) {
5585       A.getOrCreateAAFor<AAIsDead>(IRPosition::value(*FI));
5586       continue;
5587     }
5588     if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
5589       if (II->getIntrinsicID() == Intrinsic::assume) {
5590         A.getOrCreateAAFor<AAPotentialValues>(
5591             IRPosition::value(*II->getArgOperand(0)));
5592         continue;
5593       }
5594     }
5595   }
5596 }
5597 
5598 const char AAICVTracker::ID = 0;
5599 const char AAKernelInfo::ID = 0;
5600 const char AAExecutionDomain::ID = 0;
5601 const char AAHeapToShared::ID = 0;
5602 const char AAFoldRuntimeCall::ID = 0;
5603 
5604 AAICVTracker &AAICVTracker::createForPosition(const IRPosition &IRP,
5605                                               Attributor &A) {
5606   AAICVTracker *AA = nullptr;
5607   switch (IRP.getPositionKind()) {
5608   case IRPosition::IRP_INVALID:
5609   case IRPosition::IRP_FLOAT:
5610   case IRPosition::IRP_ARGUMENT:
5611   case IRPosition::IRP_CALL_SITE_ARGUMENT:
5612     llvm_unreachable("ICVTracker can only be created for function position!");
5613   case IRPosition::IRP_RETURNED:
5614     AA = new (A.Allocator) AAICVTrackerFunctionReturned(IRP, A);
5615     break;
5616   case IRPosition::IRP_CALL_SITE_RETURNED:
5617     AA = new (A.Allocator) AAICVTrackerCallSiteReturned(IRP, A);
5618     break;
5619   case IRPosition::IRP_CALL_SITE:
5620     AA = new (A.Allocator) AAICVTrackerCallSite(IRP, A);
5621     break;
5622   case IRPosition::IRP_FUNCTION:
5623     AA = new (A.Allocator) AAICVTrackerFunction(IRP, A);
5624     break;
5625   }
5626 
5627   return *AA;
5628 }
5629 
5630 AAExecutionDomain &AAExecutionDomain::createForPosition(const IRPosition &IRP,
5631                                                         Attributor &A) {
5632   AAExecutionDomainFunction *AA = nullptr;
5633   switch (IRP.getPositionKind()) {
5634   case IRPosition::IRP_INVALID:
5635   case IRPosition::IRP_FLOAT:
5636   case IRPosition::IRP_ARGUMENT:
5637   case IRPosition::IRP_CALL_SITE_ARGUMENT:
5638   case IRPosition::IRP_RETURNED:
5639   case IRPosition::IRP_CALL_SITE_RETURNED:
5640   case IRPosition::IRP_CALL_SITE:
5641     llvm_unreachable(
5642         "AAExecutionDomain can only be created for function position!");
5643   case IRPosition::IRP_FUNCTION:
5644     AA = new (A.Allocator) AAExecutionDomainFunction(IRP, A);
5645     break;
5646   }
5647 
5648   return *AA;
5649 }
5650 
5651 AAHeapToShared &AAHeapToShared::createForPosition(const IRPosition &IRP,
5652                                                   Attributor &A) {
5653   AAHeapToSharedFunction *AA = nullptr;
5654   switch (IRP.getPositionKind()) {
5655   case IRPosition::IRP_INVALID:
5656   case IRPosition::IRP_FLOAT:
5657   case IRPosition::IRP_ARGUMENT:
5658   case IRPosition::IRP_CALL_SITE_ARGUMENT:
5659   case IRPosition::IRP_RETURNED:
5660   case IRPosition::IRP_CALL_SITE_RETURNED:
5661   case IRPosition::IRP_CALL_SITE:
5662     llvm_unreachable(
5663         "AAHeapToShared can only be created for function position!");
5664   case IRPosition::IRP_FUNCTION:
5665     AA = new (A.Allocator) AAHeapToSharedFunction(IRP, A);
5666     break;
5667   }
5668 
5669   return *AA;
5670 }
5671 
5672 AAKernelInfo &AAKernelInfo::createForPosition(const IRPosition &IRP,
5673                                               Attributor &A) {
5674   AAKernelInfo *AA = nullptr;
5675   switch (IRP.getPositionKind()) {
5676   case IRPosition::IRP_INVALID:
5677   case IRPosition::IRP_FLOAT:
5678   case IRPosition::IRP_ARGUMENT:
5679   case IRPosition::IRP_RETURNED:
5680   case IRPosition::IRP_CALL_SITE_RETURNED:
5681   case IRPosition::IRP_CALL_SITE_ARGUMENT:
5682     llvm_unreachable("KernelInfo can only be created for function position!");
5683   case IRPosition::IRP_CALL_SITE:
5684     AA = new (A.Allocator) AAKernelInfoCallSite(IRP, A);
5685     break;
5686   case IRPosition::IRP_FUNCTION:
5687     AA = new (A.Allocator) AAKernelInfoFunction(IRP, A);
5688     break;
5689   }
5690 
5691   return *AA;
5692 }
5693 
5694 AAFoldRuntimeCall &AAFoldRuntimeCall::createForPosition(const IRPosition &IRP,
5695                                                         Attributor &A) {
5696   AAFoldRuntimeCall *AA = nullptr;
5697   switch (IRP.getPositionKind()) {
5698   case IRPosition::IRP_INVALID:
5699   case IRPosition::IRP_FLOAT:
5700   case IRPosition::IRP_ARGUMENT:
5701   case IRPosition::IRP_RETURNED:
5702   case IRPosition::IRP_FUNCTION:
5703   case IRPosition::IRP_CALL_SITE:
5704   case IRPosition::IRP_CALL_SITE_ARGUMENT:
5705     llvm_unreachable("KernelInfo can only be created for call site position!");
5706   case IRPosition::IRP_CALL_SITE_RETURNED:
5707     AA = new (A.Allocator) AAFoldRuntimeCallCallSiteReturned(IRP, A);
5708     break;
5709   }
5710 
5711   return *AA;
5712 }
5713 
5714 PreservedAnalyses OpenMPOptPass::run(Module &M, ModuleAnalysisManager &AM) {
5715   if (!containsOpenMP(M))
5716     return PreservedAnalyses::all();
5717   if (DisableOpenMPOptimizations)
5718     return PreservedAnalyses::all();
5719 
5720   FunctionAnalysisManager &FAM =
5721       AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
5722   KernelSet Kernels = getDeviceKernels(M);
5723 
5724   if (PrintModuleBeforeOptimizations)
5725     LLVM_DEBUG(dbgs() << TAG << "Module before OpenMPOpt Module Pass:\n" << M);
5726 
5727   auto IsCalled = [&](Function &F) {
5728     if (Kernels.contains(&F))
5729       return true;
5730     for (const User *U : F.users())
5731       if (!isa<BlockAddress>(U))
5732         return true;
5733     return false;
5734   };
5735 
5736   auto EmitRemark = [&](Function &F) {
5737     auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F);
5738     ORE.emit([&]() {
5739       OptimizationRemarkAnalysis ORA(DEBUG_TYPE, "OMP140", &F);
5740       return ORA << "Could not internalize function. "
5741                  << "Some optimizations may not be possible. [OMP140]";
5742     });
5743   };
5744 
5745   bool Changed = false;
5746 
5747   // Create internal copies of each function if this is a kernel Module. This
5748   // allows iterprocedural passes to see every call edge.
5749   DenseMap<Function *, Function *> InternalizedMap;
5750   if (isOpenMPDevice(M)) {
5751     SmallPtrSet<Function *, 16> InternalizeFns;
5752     for (Function &F : M)
5753       if (!F.isDeclaration() && !Kernels.contains(&F) && IsCalled(F) &&
5754           !DisableInternalization) {
5755         if (Attributor::isInternalizable(F)) {
5756           InternalizeFns.insert(&F);
5757         } else if (!F.hasLocalLinkage() && !F.hasFnAttribute(Attribute::Cold)) {
5758           EmitRemark(F);
5759         }
5760       }
5761 
5762     Changed |=
5763         Attributor::internalizeFunctions(InternalizeFns, InternalizedMap);
5764   }
5765 
5766   // Look at every function in the Module unless it was internalized.
5767   SetVector<Function *> Functions;
5768   SmallVector<Function *, 16> SCC;
5769   for (Function &F : M)
5770     if (!F.isDeclaration() && !InternalizedMap.lookup(&F)) {
5771       SCC.push_back(&F);
5772       Functions.insert(&F);
5773     }
5774 
5775   if (SCC.empty())
5776     return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
5777 
5778   AnalysisGetter AG(FAM);
5779 
5780   auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
5781     return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);
5782   };
5783 
5784   BumpPtrAllocator Allocator;
5785   CallGraphUpdater CGUpdater;
5786 
5787   bool PostLink = LTOPhase == ThinOrFullLTOPhase::FullLTOPostLink ||
5788                   LTOPhase == ThinOrFullLTOPhase::ThinLTOPreLink;
5789   OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ nullptr, PostLink);
5790 
5791   unsigned MaxFixpointIterations =
5792       (isOpenMPDevice(M)) ? SetFixpointIterations : 32;
5793 
5794   AttributorConfig AC(CGUpdater);
5795   AC.DefaultInitializeLiveInternals = false;
5796   AC.IsModulePass = true;
5797   AC.RewriteSignatures = false;
5798   AC.MaxFixpointIterations = MaxFixpointIterations;
5799   AC.OREGetter = OREGetter;
5800   AC.PassName = DEBUG_TYPE;
5801   AC.InitializationCallback = OpenMPOpt::registerAAsForFunction;
5802   AC.IPOAmendableCB = [](const Function &F) {
5803     return F.hasFnAttribute("kernel");
5804   };
5805 
5806   Attributor A(Functions, InfoCache, AC);
5807 
5808   OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
5809   Changed |= OMPOpt.run(true);
5810 
5811   // Optionally inline device functions for potentially better performance.
5812   if (AlwaysInlineDeviceFunctions && isOpenMPDevice(M))
5813     for (Function &F : M)
5814       if (!F.isDeclaration() && !Kernels.contains(&F) &&
5815           !F.hasFnAttribute(Attribute::NoInline))
5816         F.addFnAttr(Attribute::AlwaysInline);
5817 
5818   if (PrintModuleAfterOptimizations)
5819     LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt Module Pass:\n" << M);
5820 
5821   if (Changed)
5822     return PreservedAnalyses::none();
5823 
5824   return PreservedAnalyses::all();
5825 }
5826 
5827 PreservedAnalyses OpenMPOptCGSCCPass::run(LazyCallGraph::SCC &C,
5828                                           CGSCCAnalysisManager &AM,
5829                                           LazyCallGraph &CG,
5830                                           CGSCCUpdateResult &UR) {
5831   if (!containsOpenMP(*C.begin()->getFunction().getParent()))
5832     return PreservedAnalyses::all();
5833   if (DisableOpenMPOptimizations)
5834     return PreservedAnalyses::all();
5835 
5836   SmallVector<Function *, 16> SCC;
5837   // If there are kernels in the module, we have to run on all SCC's.
5838   for (LazyCallGraph::Node &N : C) {
5839     Function *Fn = &N.getFunction();
5840     SCC.push_back(Fn);
5841   }
5842 
5843   if (SCC.empty())
5844     return PreservedAnalyses::all();
5845 
5846   Module &M = *C.begin()->getFunction().getParent();
5847 
5848   if (PrintModuleBeforeOptimizations)
5849     LLVM_DEBUG(dbgs() << TAG << "Module before OpenMPOpt CGSCC Pass:\n" << M);
5850 
5851   KernelSet Kernels = getDeviceKernels(M);
5852 
5853   FunctionAnalysisManager &FAM =
5854       AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();
5855 
5856   AnalysisGetter AG(FAM);
5857 
5858   auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
5859     return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);
5860   };
5861 
5862   BumpPtrAllocator Allocator;
5863   CallGraphUpdater CGUpdater;
5864   CGUpdater.initialize(CG, C, AM, UR);
5865 
5866   bool PostLink = LTOPhase == ThinOrFullLTOPhase::FullLTOPostLink ||
5867                   LTOPhase == ThinOrFullLTOPhase::ThinLTOPreLink;
5868   SetVector<Function *> Functions(SCC.begin(), SCC.end());
5869   OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator,
5870                                 /*CGSCC*/ &Functions, PostLink);
5871 
5872   unsigned MaxFixpointIterations =
5873       (isOpenMPDevice(M)) ? SetFixpointIterations : 32;
5874 
5875   AttributorConfig AC(CGUpdater);
5876   AC.DefaultInitializeLiveInternals = false;
5877   AC.IsModulePass = false;
5878   AC.RewriteSignatures = false;
5879   AC.MaxFixpointIterations = MaxFixpointIterations;
5880   AC.OREGetter = OREGetter;
5881   AC.PassName = DEBUG_TYPE;
5882   AC.InitializationCallback = OpenMPOpt::registerAAsForFunction;
5883 
5884   Attributor A(Functions, InfoCache, AC);
5885 
5886   OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
5887   bool Changed = OMPOpt.run(false);
5888 
5889   if (PrintModuleAfterOptimizations)
5890     LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt CGSCC Pass:\n" << M);
5891 
5892   if (Changed)
5893     return PreservedAnalyses::none();
5894 
5895   return PreservedAnalyses::all();
5896 }
5897 
5898 bool llvm::omp::isOpenMPKernel(Function &Fn) {
5899   return Fn.hasFnAttribute("kernel");
5900 }
5901 
5902 KernelSet llvm::omp::getDeviceKernels(Module &M) {
5903   // TODO: Create a more cross-platform way of determining device kernels.
5904   NamedMDNode *MD = M.getNamedMetadata("nvvm.annotations");
5905   KernelSet Kernels;
5906 
5907   if (!MD)
5908     return Kernels;
5909 
5910   for (auto *Op : MD->operands()) {
5911     if (Op->getNumOperands() < 2)
5912       continue;
5913     MDString *KindID = dyn_cast<MDString>(Op->getOperand(1));
5914     if (!KindID || KindID->getString() != "kernel")
5915       continue;
5916 
5917     Function *KernelFn =
5918         mdconst::dyn_extract_or_null<Function>(Op->getOperand(0));
5919     if (!KernelFn)
5920       continue;
5921 
5922     // We are only interested in OpenMP target regions. Others, such as kernels
5923     // generated by CUDA but linked together, are not interesting to this pass.
5924     if (isOpenMPKernel(*KernelFn)) {
5925       ++NumOpenMPTargetRegionKernels;
5926       Kernels.insert(KernelFn);
5927     } else
5928       ++NumNonOpenMPTargetRegionKernels;
5929   }
5930 
5931   return Kernels;
5932 }
5933 
5934 bool llvm::omp::containsOpenMP(Module &M) {
5935   Metadata *MD = M.getModuleFlag("openmp");
5936   if (!MD)
5937     return false;
5938 
5939   return true;
5940 }
5941 
5942 bool llvm::omp::isOpenMPDevice(Module &M) {
5943   Metadata *MD = M.getModuleFlag("openmp-device");
5944   if (!MD)
5945     return false;
5946 
5947   return true;
5948 }
5949