xref: /freebsd/contrib/llvm-project/llvm/lib/Transforms/IPO/OpenMPOpt.cpp (revision 1db9f3b21e39176dd5b67cf8ac378633b172463e)
1 //===-- IPO/OpenMPOpt.cpp - Collection of OpenMP specific optimizations ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OpenMP specific optimizations:
10 //
11 // - Deduplication of runtime calls, e.g., omp_get_thread_num.
12 // - Replacing globalized device memory with stack memory.
13 // - Replacing globalized device memory with shared memory.
14 // - Parallel region merging.
15 // - Transforming generic-mode device kernels to SPMD mode.
16 // - Specializing the state machine for generic-mode device kernels.
17 //
18 //===----------------------------------------------------------------------===//
19 
20 #include "llvm/Transforms/IPO/OpenMPOpt.h"
21 
22 #include "llvm/ADT/EnumeratedArray.h"
23 #include "llvm/ADT/PostOrderIterator.h"
24 #include "llvm/ADT/SetVector.h"
25 #include "llvm/ADT/SmallPtrSet.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "llvm/ADT/Statistic.h"
28 #include "llvm/ADT/StringExtras.h"
29 #include "llvm/ADT/StringRef.h"
30 #include "llvm/Analysis/CallGraph.h"
31 #include "llvm/Analysis/CallGraphSCCPass.h"
32 #include "llvm/Analysis/MemoryLocation.h"
33 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
34 #include "llvm/Analysis/ValueTracking.h"
35 #include "llvm/Frontend/OpenMP/OMPConstants.h"
36 #include "llvm/Frontend/OpenMP/OMPDeviceConstants.h"
37 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
38 #include "llvm/IR/Assumptions.h"
39 #include "llvm/IR/BasicBlock.h"
40 #include "llvm/IR/Constants.h"
41 #include "llvm/IR/DiagnosticInfo.h"
42 #include "llvm/IR/Dominators.h"
43 #include "llvm/IR/Function.h"
44 #include "llvm/IR/GlobalValue.h"
45 #include "llvm/IR/GlobalVariable.h"
46 #include "llvm/IR/InstrTypes.h"
47 #include "llvm/IR/Instruction.h"
48 #include "llvm/IR/Instructions.h"
49 #include "llvm/IR/IntrinsicInst.h"
50 #include "llvm/IR/IntrinsicsAMDGPU.h"
51 #include "llvm/IR/IntrinsicsNVPTX.h"
52 #include "llvm/IR/LLVMContext.h"
53 #include "llvm/Support/Casting.h"
54 #include "llvm/Support/CommandLine.h"
55 #include "llvm/Support/Debug.h"
56 #include "llvm/Transforms/IPO/Attributor.h"
57 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
58 #include "llvm/Transforms/Utils/CallGraphUpdater.h"
59 
60 #include <algorithm>
61 #include <optional>
62 #include <string>
63 
64 using namespace llvm;
65 using namespace omp;
66 
67 #define DEBUG_TYPE "openmp-opt"
68 
69 static cl::opt<bool> DisableOpenMPOptimizations(
70     "openmp-opt-disable", cl::desc("Disable OpenMP specific optimizations."),
71     cl::Hidden, cl::init(false));
72 
73 static cl::opt<bool> EnableParallelRegionMerging(
74     "openmp-opt-enable-merging",
75     cl::desc("Enable the OpenMP region merging optimization."), cl::Hidden,
76     cl::init(false));
77 
78 static cl::opt<bool>
79     DisableInternalization("openmp-opt-disable-internalization",
80                            cl::desc("Disable function internalization."),
81                            cl::Hidden, cl::init(false));
82 
83 static cl::opt<bool> DeduceICVValues("openmp-deduce-icv-values",
84                                      cl::init(false), cl::Hidden);
85 static cl::opt<bool> PrintICVValues("openmp-print-icv-values", cl::init(false),
86                                     cl::Hidden);
87 static cl::opt<bool> PrintOpenMPKernels("openmp-print-gpu-kernels",
88                                         cl::init(false), cl::Hidden);
89 
90 static cl::opt<bool> HideMemoryTransferLatency(
91     "openmp-hide-memory-transfer-latency",
92     cl::desc("[WIP] Tries to hide the latency of host to device memory"
93              " transfers"),
94     cl::Hidden, cl::init(false));
95 
96 static cl::opt<bool> DisableOpenMPOptDeglobalization(
97     "openmp-opt-disable-deglobalization",
98     cl::desc("Disable OpenMP optimizations involving deglobalization."),
99     cl::Hidden, cl::init(false));
100 
101 static cl::opt<bool> DisableOpenMPOptSPMDization(
102     "openmp-opt-disable-spmdization",
103     cl::desc("Disable OpenMP optimizations involving SPMD-ization."),
104     cl::Hidden, cl::init(false));
105 
106 static cl::opt<bool> DisableOpenMPOptFolding(
107     "openmp-opt-disable-folding",
108     cl::desc("Disable OpenMP optimizations involving folding."), cl::Hidden,
109     cl::init(false));
110 
111 static cl::opt<bool> DisableOpenMPOptStateMachineRewrite(
112     "openmp-opt-disable-state-machine-rewrite",
113     cl::desc("Disable OpenMP optimizations that replace the state machine."),
114     cl::Hidden, cl::init(false));
115 
116 static cl::opt<bool> DisableOpenMPOptBarrierElimination(
117     "openmp-opt-disable-barrier-elimination",
118     cl::desc("Disable OpenMP optimizations that eliminate barriers."),
119     cl::Hidden, cl::init(false));
120 
121 static cl::opt<bool> PrintModuleAfterOptimizations(
122     "openmp-opt-print-module-after",
123     cl::desc("Print the current module after OpenMP optimizations."),
124     cl::Hidden, cl::init(false));
125 
126 static cl::opt<bool> PrintModuleBeforeOptimizations(
127     "openmp-opt-print-module-before",
128     cl::desc("Print the current module before OpenMP optimizations."),
129     cl::Hidden, cl::init(false));
130 
131 static cl::opt<bool> AlwaysInlineDeviceFunctions(
132     "openmp-opt-inline-device",
133     cl::desc("Inline all applicible functions on the device."), cl::Hidden,
134     cl::init(false));
135 
136 static cl::opt<bool>
137     EnableVerboseRemarks("openmp-opt-verbose-remarks",
138                          cl::desc("Enables more verbose remarks."), cl::Hidden,
139                          cl::init(false));
140 
141 static cl::opt<unsigned>
142     SetFixpointIterations("openmp-opt-max-iterations", cl::Hidden,
143                           cl::desc("Maximal number of attributor iterations."),
144                           cl::init(256));
145 
146 static cl::opt<unsigned>
147     SharedMemoryLimit("openmp-opt-shared-limit", cl::Hidden,
148                       cl::desc("Maximum amount of shared memory to use."),
149                       cl::init(std::numeric_limits<unsigned>::max()));
150 
151 STATISTIC(NumOpenMPRuntimeCallsDeduplicated,
152           "Number of OpenMP runtime calls deduplicated");
153 STATISTIC(NumOpenMPParallelRegionsDeleted,
154           "Number of OpenMP parallel regions deleted");
155 STATISTIC(NumOpenMPRuntimeFunctionsIdentified,
156           "Number of OpenMP runtime functions identified");
157 STATISTIC(NumOpenMPRuntimeFunctionUsesIdentified,
158           "Number of OpenMP runtime function uses identified");
159 STATISTIC(NumOpenMPTargetRegionKernels,
160           "Number of OpenMP target region entry points (=kernels) identified");
161 STATISTIC(NumNonOpenMPTargetRegionKernels,
162           "Number of non-OpenMP target region kernels identified");
163 STATISTIC(NumOpenMPTargetRegionKernelsSPMD,
164           "Number of OpenMP target region entry points (=kernels) executed in "
165           "SPMD-mode instead of generic-mode");
166 STATISTIC(NumOpenMPTargetRegionKernelsWithoutStateMachine,
167           "Number of OpenMP target region entry points (=kernels) executed in "
168           "generic-mode without a state machines");
169 STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback,
170           "Number of OpenMP target region entry points (=kernels) executed in "
171           "generic-mode with customized state machines with fallback");
172 STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback,
173           "Number of OpenMP target region entry points (=kernels) executed in "
174           "generic-mode with customized state machines without fallback");
175 STATISTIC(
176     NumOpenMPParallelRegionsReplacedInGPUStateMachine,
177     "Number of OpenMP parallel regions replaced with ID in GPU state machines");
178 STATISTIC(NumOpenMPParallelRegionsMerged,
179           "Number of OpenMP parallel regions merged");
180 STATISTIC(NumBytesMovedToSharedMemory,
181           "Amount of memory pushed to shared memory");
182 STATISTIC(NumBarriersEliminated, "Number of redundant barriers eliminated");
183 
184 #if !defined(NDEBUG)
185 static constexpr auto TAG = "[" DEBUG_TYPE "]";
186 #endif
187 
188 namespace KernelInfo {
189 
190 // struct ConfigurationEnvironmentTy {
191 //   uint8_t UseGenericStateMachine;
192 //   uint8_t MayUseNestedParallelism;
193 //   llvm::omp::OMPTgtExecModeFlags ExecMode;
194 //   int32_t MinThreads;
195 //   int32_t MaxThreads;
196 //   int32_t MinTeams;
197 //   int32_t MaxTeams;
198 // };
199 
200 // struct DynamicEnvironmentTy {
201 //   uint16_t DebugIndentionLevel;
202 // };
203 
204 // struct KernelEnvironmentTy {
205 //   ConfigurationEnvironmentTy Configuration;
206 //   IdentTy *Ident;
207 //   DynamicEnvironmentTy *DynamicEnv;
208 // };
209 
210 #define KERNEL_ENVIRONMENT_IDX(MEMBER, IDX)                                    \
211   constexpr const unsigned MEMBER##Idx = IDX;
212 
213 KERNEL_ENVIRONMENT_IDX(Configuration, 0)
214 KERNEL_ENVIRONMENT_IDX(Ident, 1)
215 
216 #undef KERNEL_ENVIRONMENT_IDX
217 
218 #define KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MEMBER, IDX)                      \
219   constexpr const unsigned MEMBER##Idx = IDX;
220 
221 KERNEL_ENVIRONMENT_CONFIGURATION_IDX(UseGenericStateMachine, 0)
222 KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MayUseNestedParallelism, 1)
223 KERNEL_ENVIRONMENT_CONFIGURATION_IDX(ExecMode, 2)
224 KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MinThreads, 3)
225 KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MaxThreads, 4)
226 KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MinTeams, 5)
227 KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MaxTeams, 6)
228 
229 #undef KERNEL_ENVIRONMENT_CONFIGURATION_IDX
230 
231 #define KERNEL_ENVIRONMENT_GETTER(MEMBER, RETURNTYPE)                          \
232   RETURNTYPE *get##MEMBER##FromKernelEnvironment(ConstantStruct *KernelEnvC) { \
233     return cast<RETURNTYPE>(KernelEnvC->getAggregateElement(MEMBER##Idx));     \
234   }
235 
236 KERNEL_ENVIRONMENT_GETTER(Ident, Constant)
237 KERNEL_ENVIRONMENT_GETTER(Configuration, ConstantStruct)
238 
239 #undef KERNEL_ENVIRONMENT_GETTER
240 
241 #define KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MEMBER)                        \
242   ConstantInt *get##MEMBER##FromKernelEnvironment(                             \
243       ConstantStruct *KernelEnvC) {                                            \
244     ConstantStruct *ConfigC =                                                  \
245         getConfigurationFromKernelEnvironment(KernelEnvC);                     \
246     return dyn_cast<ConstantInt>(ConfigC->getAggregateElement(MEMBER##Idx));   \
247   }
248 
249 KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(UseGenericStateMachine)
250 KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MayUseNestedParallelism)
251 KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(ExecMode)
252 KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MinThreads)
253 KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MaxThreads)
254 KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MinTeams)
255 KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MaxTeams)
256 
257 #undef KERNEL_ENVIRONMENT_CONFIGURATION_GETTER
258 
259 GlobalVariable *
260 getKernelEnvironementGVFromKernelInitCB(CallBase *KernelInitCB) {
261   constexpr const int InitKernelEnvironmentArgNo = 0;
262   return cast<GlobalVariable>(
263       KernelInitCB->getArgOperand(InitKernelEnvironmentArgNo)
264           ->stripPointerCasts());
265 }
266 
267 ConstantStruct *getKernelEnvironementFromKernelInitCB(CallBase *KernelInitCB) {
268   GlobalVariable *KernelEnvGV =
269       getKernelEnvironementGVFromKernelInitCB(KernelInitCB);
270   return cast<ConstantStruct>(KernelEnvGV->getInitializer());
271 }
272 } // namespace KernelInfo
273 
274 namespace {
275 
276 struct AAHeapToShared;
277 
278 struct AAICVTracker;
279 
280 /// OpenMP specific information. For now, stores RFIs and ICVs also needed for
281 /// Attributor runs.
282 struct OMPInformationCache : public InformationCache {
283   OMPInformationCache(Module &M, AnalysisGetter &AG,
284                       BumpPtrAllocator &Allocator, SetVector<Function *> *CGSCC,
285                       bool OpenMPPostLink)
286       : InformationCache(M, AG, Allocator, CGSCC), OMPBuilder(M),
287         OpenMPPostLink(OpenMPPostLink) {
288 
289     OMPBuilder.Config.IsTargetDevice = isOpenMPDevice(OMPBuilder.M);
290     OMPBuilder.initialize();
291     initializeRuntimeFunctions(M);
292     initializeInternalControlVars();
293   }
294 
295   /// Generic information that describes an internal control variable.
296   struct InternalControlVarInfo {
297     /// The kind, as described by InternalControlVar enum.
298     InternalControlVar Kind;
299 
300     /// The name of the ICV.
301     StringRef Name;
302 
303     /// Environment variable associated with this ICV.
304     StringRef EnvVarName;
305 
306     /// Initial value kind.
307     ICVInitValue InitKind;
308 
309     /// Initial value.
310     ConstantInt *InitValue;
311 
312     /// Setter RTL function associated with this ICV.
313     RuntimeFunction Setter;
314 
315     /// Getter RTL function associated with this ICV.
316     RuntimeFunction Getter;
317 
318     /// RTL Function corresponding to the override clause of this ICV
319     RuntimeFunction Clause;
320   };
321 
322   /// Generic information that describes a runtime function
323   struct RuntimeFunctionInfo {
324 
325     /// The kind, as described by the RuntimeFunction enum.
326     RuntimeFunction Kind;
327 
328     /// The name of the function.
329     StringRef Name;
330 
331     /// Flag to indicate a variadic function.
332     bool IsVarArg;
333 
334     /// The return type of the function.
335     Type *ReturnType;
336 
337     /// The argument types of the function.
338     SmallVector<Type *, 8> ArgumentTypes;
339 
340     /// The declaration if available.
341     Function *Declaration = nullptr;
342 
343     /// Uses of this runtime function per function containing the use.
344     using UseVector = SmallVector<Use *, 16>;
345 
346     /// Clear UsesMap for runtime function.
347     void clearUsesMap() { UsesMap.clear(); }
348 
349     /// Boolean conversion that is true if the runtime function was found.
350     operator bool() const { return Declaration; }
351 
352     /// Return the vector of uses in function \p F.
353     UseVector &getOrCreateUseVector(Function *F) {
354       std::shared_ptr<UseVector> &UV = UsesMap[F];
355       if (!UV)
356         UV = std::make_shared<UseVector>();
357       return *UV;
358     }
359 
360     /// Return the vector of uses in function \p F or `nullptr` if there are
361     /// none.
362     const UseVector *getUseVector(Function &F) const {
363       auto I = UsesMap.find(&F);
364       if (I != UsesMap.end())
365         return I->second.get();
366       return nullptr;
367     }
368 
369     /// Return how many functions contain uses of this runtime function.
370     size_t getNumFunctionsWithUses() const { return UsesMap.size(); }
371 
372     /// Return the number of arguments (or the minimal number for variadic
373     /// functions).
374     size_t getNumArgs() const { return ArgumentTypes.size(); }
375 
376     /// Run the callback \p CB on each use and forget the use if the result is
377     /// true. The callback will be fed the function in which the use was
378     /// encountered as second argument.
379     void foreachUse(SmallVectorImpl<Function *> &SCC,
380                     function_ref<bool(Use &, Function &)> CB) {
381       for (Function *F : SCC)
382         foreachUse(CB, F);
383     }
384 
385     /// Run the callback \p CB on each use within the function \p F and forget
386     /// the use if the result is true.
387     void foreachUse(function_ref<bool(Use &, Function &)> CB, Function *F) {
388       SmallVector<unsigned, 8> ToBeDeleted;
389       ToBeDeleted.clear();
390 
391       unsigned Idx = 0;
392       UseVector &UV = getOrCreateUseVector(F);
393 
394       for (Use *U : UV) {
395         if (CB(*U, *F))
396           ToBeDeleted.push_back(Idx);
397         ++Idx;
398       }
399 
400       // Remove the to-be-deleted indices in reverse order as prior
401       // modifications will not modify the smaller indices.
402       while (!ToBeDeleted.empty()) {
403         unsigned Idx = ToBeDeleted.pop_back_val();
404         UV[Idx] = UV.back();
405         UV.pop_back();
406       }
407     }
408 
409   private:
410     /// Map from functions to all uses of this runtime function contained in
411     /// them.
412     DenseMap<Function *, std::shared_ptr<UseVector>> UsesMap;
413 
414   public:
415     /// Iterators for the uses of this runtime function.
416     decltype(UsesMap)::iterator begin() { return UsesMap.begin(); }
417     decltype(UsesMap)::iterator end() { return UsesMap.end(); }
418   };
419 
420   /// An OpenMP-IR-Builder instance
421   OpenMPIRBuilder OMPBuilder;
422 
423   /// Map from runtime function kind to the runtime function description.
424   EnumeratedArray<RuntimeFunctionInfo, RuntimeFunction,
425                   RuntimeFunction::OMPRTL___last>
426       RFIs;
427 
428   /// Map from function declarations/definitions to their runtime enum type.
429   DenseMap<Function *, RuntimeFunction> RuntimeFunctionIDMap;
430 
431   /// Map from ICV kind to the ICV description.
432   EnumeratedArray<InternalControlVarInfo, InternalControlVar,
433                   InternalControlVar::ICV___last>
434       ICVs;
435 
436   /// Helper to initialize all internal control variable information for those
437   /// defined in OMPKinds.def.
438   void initializeInternalControlVars() {
439 #define ICV_RT_SET(_Name, RTL)                                                 \
440   {                                                                            \
441     auto &ICV = ICVs[_Name];                                                   \
442     ICV.Setter = RTL;                                                          \
443   }
444 #define ICV_RT_GET(Name, RTL)                                                  \
445   {                                                                            \
446     auto &ICV = ICVs[Name];                                                    \
447     ICV.Getter = RTL;                                                          \
448   }
449 #define ICV_DATA_ENV(Enum, _Name, _EnvVarName, Init)                           \
450   {                                                                            \
451     auto &ICV = ICVs[Enum];                                                    \
452     ICV.Name = _Name;                                                          \
453     ICV.Kind = Enum;                                                           \
454     ICV.InitKind = Init;                                                       \
455     ICV.EnvVarName = _EnvVarName;                                              \
456     switch (ICV.InitKind) {                                                    \
457     case ICV_IMPLEMENTATION_DEFINED:                                           \
458       ICV.InitValue = nullptr;                                                 \
459       break;                                                                   \
460     case ICV_ZERO:                                                             \
461       ICV.InitValue = ConstantInt::get(                                        \
462           Type::getInt32Ty(OMPBuilder.Int32->getContext()), 0);                \
463       break;                                                                   \
464     case ICV_FALSE:                                                            \
465       ICV.InitValue = ConstantInt::getFalse(OMPBuilder.Int1->getContext());    \
466       break;                                                                   \
467     case ICV_LAST:                                                             \
468       break;                                                                   \
469     }                                                                          \
470   }
471 #include "llvm/Frontend/OpenMP/OMPKinds.def"
472   }
473 
474   /// Returns true if the function declaration \p F matches the runtime
475   /// function types, that is, return type \p RTFRetType, and argument types
476   /// \p RTFArgTypes.
477   static bool declMatchesRTFTypes(Function *F, Type *RTFRetType,
478                                   SmallVector<Type *, 8> &RTFArgTypes) {
479     // TODO: We should output information to the user (under debug output
480     //       and via remarks).
481 
482     if (!F)
483       return false;
484     if (F->getReturnType() != RTFRetType)
485       return false;
486     if (F->arg_size() != RTFArgTypes.size())
487       return false;
488 
489     auto *RTFTyIt = RTFArgTypes.begin();
490     for (Argument &Arg : F->args()) {
491       if (Arg.getType() != *RTFTyIt)
492         return false;
493 
494       ++RTFTyIt;
495     }
496 
497     return true;
498   }
499 
500   // Helper to collect all uses of the declaration in the UsesMap.
501   unsigned collectUses(RuntimeFunctionInfo &RFI, bool CollectStats = true) {
502     unsigned NumUses = 0;
503     if (!RFI.Declaration)
504       return NumUses;
505     OMPBuilder.addAttributes(RFI.Kind, *RFI.Declaration);
506 
507     if (CollectStats) {
508       NumOpenMPRuntimeFunctionsIdentified += 1;
509       NumOpenMPRuntimeFunctionUsesIdentified += RFI.Declaration->getNumUses();
510     }
511 
512     // TODO: We directly convert uses into proper calls and unknown uses.
513     for (Use &U : RFI.Declaration->uses()) {
514       if (Instruction *UserI = dyn_cast<Instruction>(U.getUser())) {
515         if (!CGSCC || CGSCC->empty() || CGSCC->contains(UserI->getFunction())) {
516           RFI.getOrCreateUseVector(UserI->getFunction()).push_back(&U);
517           ++NumUses;
518         }
519       } else {
520         RFI.getOrCreateUseVector(nullptr).push_back(&U);
521         ++NumUses;
522       }
523     }
524     return NumUses;
525   }
526 
527   // Helper function to recollect uses of a runtime function.
528   void recollectUsesForFunction(RuntimeFunction RTF) {
529     auto &RFI = RFIs[RTF];
530     RFI.clearUsesMap();
531     collectUses(RFI, /*CollectStats*/ false);
532   }
533 
534   // Helper function to recollect uses of all runtime functions.
535   void recollectUses() {
536     for (int Idx = 0; Idx < RFIs.size(); ++Idx)
537       recollectUsesForFunction(static_cast<RuntimeFunction>(Idx));
538   }
539 
540   // Helper function to inherit the calling convention of the function callee.
541   void setCallingConvention(FunctionCallee Callee, CallInst *CI) {
542     if (Function *Fn = dyn_cast<Function>(Callee.getCallee()))
543       CI->setCallingConv(Fn->getCallingConv());
544   }
545 
546   // Helper function to determine if it's legal to create a call to the runtime
547   // functions.
548   bool runtimeFnsAvailable(ArrayRef<RuntimeFunction> Fns) {
549     // We can always emit calls if we haven't yet linked in the runtime.
550     if (!OpenMPPostLink)
551       return true;
552 
553     // Once the runtime has been already been linked in we cannot emit calls to
554     // any undefined functions.
555     for (RuntimeFunction Fn : Fns) {
556       RuntimeFunctionInfo &RFI = RFIs[Fn];
557 
558       if (RFI.Declaration && RFI.Declaration->isDeclaration())
559         return false;
560     }
561     return true;
562   }
563 
564   /// Helper to initialize all runtime function information for those defined
565   /// in OpenMPKinds.def.
566   void initializeRuntimeFunctions(Module &M) {
567 
568     // Helper macros for handling __VA_ARGS__ in OMP_RTL
569 #define OMP_TYPE(VarName, ...)                                                 \
570   Type *VarName = OMPBuilder.VarName;                                          \
571   (void)VarName;
572 
573 #define OMP_ARRAY_TYPE(VarName, ...)                                           \
574   ArrayType *VarName##Ty = OMPBuilder.VarName##Ty;                             \
575   (void)VarName##Ty;                                                           \
576   PointerType *VarName##PtrTy = OMPBuilder.VarName##PtrTy;                     \
577   (void)VarName##PtrTy;
578 
579 #define OMP_FUNCTION_TYPE(VarName, ...)                                        \
580   FunctionType *VarName = OMPBuilder.VarName;                                  \
581   (void)VarName;                                                               \
582   PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr;                         \
583   (void)VarName##Ptr;
584 
585 #define OMP_STRUCT_TYPE(VarName, ...)                                          \
586   StructType *VarName = OMPBuilder.VarName;                                    \
587   (void)VarName;                                                               \
588   PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr;                         \
589   (void)VarName##Ptr;
590 
591 #define OMP_RTL(_Enum, _Name, _IsVarArg, _ReturnType, ...)                     \
592   {                                                                            \
593     SmallVector<Type *, 8> ArgsTypes({__VA_ARGS__});                           \
594     Function *F = M.getFunction(_Name);                                        \
595     RTLFunctions.insert(F);                                                    \
596     if (declMatchesRTFTypes(F, OMPBuilder._ReturnType, ArgsTypes)) {           \
597       RuntimeFunctionIDMap[F] = _Enum;                                         \
598       auto &RFI = RFIs[_Enum];                                                 \
599       RFI.Kind = _Enum;                                                        \
600       RFI.Name = _Name;                                                        \
601       RFI.IsVarArg = _IsVarArg;                                                \
602       RFI.ReturnType = OMPBuilder._ReturnType;                                 \
603       RFI.ArgumentTypes = std::move(ArgsTypes);                                \
604       RFI.Declaration = F;                                                     \
605       unsigned NumUses = collectUses(RFI);                                     \
606       (void)NumUses;                                                           \
607       LLVM_DEBUG({                                                             \
608         dbgs() << TAG << RFI.Name << (RFI.Declaration ? "" : " not")           \
609                << " found\n";                                                  \
610         if (RFI.Declaration)                                                   \
611           dbgs() << TAG << "-> got " << NumUses << " uses in "                 \
612                  << RFI.getNumFunctionsWithUses()                              \
613                  << " different functions.\n";                                 \
614       });                                                                      \
615     }                                                                          \
616   }
617 #include "llvm/Frontend/OpenMP/OMPKinds.def"
618 
619     // Remove the `noinline` attribute from `__kmpc`, `ompx::` and `omp_`
620     // functions, except if `optnone` is present.
621     if (isOpenMPDevice(M)) {
622       for (Function &F : M) {
623         for (StringRef Prefix : {"__kmpc", "_ZN4ompx", "omp_"})
624           if (F.hasFnAttribute(Attribute::NoInline) &&
625               F.getName().starts_with(Prefix) &&
626               !F.hasFnAttribute(Attribute::OptimizeNone))
627             F.removeFnAttr(Attribute::NoInline);
628       }
629     }
630 
631     // TODO: We should attach the attributes defined in OMPKinds.def.
632   }
633 
634   /// Collection of known OpenMP runtime functions..
635   DenseSet<const Function *> RTLFunctions;
636 
637   /// Indicates if we have already linked in the OpenMP device library.
638   bool OpenMPPostLink = false;
639 };
640 
641 template <typename Ty, bool InsertInvalidates = true>
642 struct BooleanStateWithSetVector : public BooleanState {
643   bool contains(const Ty &Elem) const { return Set.contains(Elem); }
644   bool insert(const Ty &Elem) {
645     if (InsertInvalidates)
646       BooleanState::indicatePessimisticFixpoint();
647     return Set.insert(Elem);
648   }
649 
650   const Ty &operator[](int Idx) const { return Set[Idx]; }
651   bool operator==(const BooleanStateWithSetVector &RHS) const {
652     return BooleanState::operator==(RHS) && Set == RHS.Set;
653   }
654   bool operator!=(const BooleanStateWithSetVector &RHS) const {
655     return !(*this == RHS);
656   }
657 
658   bool empty() const { return Set.empty(); }
659   size_t size() const { return Set.size(); }
660 
661   /// "Clamp" this state with \p RHS.
662   BooleanStateWithSetVector &operator^=(const BooleanStateWithSetVector &RHS) {
663     BooleanState::operator^=(RHS);
664     Set.insert(RHS.Set.begin(), RHS.Set.end());
665     return *this;
666   }
667 
668 private:
669   /// A set to keep track of elements.
670   SetVector<Ty> Set;
671 
672 public:
673   typename decltype(Set)::iterator begin() { return Set.begin(); }
674   typename decltype(Set)::iterator end() { return Set.end(); }
675   typename decltype(Set)::const_iterator begin() const { return Set.begin(); }
676   typename decltype(Set)::const_iterator end() const { return Set.end(); }
677 };
678 
679 template <typename Ty, bool InsertInvalidates = true>
680 using BooleanStateWithPtrSetVector =
681     BooleanStateWithSetVector<Ty *, InsertInvalidates>;
682 
683 struct KernelInfoState : AbstractState {
684   /// Flag to track if we reached a fixpoint.
685   bool IsAtFixpoint = false;
686 
687   /// The parallel regions (identified by the outlined parallel functions) that
688   /// can be reached from the associated function.
689   BooleanStateWithPtrSetVector<CallBase, /* InsertInvalidates */ false>
690       ReachedKnownParallelRegions;
691 
692   /// State to track what parallel region we might reach.
693   BooleanStateWithPtrSetVector<CallBase> ReachedUnknownParallelRegions;
694 
695   /// State to track if we are in SPMD-mode, assumed or know, and why we decided
696   /// we cannot be. If it is assumed, then RequiresFullRuntime should also be
697   /// false.
698   BooleanStateWithPtrSetVector<Instruction, false> SPMDCompatibilityTracker;
699 
700   /// The __kmpc_target_init call in this kernel, if any. If we find more than
701   /// one we abort as the kernel is malformed.
702   CallBase *KernelInitCB = nullptr;
703 
704   /// The constant kernel environement as taken from and passed to
705   /// __kmpc_target_init.
706   ConstantStruct *KernelEnvC = nullptr;
707 
708   /// The __kmpc_target_deinit call in this kernel, if any. If we find more than
709   /// one we abort as the kernel is malformed.
710   CallBase *KernelDeinitCB = nullptr;
711 
712   /// Flag to indicate if the associated function is a kernel entry.
713   bool IsKernelEntry = false;
714 
715   /// State to track what kernel entries can reach the associated function.
716   BooleanStateWithPtrSetVector<Function, false> ReachingKernelEntries;
717 
718   /// State to indicate if we can track parallel level of the associated
719   /// function. We will give up tracking if we encounter unknown caller or the
720   /// caller is __kmpc_parallel_51.
721   BooleanStateWithSetVector<uint8_t> ParallelLevels;
722 
723   /// Flag that indicates if the kernel has nested Parallelism
724   bool NestedParallelism = false;
725 
726   /// Abstract State interface
727   ///{
728 
729   KernelInfoState() = default;
730   KernelInfoState(bool BestState) {
731     if (!BestState)
732       indicatePessimisticFixpoint();
733   }
734 
735   /// See AbstractState::isValidState(...)
736   bool isValidState() const override { return true; }
737 
738   /// See AbstractState::isAtFixpoint(...)
739   bool isAtFixpoint() const override { return IsAtFixpoint; }
740 
741   /// See AbstractState::indicatePessimisticFixpoint(...)
742   ChangeStatus indicatePessimisticFixpoint() override {
743     IsAtFixpoint = true;
744     ParallelLevels.indicatePessimisticFixpoint();
745     ReachingKernelEntries.indicatePessimisticFixpoint();
746     SPMDCompatibilityTracker.indicatePessimisticFixpoint();
747     ReachedKnownParallelRegions.indicatePessimisticFixpoint();
748     ReachedUnknownParallelRegions.indicatePessimisticFixpoint();
749     NestedParallelism = true;
750     return ChangeStatus::CHANGED;
751   }
752 
753   /// See AbstractState::indicateOptimisticFixpoint(...)
754   ChangeStatus indicateOptimisticFixpoint() override {
755     IsAtFixpoint = true;
756     ParallelLevels.indicateOptimisticFixpoint();
757     ReachingKernelEntries.indicateOptimisticFixpoint();
758     SPMDCompatibilityTracker.indicateOptimisticFixpoint();
759     ReachedKnownParallelRegions.indicateOptimisticFixpoint();
760     ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
761     return ChangeStatus::UNCHANGED;
762   }
763 
764   /// Return the assumed state
765   KernelInfoState &getAssumed() { return *this; }
766   const KernelInfoState &getAssumed() const { return *this; }
767 
768   bool operator==(const KernelInfoState &RHS) const {
769     if (SPMDCompatibilityTracker != RHS.SPMDCompatibilityTracker)
770       return false;
771     if (ReachedKnownParallelRegions != RHS.ReachedKnownParallelRegions)
772       return false;
773     if (ReachedUnknownParallelRegions != RHS.ReachedUnknownParallelRegions)
774       return false;
775     if (ReachingKernelEntries != RHS.ReachingKernelEntries)
776       return false;
777     if (ParallelLevels != RHS.ParallelLevels)
778       return false;
779     if (NestedParallelism != RHS.NestedParallelism)
780       return false;
781     return true;
782   }
783 
784   /// Returns true if this kernel contains any OpenMP parallel regions.
785   bool mayContainParallelRegion() {
786     return !ReachedKnownParallelRegions.empty() ||
787            !ReachedUnknownParallelRegions.empty();
788   }
789 
790   /// Return empty set as the best state of potential values.
791   static KernelInfoState getBestState() { return KernelInfoState(true); }
792 
793   static KernelInfoState getBestState(KernelInfoState &KIS) {
794     return getBestState();
795   }
796 
797   /// Return full set as the worst state of potential values.
798   static KernelInfoState getWorstState() { return KernelInfoState(false); }
799 
800   /// "Clamp" this state with \p KIS.
801   KernelInfoState operator^=(const KernelInfoState &KIS) {
802     // Do not merge two different _init and _deinit call sites.
803     if (KIS.KernelInitCB) {
804       if (KernelInitCB && KernelInitCB != KIS.KernelInitCB)
805         llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
806                          "assumptions.");
807       KernelInitCB = KIS.KernelInitCB;
808     }
809     if (KIS.KernelDeinitCB) {
810       if (KernelDeinitCB && KernelDeinitCB != KIS.KernelDeinitCB)
811         llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
812                          "assumptions.");
813       KernelDeinitCB = KIS.KernelDeinitCB;
814     }
815     if (KIS.KernelEnvC) {
816       if (KernelEnvC && KernelEnvC != KIS.KernelEnvC)
817         llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
818                          "assumptions.");
819       KernelEnvC = KIS.KernelEnvC;
820     }
821     SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker;
822     ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions;
823     ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions;
824     NestedParallelism |= KIS.NestedParallelism;
825     return *this;
826   }
827 
828   KernelInfoState operator&=(const KernelInfoState &KIS) {
829     return (*this ^= KIS);
830   }
831 
832   ///}
833 };
834 
835 /// Used to map the values physically (in the IR) stored in an offload
836 /// array, to a vector in memory.
837 struct OffloadArray {
838   /// Physical array (in the IR).
839   AllocaInst *Array = nullptr;
840   /// Mapped values.
841   SmallVector<Value *, 8> StoredValues;
842   /// Last stores made in the offload array.
843   SmallVector<StoreInst *, 8> LastAccesses;
844 
845   OffloadArray() = default;
846 
847   /// Initializes the OffloadArray with the values stored in \p Array before
848   /// instruction \p Before is reached. Returns false if the initialization
849   /// fails.
850   /// This MUST be used immediately after the construction of the object.
851   bool initialize(AllocaInst &Array, Instruction &Before) {
852     if (!Array.getAllocatedType()->isArrayTy())
853       return false;
854 
855     if (!getValues(Array, Before))
856       return false;
857 
858     this->Array = &Array;
859     return true;
860   }
861 
862   static const unsigned DeviceIDArgNum = 1;
863   static const unsigned BasePtrsArgNum = 3;
864   static const unsigned PtrsArgNum = 4;
865   static const unsigned SizesArgNum = 5;
866 
867 private:
868   /// Traverses the BasicBlock where \p Array is, collecting the stores made to
869   /// \p Array, leaving StoredValues with the values stored before the
870   /// instruction \p Before is reached.
871   bool getValues(AllocaInst &Array, Instruction &Before) {
872     // Initialize container.
873     const uint64_t NumValues = Array.getAllocatedType()->getArrayNumElements();
874     StoredValues.assign(NumValues, nullptr);
875     LastAccesses.assign(NumValues, nullptr);
876 
877     // TODO: This assumes the instruction \p Before is in the same
878     //  BasicBlock as Array. Make it general, for any control flow graph.
879     BasicBlock *BB = Array.getParent();
880     if (BB != Before.getParent())
881       return false;
882 
883     const DataLayout &DL = Array.getModule()->getDataLayout();
884     const unsigned int PointerSize = DL.getPointerSize();
885 
886     for (Instruction &I : *BB) {
887       if (&I == &Before)
888         break;
889 
890       if (!isa<StoreInst>(&I))
891         continue;
892 
893       auto *S = cast<StoreInst>(&I);
894       int64_t Offset = -1;
895       auto *Dst =
896           GetPointerBaseWithConstantOffset(S->getPointerOperand(), Offset, DL);
897       if (Dst == &Array) {
898         int64_t Idx = Offset / PointerSize;
899         StoredValues[Idx] = getUnderlyingObject(S->getValueOperand());
900         LastAccesses[Idx] = S;
901       }
902     }
903 
904     return isFilled();
905   }
906 
907   /// Returns true if all values in StoredValues and
908   /// LastAccesses are not nullptrs.
909   bool isFilled() {
910     const unsigned NumValues = StoredValues.size();
911     for (unsigned I = 0; I < NumValues; ++I) {
912       if (!StoredValues[I] || !LastAccesses[I])
913         return false;
914     }
915 
916     return true;
917   }
918 };
919 
920 struct OpenMPOpt {
921 
922   using OptimizationRemarkGetter =
923       function_ref<OptimizationRemarkEmitter &(Function *)>;
924 
925   OpenMPOpt(SmallVectorImpl<Function *> &SCC, CallGraphUpdater &CGUpdater,
926             OptimizationRemarkGetter OREGetter,
927             OMPInformationCache &OMPInfoCache, Attributor &A)
928       : M(*(*SCC.begin())->getParent()), SCC(SCC), CGUpdater(CGUpdater),
929         OREGetter(OREGetter), OMPInfoCache(OMPInfoCache), A(A) {}
930 
931   /// Check if any remarks are enabled for openmp-opt
932   bool remarksEnabled() {
933     auto &Ctx = M.getContext();
934     return Ctx.getDiagHandlerPtr()->isAnyRemarkEnabled(DEBUG_TYPE);
935   }
936 
937   /// Run all OpenMP optimizations on the underlying SCC.
938   bool run(bool IsModulePass) {
939     if (SCC.empty())
940       return false;
941 
942     bool Changed = false;
943 
944     LLVM_DEBUG(dbgs() << TAG << "Run on SCC with " << SCC.size()
945                       << " functions\n");
946 
947     if (IsModulePass) {
948       Changed |= runAttributor(IsModulePass);
949 
950       // Recollect uses, in case Attributor deleted any.
951       OMPInfoCache.recollectUses();
952 
953       // TODO: This should be folded into buildCustomStateMachine.
954       Changed |= rewriteDeviceCodeStateMachine();
955 
956       if (remarksEnabled())
957         analysisGlobalization();
958     } else {
959       if (PrintICVValues)
960         printICVs();
961       if (PrintOpenMPKernels)
962         printKernels();
963 
964       Changed |= runAttributor(IsModulePass);
965 
966       // Recollect uses, in case Attributor deleted any.
967       OMPInfoCache.recollectUses();
968 
969       Changed |= deleteParallelRegions();
970 
971       if (HideMemoryTransferLatency)
972         Changed |= hideMemTransfersLatency();
973       Changed |= deduplicateRuntimeCalls();
974       if (EnableParallelRegionMerging) {
975         if (mergeParallelRegions()) {
976           deduplicateRuntimeCalls();
977           Changed = true;
978         }
979       }
980     }
981 
982     if (OMPInfoCache.OpenMPPostLink)
983       Changed |= removeRuntimeSymbols();
984 
985     return Changed;
986   }
987 
988   /// Print initial ICV values for testing.
989   /// FIXME: This should be done from the Attributor once it is added.
990   void printICVs() const {
991     InternalControlVar ICVs[] = {ICV_nthreads, ICV_active_levels, ICV_cancel,
992                                  ICV_proc_bind};
993 
994     for (Function *F : SCC) {
995       for (auto ICV : ICVs) {
996         auto ICVInfo = OMPInfoCache.ICVs[ICV];
997         auto Remark = [&](OptimizationRemarkAnalysis ORA) {
998           return ORA << "OpenMP ICV " << ore::NV("OpenMPICV", ICVInfo.Name)
999                      << " Value: "
1000                      << (ICVInfo.InitValue
1001                              ? toString(ICVInfo.InitValue->getValue(), 10, true)
1002                              : "IMPLEMENTATION_DEFINED");
1003         };
1004 
1005         emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPICVTracker", Remark);
1006       }
1007     }
1008   }
1009 
1010   /// Print OpenMP GPU kernels for testing.
1011   void printKernels() const {
1012     for (Function *F : SCC) {
1013       if (!omp::isOpenMPKernel(*F))
1014         continue;
1015 
1016       auto Remark = [&](OptimizationRemarkAnalysis ORA) {
1017         return ORA << "OpenMP GPU kernel "
1018                    << ore::NV("OpenMPGPUKernel", F->getName()) << "\n";
1019       };
1020 
1021       emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPGPU", Remark);
1022     }
1023   }
1024 
1025   /// Return the call if \p U is a callee use in a regular call. If \p RFI is
1026   /// given it has to be the callee or a nullptr is returned.
1027   static CallInst *getCallIfRegularCall(
1028       Use &U, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
1029     CallInst *CI = dyn_cast<CallInst>(U.getUser());
1030     if (CI && CI->isCallee(&U) && !CI->hasOperandBundles() &&
1031         (!RFI ||
1032          (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
1033       return CI;
1034     return nullptr;
1035   }
1036 
1037   /// Return the call if \p V is a regular call. If \p RFI is given it has to be
1038   /// the callee or a nullptr is returned.
1039   static CallInst *getCallIfRegularCall(
1040       Value &V, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
1041     CallInst *CI = dyn_cast<CallInst>(&V);
1042     if (CI && !CI->hasOperandBundles() &&
1043         (!RFI ||
1044          (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
1045       return CI;
1046     return nullptr;
1047   }
1048 
1049 private:
1050   /// Merge parallel regions when it is safe.
1051   bool mergeParallelRegions() {
1052     const unsigned CallbackCalleeOperand = 2;
1053     const unsigned CallbackFirstArgOperand = 3;
1054     using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
1055 
1056     // Check if there are any __kmpc_fork_call calls to merge.
1057     OMPInformationCache::RuntimeFunctionInfo &RFI =
1058         OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
1059 
1060     if (!RFI.Declaration)
1061       return false;
1062 
1063     // Unmergable calls that prevent merging a parallel region.
1064     OMPInformationCache::RuntimeFunctionInfo UnmergableCallsInfo[] = {
1065         OMPInfoCache.RFIs[OMPRTL___kmpc_push_proc_bind],
1066         OMPInfoCache.RFIs[OMPRTL___kmpc_push_num_threads],
1067     };
1068 
1069     bool Changed = false;
1070     LoopInfo *LI = nullptr;
1071     DominatorTree *DT = nullptr;
1072 
1073     SmallDenseMap<BasicBlock *, SmallPtrSet<Instruction *, 4>> BB2PRMap;
1074 
1075     BasicBlock *StartBB = nullptr, *EndBB = nullptr;
1076     auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
1077       BasicBlock *CGStartBB = CodeGenIP.getBlock();
1078       BasicBlock *CGEndBB =
1079           SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
1080       assert(StartBB != nullptr && "StartBB should not be null");
1081       CGStartBB->getTerminator()->setSuccessor(0, StartBB);
1082       assert(EndBB != nullptr && "EndBB should not be null");
1083       EndBB->getTerminator()->setSuccessor(0, CGEndBB);
1084     };
1085 
1086     auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value &,
1087                       Value &Inner, Value *&ReplacementValue) -> InsertPointTy {
1088       ReplacementValue = &Inner;
1089       return CodeGenIP;
1090     };
1091 
1092     auto FiniCB = [&](InsertPointTy CodeGenIP) {};
1093 
1094     /// Create a sequential execution region within a merged parallel region,
1095     /// encapsulated in a master construct with a barrier for synchronization.
1096     auto CreateSequentialRegion = [&](Function *OuterFn,
1097                                       BasicBlock *OuterPredBB,
1098                                       Instruction *SeqStartI,
1099                                       Instruction *SeqEndI) {
1100       // Isolate the instructions of the sequential region to a separate
1101       // block.
1102       BasicBlock *ParentBB = SeqStartI->getParent();
1103       BasicBlock *SeqEndBB =
1104           SplitBlock(ParentBB, SeqEndI->getNextNode(), DT, LI);
1105       BasicBlock *SeqAfterBB =
1106           SplitBlock(SeqEndBB, &*SeqEndBB->getFirstInsertionPt(), DT, LI);
1107       BasicBlock *SeqStartBB =
1108           SplitBlock(ParentBB, SeqStartI, DT, LI, nullptr, "seq.par.merged");
1109 
1110       assert(ParentBB->getUniqueSuccessor() == SeqStartBB &&
1111              "Expected a different CFG");
1112       const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
1113       ParentBB->getTerminator()->eraseFromParent();
1114 
1115       auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
1116         BasicBlock *CGStartBB = CodeGenIP.getBlock();
1117         BasicBlock *CGEndBB =
1118             SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
1119         assert(SeqStartBB != nullptr && "SeqStartBB should not be null");
1120         CGStartBB->getTerminator()->setSuccessor(0, SeqStartBB);
1121         assert(SeqEndBB != nullptr && "SeqEndBB should not be null");
1122         SeqEndBB->getTerminator()->setSuccessor(0, CGEndBB);
1123       };
1124       auto FiniCB = [&](InsertPointTy CodeGenIP) {};
1125 
1126       // Find outputs from the sequential region to outside users and
1127       // broadcast their values to them.
1128       for (Instruction &I : *SeqStartBB) {
1129         SmallPtrSet<Instruction *, 4> OutsideUsers;
1130         for (User *Usr : I.users()) {
1131           Instruction &UsrI = *cast<Instruction>(Usr);
1132           // Ignore outputs to LT intrinsics, code extraction for the merged
1133           // parallel region will fix them.
1134           if (UsrI.isLifetimeStartOrEnd())
1135             continue;
1136 
1137           if (UsrI.getParent() != SeqStartBB)
1138             OutsideUsers.insert(&UsrI);
1139         }
1140 
1141         if (OutsideUsers.empty())
1142           continue;
1143 
1144         // Emit an alloca in the outer region to store the broadcasted
1145         // value.
1146         const DataLayout &DL = M.getDataLayout();
1147         AllocaInst *AllocaI = new AllocaInst(
1148             I.getType(), DL.getAllocaAddrSpace(), nullptr,
1149             I.getName() + ".seq.output.alloc", &OuterFn->front().front());
1150 
1151         // Emit a store instruction in the sequential BB to update the
1152         // value.
1153         new StoreInst(&I, AllocaI, SeqStartBB->getTerminator());
1154 
1155         // Emit a load instruction and replace the use of the output value
1156         // with it.
1157         for (Instruction *UsrI : OutsideUsers) {
1158           LoadInst *LoadI = new LoadInst(
1159               I.getType(), AllocaI, I.getName() + ".seq.output.load", UsrI);
1160           UsrI->replaceUsesOfWith(&I, LoadI);
1161         }
1162       }
1163 
1164       OpenMPIRBuilder::LocationDescription Loc(
1165           InsertPointTy(ParentBB, ParentBB->end()), DL);
1166       InsertPointTy SeqAfterIP =
1167           OMPInfoCache.OMPBuilder.createMaster(Loc, BodyGenCB, FiniCB);
1168 
1169       OMPInfoCache.OMPBuilder.createBarrier(SeqAfterIP, OMPD_parallel);
1170 
1171       BranchInst::Create(SeqAfterBB, SeqAfterIP.getBlock());
1172 
1173       LLVM_DEBUG(dbgs() << TAG << "After sequential inlining " << *OuterFn
1174                         << "\n");
1175     };
1176 
1177     // Helper to merge the __kmpc_fork_call calls in MergableCIs. They are all
1178     // contained in BB and only separated by instructions that can be
1179     // redundantly executed in parallel. The block BB is split before the first
1180     // call (in MergableCIs) and after the last so the entire region we merge
1181     // into a single parallel region is contained in a single basic block
1182     // without any other instructions. We use the OpenMPIRBuilder to outline
1183     // that block and call the resulting function via __kmpc_fork_call.
1184     auto Merge = [&](const SmallVectorImpl<CallInst *> &MergableCIs,
1185                      BasicBlock *BB) {
1186       // TODO: Change the interface to allow single CIs expanded, e.g, to
1187       // include an outer loop.
1188       assert(MergableCIs.size() > 1 && "Assumed multiple mergable CIs");
1189 
1190       auto Remark = [&](OptimizationRemark OR) {
1191         OR << "Parallel region merged with parallel region"
1192            << (MergableCIs.size() > 2 ? "s" : "") << " at ";
1193         for (auto *CI : llvm::drop_begin(MergableCIs)) {
1194           OR << ore::NV("OpenMPParallelMerge", CI->getDebugLoc());
1195           if (CI != MergableCIs.back())
1196             OR << ", ";
1197         }
1198         return OR << ".";
1199       };
1200 
1201       emitRemark<OptimizationRemark>(MergableCIs.front(), "OMP150", Remark);
1202 
1203       Function *OriginalFn = BB->getParent();
1204       LLVM_DEBUG(dbgs() << TAG << "Merge " << MergableCIs.size()
1205                         << " parallel regions in " << OriginalFn->getName()
1206                         << "\n");
1207 
1208       // Isolate the calls to merge in a separate block.
1209       EndBB = SplitBlock(BB, MergableCIs.back()->getNextNode(), DT, LI);
1210       BasicBlock *AfterBB =
1211           SplitBlock(EndBB, &*EndBB->getFirstInsertionPt(), DT, LI);
1212       StartBB = SplitBlock(BB, MergableCIs.front(), DT, LI, nullptr,
1213                            "omp.par.merged");
1214 
1215       assert(BB->getUniqueSuccessor() == StartBB && "Expected a different CFG");
1216       const DebugLoc DL = BB->getTerminator()->getDebugLoc();
1217       BB->getTerminator()->eraseFromParent();
1218 
1219       // Create sequential regions for sequential instructions that are
1220       // in-between mergable parallel regions.
1221       for (auto *It = MergableCIs.begin(), *End = MergableCIs.end() - 1;
1222            It != End; ++It) {
1223         Instruction *ForkCI = *It;
1224         Instruction *NextForkCI = *(It + 1);
1225 
1226         // Continue if there are not in-between instructions.
1227         if (ForkCI->getNextNode() == NextForkCI)
1228           continue;
1229 
1230         CreateSequentialRegion(OriginalFn, BB, ForkCI->getNextNode(),
1231                                NextForkCI->getPrevNode());
1232       }
1233 
1234       OpenMPIRBuilder::LocationDescription Loc(InsertPointTy(BB, BB->end()),
1235                                                DL);
1236       IRBuilder<>::InsertPoint AllocaIP(
1237           &OriginalFn->getEntryBlock(),
1238           OriginalFn->getEntryBlock().getFirstInsertionPt());
1239       // Create the merged parallel region with default proc binding, to
1240       // avoid overriding binding settings, and without explicit cancellation.
1241       InsertPointTy AfterIP = OMPInfoCache.OMPBuilder.createParallel(
1242           Loc, AllocaIP, BodyGenCB, PrivCB, FiniCB, nullptr, nullptr,
1243           OMP_PROC_BIND_default, /* IsCancellable */ false);
1244       BranchInst::Create(AfterBB, AfterIP.getBlock());
1245 
1246       // Perform the actual outlining.
1247       OMPInfoCache.OMPBuilder.finalize(OriginalFn);
1248 
1249       Function *OutlinedFn = MergableCIs.front()->getCaller();
1250 
1251       // Replace the __kmpc_fork_call calls with direct calls to the outlined
1252       // callbacks.
1253       SmallVector<Value *, 8> Args;
1254       for (auto *CI : MergableCIs) {
1255         Value *Callee = CI->getArgOperand(CallbackCalleeOperand);
1256         FunctionType *FT = OMPInfoCache.OMPBuilder.ParallelTask;
1257         Args.clear();
1258         Args.push_back(OutlinedFn->getArg(0));
1259         Args.push_back(OutlinedFn->getArg(1));
1260         for (unsigned U = CallbackFirstArgOperand, E = CI->arg_size(); U < E;
1261              ++U)
1262           Args.push_back(CI->getArgOperand(U));
1263 
1264         CallInst *NewCI = CallInst::Create(FT, Callee, Args, "", CI);
1265         if (CI->getDebugLoc())
1266           NewCI->setDebugLoc(CI->getDebugLoc());
1267 
1268         // Forward parameter attributes from the callback to the callee.
1269         for (unsigned U = CallbackFirstArgOperand, E = CI->arg_size(); U < E;
1270              ++U)
1271           for (const Attribute &A : CI->getAttributes().getParamAttrs(U))
1272             NewCI->addParamAttr(
1273                 U - (CallbackFirstArgOperand - CallbackCalleeOperand), A);
1274 
1275         // Emit an explicit barrier to replace the implicit fork-join barrier.
1276         if (CI != MergableCIs.back()) {
1277           // TODO: Remove barrier if the merged parallel region includes the
1278           // 'nowait' clause.
1279           OMPInfoCache.OMPBuilder.createBarrier(
1280               InsertPointTy(NewCI->getParent(),
1281                             NewCI->getNextNode()->getIterator()),
1282               OMPD_parallel);
1283         }
1284 
1285         CI->eraseFromParent();
1286       }
1287 
1288       assert(OutlinedFn != OriginalFn && "Outlining failed");
1289       CGUpdater.registerOutlinedFunction(*OriginalFn, *OutlinedFn);
1290       CGUpdater.reanalyzeFunction(*OriginalFn);
1291 
1292       NumOpenMPParallelRegionsMerged += MergableCIs.size();
1293 
1294       return true;
1295     };
1296 
1297     // Helper function that identifes sequences of
1298     // __kmpc_fork_call uses in a basic block.
1299     auto DetectPRsCB = [&](Use &U, Function &F) {
1300       CallInst *CI = getCallIfRegularCall(U, &RFI);
1301       BB2PRMap[CI->getParent()].insert(CI);
1302 
1303       return false;
1304     };
1305 
1306     BB2PRMap.clear();
1307     RFI.foreachUse(SCC, DetectPRsCB);
1308     SmallVector<SmallVector<CallInst *, 4>, 4> MergableCIsVector;
1309     // Find mergable parallel regions within a basic block that are
1310     // safe to merge, that is any in-between instructions can safely
1311     // execute in parallel after merging.
1312     // TODO: support merging across basic-blocks.
1313     for (auto &It : BB2PRMap) {
1314       auto &CIs = It.getSecond();
1315       if (CIs.size() < 2)
1316         continue;
1317 
1318       BasicBlock *BB = It.getFirst();
1319       SmallVector<CallInst *, 4> MergableCIs;
1320 
1321       /// Returns true if the instruction is mergable, false otherwise.
1322       /// A terminator instruction is unmergable by definition since merging
1323       /// works within a BB. Instructions before the mergable region are
1324       /// mergable if they are not calls to OpenMP runtime functions that may
1325       /// set different execution parameters for subsequent parallel regions.
1326       /// Instructions in-between parallel regions are mergable if they are not
1327       /// calls to any non-intrinsic function since that may call a non-mergable
1328       /// OpenMP runtime function.
1329       auto IsMergable = [&](Instruction &I, bool IsBeforeMergableRegion) {
1330         // We do not merge across BBs, hence return false (unmergable) if the
1331         // instruction is a terminator.
1332         if (I.isTerminator())
1333           return false;
1334 
1335         if (!isa<CallInst>(&I))
1336           return true;
1337 
1338         CallInst *CI = cast<CallInst>(&I);
1339         if (IsBeforeMergableRegion) {
1340           Function *CalledFunction = CI->getCalledFunction();
1341           if (!CalledFunction)
1342             return false;
1343           // Return false (unmergable) if the call before the parallel
1344           // region calls an explicit affinity (proc_bind) or number of
1345           // threads (num_threads) compiler-generated function. Those settings
1346           // may be incompatible with following parallel regions.
1347           // TODO: ICV tracking to detect compatibility.
1348           for (const auto &RFI : UnmergableCallsInfo) {
1349             if (CalledFunction == RFI.Declaration)
1350               return false;
1351           }
1352         } else {
1353           // Return false (unmergable) if there is a call instruction
1354           // in-between parallel regions when it is not an intrinsic. It
1355           // may call an unmergable OpenMP runtime function in its callpath.
1356           // TODO: Keep track of possible OpenMP calls in the callpath.
1357           if (!isa<IntrinsicInst>(CI))
1358             return false;
1359         }
1360 
1361         return true;
1362       };
1363       // Find maximal number of parallel region CIs that are safe to merge.
1364       for (auto It = BB->begin(), End = BB->end(); It != End;) {
1365         Instruction &I = *It;
1366         ++It;
1367 
1368         if (CIs.count(&I)) {
1369           MergableCIs.push_back(cast<CallInst>(&I));
1370           continue;
1371         }
1372 
1373         // Continue expanding if the instruction is mergable.
1374         if (IsMergable(I, MergableCIs.empty()))
1375           continue;
1376 
1377         // Forward the instruction iterator to skip the next parallel region
1378         // since there is an unmergable instruction which can affect it.
1379         for (; It != End; ++It) {
1380           Instruction &SkipI = *It;
1381           if (CIs.count(&SkipI)) {
1382             LLVM_DEBUG(dbgs() << TAG << "Skip parallel region " << SkipI
1383                               << " due to " << I << "\n");
1384             ++It;
1385             break;
1386           }
1387         }
1388 
1389         // Store mergable regions found.
1390         if (MergableCIs.size() > 1) {
1391           MergableCIsVector.push_back(MergableCIs);
1392           LLVM_DEBUG(dbgs() << TAG << "Found " << MergableCIs.size()
1393                             << " parallel regions in block " << BB->getName()
1394                             << " of function " << BB->getParent()->getName()
1395                             << "\n";);
1396         }
1397 
1398         MergableCIs.clear();
1399       }
1400 
1401       if (!MergableCIsVector.empty()) {
1402         Changed = true;
1403 
1404         for (auto &MergableCIs : MergableCIsVector)
1405           Merge(MergableCIs, BB);
1406         MergableCIsVector.clear();
1407       }
1408     }
1409 
1410     if (Changed) {
1411       /// Re-collect use for fork calls, emitted barrier calls, and
1412       /// any emitted master/end_master calls.
1413       OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_fork_call);
1414       OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_barrier);
1415       OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_master);
1416       OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_end_master);
1417     }
1418 
1419     return Changed;
1420   }
1421 
1422   /// Try to delete parallel regions if possible.
1423   bool deleteParallelRegions() {
1424     const unsigned CallbackCalleeOperand = 2;
1425 
1426     OMPInformationCache::RuntimeFunctionInfo &RFI =
1427         OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
1428 
1429     if (!RFI.Declaration)
1430       return false;
1431 
1432     bool Changed = false;
1433     auto DeleteCallCB = [&](Use &U, Function &) {
1434       CallInst *CI = getCallIfRegularCall(U);
1435       if (!CI)
1436         return false;
1437       auto *Fn = dyn_cast<Function>(
1438           CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts());
1439       if (!Fn)
1440         return false;
1441       if (!Fn->onlyReadsMemory())
1442         return false;
1443       if (!Fn->hasFnAttribute(Attribute::WillReturn))
1444         return false;
1445 
1446       LLVM_DEBUG(dbgs() << TAG << "Delete read-only parallel region in "
1447                         << CI->getCaller()->getName() << "\n");
1448 
1449       auto Remark = [&](OptimizationRemark OR) {
1450         return OR << "Removing parallel region with no side-effects.";
1451       };
1452       emitRemark<OptimizationRemark>(CI, "OMP160", Remark);
1453 
1454       CGUpdater.removeCallSite(*CI);
1455       CI->eraseFromParent();
1456       Changed = true;
1457       ++NumOpenMPParallelRegionsDeleted;
1458       return true;
1459     };
1460 
1461     RFI.foreachUse(SCC, DeleteCallCB);
1462 
1463     return Changed;
1464   }
1465 
1466   /// Try to eliminate runtime calls by reusing existing ones.
1467   bool deduplicateRuntimeCalls() {
1468     bool Changed = false;
1469 
1470     RuntimeFunction DeduplicableRuntimeCallIDs[] = {
1471         OMPRTL_omp_get_num_threads,
1472         OMPRTL_omp_in_parallel,
1473         OMPRTL_omp_get_cancellation,
1474         OMPRTL_omp_get_thread_limit,
1475         OMPRTL_omp_get_supported_active_levels,
1476         OMPRTL_omp_get_level,
1477         OMPRTL_omp_get_ancestor_thread_num,
1478         OMPRTL_omp_get_team_size,
1479         OMPRTL_omp_get_active_level,
1480         OMPRTL_omp_in_final,
1481         OMPRTL_omp_get_proc_bind,
1482         OMPRTL_omp_get_num_places,
1483         OMPRTL_omp_get_num_procs,
1484         OMPRTL_omp_get_place_num,
1485         OMPRTL_omp_get_partition_num_places,
1486         OMPRTL_omp_get_partition_place_nums};
1487 
1488     // Global-tid is handled separately.
1489     SmallSetVector<Value *, 16> GTIdArgs;
1490     collectGlobalThreadIdArguments(GTIdArgs);
1491     LLVM_DEBUG(dbgs() << TAG << "Found " << GTIdArgs.size()
1492                       << " global thread ID arguments\n");
1493 
1494     for (Function *F : SCC) {
1495       for (auto DeduplicableRuntimeCallID : DeduplicableRuntimeCallIDs)
1496         Changed |= deduplicateRuntimeCalls(
1497             *F, OMPInfoCache.RFIs[DeduplicableRuntimeCallID]);
1498 
1499       // __kmpc_global_thread_num is special as we can replace it with an
1500       // argument in enough cases to make it worth trying.
1501       Value *GTIdArg = nullptr;
1502       for (Argument &Arg : F->args())
1503         if (GTIdArgs.count(&Arg)) {
1504           GTIdArg = &Arg;
1505           break;
1506         }
1507       Changed |= deduplicateRuntimeCalls(
1508           *F, OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num], GTIdArg);
1509     }
1510 
1511     return Changed;
1512   }
1513 
1514   /// Tries to remove known runtime symbols that are optional from the module.
1515   bool removeRuntimeSymbols() {
1516     // The RPC client symbol is defined in `libc` and indicates that something
1517     // required an RPC server. If its users were all optimized out then we can
1518     // safely remove it.
1519     // TODO: This should be somewhere more common in the future.
1520     if (GlobalVariable *GV = M.getNamedGlobal("__llvm_libc_rpc_client")) {
1521       if (!GV->getType()->isPointerTy())
1522         return false;
1523 
1524       Constant *C = GV->getInitializer();
1525       if (!C)
1526         return false;
1527 
1528       // Check to see if the only user of the RPC client is the external handle.
1529       GlobalVariable *Client = dyn_cast<GlobalVariable>(C->stripPointerCasts());
1530       if (!Client || Client->getNumUses() > 1 ||
1531           Client->user_back() != GV->getInitializer())
1532         return false;
1533 
1534       Client->replaceAllUsesWith(PoisonValue::get(Client->getType()));
1535       Client->eraseFromParent();
1536 
1537       GV->replaceAllUsesWith(PoisonValue::get(GV->getType()));
1538       GV->eraseFromParent();
1539 
1540       return true;
1541     }
1542     return false;
1543   }
1544 
1545   /// Tries to hide the latency of runtime calls that involve host to
1546   /// device memory transfers by splitting them into their "issue" and "wait"
1547   /// versions. The "issue" is moved upwards as much as possible. The "wait" is
1548   /// moved downards as much as possible. The "issue" issues the memory transfer
1549   /// asynchronously, returning a handle. The "wait" waits in the returned
1550   /// handle for the memory transfer to finish.
1551   bool hideMemTransfersLatency() {
1552     auto &RFI = OMPInfoCache.RFIs[OMPRTL___tgt_target_data_begin_mapper];
1553     bool Changed = false;
1554     auto SplitMemTransfers = [&](Use &U, Function &Decl) {
1555       auto *RTCall = getCallIfRegularCall(U, &RFI);
1556       if (!RTCall)
1557         return false;
1558 
1559       OffloadArray OffloadArrays[3];
1560       if (!getValuesInOffloadArrays(*RTCall, OffloadArrays))
1561         return false;
1562 
1563       LLVM_DEBUG(dumpValuesInOffloadArrays(OffloadArrays));
1564 
1565       // TODO: Check if can be moved upwards.
1566       bool WasSplit = false;
1567       Instruction *WaitMovementPoint = canBeMovedDownwards(*RTCall);
1568       if (WaitMovementPoint)
1569         WasSplit = splitTargetDataBeginRTC(*RTCall, *WaitMovementPoint);
1570 
1571       Changed |= WasSplit;
1572       return WasSplit;
1573     };
1574     if (OMPInfoCache.runtimeFnsAvailable(
1575             {OMPRTL___tgt_target_data_begin_mapper_issue,
1576              OMPRTL___tgt_target_data_begin_mapper_wait}))
1577       RFI.foreachUse(SCC, SplitMemTransfers);
1578 
1579     return Changed;
1580   }
1581 
1582   void analysisGlobalization() {
1583     auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
1584 
1585     auto CheckGlobalization = [&](Use &U, Function &Decl) {
1586       if (CallInst *CI = getCallIfRegularCall(U, &RFI)) {
1587         auto Remark = [&](OptimizationRemarkMissed ORM) {
1588           return ORM
1589                  << "Found thread data sharing on the GPU. "
1590                  << "Expect degraded performance due to data globalization.";
1591         };
1592         emitRemark<OptimizationRemarkMissed>(CI, "OMP112", Remark);
1593       }
1594 
1595       return false;
1596     };
1597 
1598     RFI.foreachUse(SCC, CheckGlobalization);
1599   }
1600 
1601   /// Maps the values stored in the offload arrays passed as arguments to
1602   /// \p RuntimeCall into the offload arrays in \p OAs.
1603   bool getValuesInOffloadArrays(CallInst &RuntimeCall,
1604                                 MutableArrayRef<OffloadArray> OAs) {
1605     assert(OAs.size() == 3 && "Need space for three offload arrays!");
1606 
1607     // A runtime call that involves memory offloading looks something like:
1608     // call void @__tgt_target_data_begin_mapper(arg0, arg1,
1609     //   i8** %offload_baseptrs, i8** %offload_ptrs, i64* %offload_sizes,
1610     // ...)
1611     // So, the idea is to access the allocas that allocate space for these
1612     // offload arrays, offload_baseptrs, offload_ptrs, offload_sizes.
1613     // Therefore:
1614     // i8** %offload_baseptrs.
1615     Value *BasePtrsArg =
1616         RuntimeCall.getArgOperand(OffloadArray::BasePtrsArgNum);
1617     // i8** %offload_ptrs.
1618     Value *PtrsArg = RuntimeCall.getArgOperand(OffloadArray::PtrsArgNum);
1619     // i8** %offload_sizes.
1620     Value *SizesArg = RuntimeCall.getArgOperand(OffloadArray::SizesArgNum);
1621 
1622     // Get values stored in **offload_baseptrs.
1623     auto *V = getUnderlyingObject(BasePtrsArg);
1624     if (!isa<AllocaInst>(V))
1625       return false;
1626     auto *BasePtrsArray = cast<AllocaInst>(V);
1627     if (!OAs[0].initialize(*BasePtrsArray, RuntimeCall))
1628       return false;
1629 
1630     // Get values stored in **offload_baseptrs.
1631     V = getUnderlyingObject(PtrsArg);
1632     if (!isa<AllocaInst>(V))
1633       return false;
1634     auto *PtrsArray = cast<AllocaInst>(V);
1635     if (!OAs[1].initialize(*PtrsArray, RuntimeCall))
1636       return false;
1637 
1638     // Get values stored in **offload_sizes.
1639     V = getUnderlyingObject(SizesArg);
1640     // If it's a [constant] global array don't analyze it.
1641     if (isa<GlobalValue>(V))
1642       return isa<Constant>(V);
1643     if (!isa<AllocaInst>(V))
1644       return false;
1645 
1646     auto *SizesArray = cast<AllocaInst>(V);
1647     if (!OAs[2].initialize(*SizesArray, RuntimeCall))
1648       return false;
1649 
1650     return true;
1651   }
1652 
1653   /// Prints the values in the OffloadArrays \p OAs using LLVM_DEBUG.
1654   /// For now this is a way to test that the function getValuesInOffloadArrays
1655   /// is working properly.
1656   /// TODO: Move this to a unittest when unittests are available for OpenMPOpt.
1657   void dumpValuesInOffloadArrays(ArrayRef<OffloadArray> OAs) {
1658     assert(OAs.size() == 3 && "There are three offload arrays to debug!");
1659 
1660     LLVM_DEBUG(dbgs() << TAG << " Successfully got offload values:\n");
1661     std::string ValuesStr;
1662     raw_string_ostream Printer(ValuesStr);
1663     std::string Separator = " --- ";
1664 
1665     for (auto *BP : OAs[0].StoredValues) {
1666       BP->print(Printer);
1667       Printer << Separator;
1668     }
1669     LLVM_DEBUG(dbgs() << "\t\toffload_baseptrs: " << Printer.str() << "\n");
1670     ValuesStr.clear();
1671 
1672     for (auto *P : OAs[1].StoredValues) {
1673       P->print(Printer);
1674       Printer << Separator;
1675     }
1676     LLVM_DEBUG(dbgs() << "\t\toffload_ptrs: " << Printer.str() << "\n");
1677     ValuesStr.clear();
1678 
1679     for (auto *S : OAs[2].StoredValues) {
1680       S->print(Printer);
1681       Printer << Separator;
1682     }
1683     LLVM_DEBUG(dbgs() << "\t\toffload_sizes: " << Printer.str() << "\n");
1684   }
1685 
1686   /// Returns the instruction where the "wait" counterpart \p RuntimeCall can be
1687   /// moved. Returns nullptr if the movement is not possible, or not worth it.
1688   Instruction *canBeMovedDownwards(CallInst &RuntimeCall) {
1689     // FIXME: This traverses only the BasicBlock where RuntimeCall is.
1690     //  Make it traverse the CFG.
1691 
1692     Instruction *CurrentI = &RuntimeCall;
1693     bool IsWorthIt = false;
1694     while ((CurrentI = CurrentI->getNextNode())) {
1695 
1696       // TODO: Once we detect the regions to be offloaded we should use the
1697       //  alias analysis manager to check if CurrentI may modify one of
1698       //  the offloaded regions.
1699       if (CurrentI->mayHaveSideEffects() || CurrentI->mayReadFromMemory()) {
1700         if (IsWorthIt)
1701           return CurrentI;
1702 
1703         return nullptr;
1704       }
1705 
1706       // FIXME: For now if we move it over anything without side effect
1707       //  is worth it.
1708       IsWorthIt = true;
1709     }
1710 
1711     // Return end of BasicBlock.
1712     return RuntimeCall.getParent()->getTerminator();
1713   }
1714 
1715   /// Splits \p RuntimeCall into its "issue" and "wait" counterparts.
1716   bool splitTargetDataBeginRTC(CallInst &RuntimeCall,
1717                                Instruction &WaitMovementPoint) {
1718     // Create stack allocated handle (__tgt_async_info) at the beginning of the
1719     // function. Used for storing information of the async transfer, allowing to
1720     // wait on it later.
1721     auto &IRBuilder = OMPInfoCache.OMPBuilder;
1722     Function *F = RuntimeCall.getCaller();
1723     BasicBlock &Entry = F->getEntryBlock();
1724     IRBuilder.Builder.SetInsertPoint(&Entry,
1725                                      Entry.getFirstNonPHIOrDbgOrAlloca());
1726     Value *Handle = IRBuilder.Builder.CreateAlloca(
1727         IRBuilder.AsyncInfo, /*ArraySize=*/nullptr, "handle");
1728     Handle =
1729         IRBuilder.Builder.CreateAddrSpaceCast(Handle, IRBuilder.AsyncInfoPtr);
1730 
1731     // Add "issue" runtime call declaration:
1732     // declare %struct.tgt_async_info @__tgt_target_data_begin_issue(i64, i32,
1733     //   i8**, i8**, i64*, i64*)
1734     FunctionCallee IssueDecl = IRBuilder.getOrCreateRuntimeFunction(
1735         M, OMPRTL___tgt_target_data_begin_mapper_issue);
1736 
1737     // Change RuntimeCall call site for its asynchronous version.
1738     SmallVector<Value *, 16> Args;
1739     for (auto &Arg : RuntimeCall.args())
1740       Args.push_back(Arg.get());
1741     Args.push_back(Handle);
1742 
1743     CallInst *IssueCallsite =
1744         CallInst::Create(IssueDecl, Args, /*NameStr=*/"", &RuntimeCall);
1745     OMPInfoCache.setCallingConvention(IssueDecl, IssueCallsite);
1746     RuntimeCall.eraseFromParent();
1747 
1748     // Add "wait" runtime call declaration:
1749     // declare void @__tgt_target_data_begin_wait(i64, %struct.__tgt_async_info)
1750     FunctionCallee WaitDecl = IRBuilder.getOrCreateRuntimeFunction(
1751         M, OMPRTL___tgt_target_data_begin_mapper_wait);
1752 
1753     Value *WaitParams[2] = {
1754         IssueCallsite->getArgOperand(
1755             OffloadArray::DeviceIDArgNum), // device_id.
1756         Handle                             // handle to wait on.
1757     };
1758     CallInst *WaitCallsite = CallInst::Create(
1759         WaitDecl, WaitParams, /*NameStr=*/"", &WaitMovementPoint);
1760     OMPInfoCache.setCallingConvention(WaitDecl, WaitCallsite);
1761 
1762     return true;
1763   }
1764 
1765   static Value *combinedIdentStruct(Value *CurrentIdent, Value *NextIdent,
1766                                     bool GlobalOnly, bool &SingleChoice) {
1767     if (CurrentIdent == NextIdent)
1768       return CurrentIdent;
1769 
1770     // TODO: Figure out how to actually combine multiple debug locations. For
1771     //       now we just keep an existing one if there is a single choice.
1772     if (!GlobalOnly || isa<GlobalValue>(NextIdent)) {
1773       SingleChoice = !CurrentIdent;
1774       return NextIdent;
1775     }
1776     return nullptr;
1777   }
1778 
1779   /// Return an `struct ident_t*` value that represents the ones used in the
1780   /// calls of \p RFI inside of \p F. If \p GlobalOnly is true, we will not
1781   /// return a local `struct ident_t*`. For now, if we cannot find a suitable
1782   /// return value we create one from scratch. We also do not yet combine
1783   /// information, e.g., the source locations, see combinedIdentStruct.
1784   Value *
1785   getCombinedIdentFromCallUsesIn(OMPInformationCache::RuntimeFunctionInfo &RFI,
1786                                  Function &F, bool GlobalOnly) {
1787     bool SingleChoice = true;
1788     Value *Ident = nullptr;
1789     auto CombineIdentStruct = [&](Use &U, Function &Caller) {
1790       CallInst *CI = getCallIfRegularCall(U, &RFI);
1791       if (!CI || &F != &Caller)
1792         return false;
1793       Ident = combinedIdentStruct(Ident, CI->getArgOperand(0),
1794                                   /* GlobalOnly */ true, SingleChoice);
1795       return false;
1796     };
1797     RFI.foreachUse(SCC, CombineIdentStruct);
1798 
1799     if (!Ident || !SingleChoice) {
1800       // The IRBuilder uses the insertion block to get to the module, this is
1801       // unfortunate but we work around it for now.
1802       if (!OMPInfoCache.OMPBuilder.getInsertionPoint().getBlock())
1803         OMPInfoCache.OMPBuilder.updateToLocation(OpenMPIRBuilder::InsertPointTy(
1804             &F.getEntryBlock(), F.getEntryBlock().begin()));
1805       // Create a fallback location if non was found.
1806       // TODO: Use the debug locations of the calls instead.
1807       uint32_t SrcLocStrSize;
1808       Constant *Loc =
1809           OMPInfoCache.OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
1810       Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(Loc, SrcLocStrSize);
1811     }
1812     return Ident;
1813   }
1814 
1815   /// Try to eliminate calls of \p RFI in \p F by reusing an existing one or
1816   /// \p ReplVal if given.
1817   bool deduplicateRuntimeCalls(Function &F,
1818                                OMPInformationCache::RuntimeFunctionInfo &RFI,
1819                                Value *ReplVal = nullptr) {
1820     auto *UV = RFI.getUseVector(F);
1821     if (!UV || UV->size() + (ReplVal != nullptr) < 2)
1822       return false;
1823 
1824     LLVM_DEBUG(
1825         dbgs() << TAG << "Deduplicate " << UV->size() << " uses of " << RFI.Name
1826                << (ReplVal ? " with an existing value\n" : "\n") << "\n");
1827 
1828     assert((!ReplVal || (isa<Argument>(ReplVal) &&
1829                          cast<Argument>(ReplVal)->getParent() == &F)) &&
1830            "Unexpected replacement value!");
1831 
1832     // TODO: Use dominance to find a good position instead.
1833     auto CanBeMoved = [this](CallBase &CB) {
1834       unsigned NumArgs = CB.arg_size();
1835       if (NumArgs == 0)
1836         return true;
1837       if (CB.getArgOperand(0)->getType() != OMPInfoCache.OMPBuilder.IdentPtr)
1838         return false;
1839       for (unsigned U = 1; U < NumArgs; ++U)
1840         if (isa<Instruction>(CB.getArgOperand(U)))
1841           return false;
1842       return true;
1843     };
1844 
1845     if (!ReplVal) {
1846       auto *DT =
1847           OMPInfoCache.getAnalysisResultForFunction<DominatorTreeAnalysis>(F);
1848       if (!DT)
1849         return false;
1850       Instruction *IP = nullptr;
1851       for (Use *U : *UV) {
1852         if (CallInst *CI = getCallIfRegularCall(*U, &RFI)) {
1853           if (IP)
1854             IP = DT->findNearestCommonDominator(IP, CI);
1855           else
1856             IP = CI;
1857           if (!CanBeMoved(*CI))
1858             continue;
1859           if (!ReplVal)
1860             ReplVal = CI;
1861         }
1862       }
1863       if (!ReplVal)
1864         return false;
1865       assert(IP && "Expected insertion point!");
1866       cast<Instruction>(ReplVal)->moveBefore(IP);
1867     }
1868 
1869     // If we use a call as a replacement value we need to make sure the ident is
1870     // valid at the new location. For now we just pick a global one, either
1871     // existing and used by one of the calls, or created from scratch.
1872     if (CallBase *CI = dyn_cast<CallBase>(ReplVal)) {
1873       if (!CI->arg_empty() &&
1874           CI->getArgOperand(0)->getType() == OMPInfoCache.OMPBuilder.IdentPtr) {
1875         Value *Ident = getCombinedIdentFromCallUsesIn(RFI, F,
1876                                                       /* GlobalOnly */ true);
1877         CI->setArgOperand(0, Ident);
1878       }
1879     }
1880 
1881     bool Changed = false;
1882     auto ReplaceAndDeleteCB = [&](Use &U, Function &Caller) {
1883       CallInst *CI = getCallIfRegularCall(U, &RFI);
1884       if (!CI || CI == ReplVal || &F != &Caller)
1885         return false;
1886       assert(CI->getCaller() == &F && "Unexpected call!");
1887 
1888       auto Remark = [&](OptimizationRemark OR) {
1889         return OR << "OpenMP runtime call "
1890                   << ore::NV("OpenMPOptRuntime", RFI.Name) << " deduplicated.";
1891       };
1892       if (CI->getDebugLoc())
1893         emitRemark<OptimizationRemark>(CI, "OMP170", Remark);
1894       else
1895         emitRemark<OptimizationRemark>(&F, "OMP170", Remark);
1896 
1897       CGUpdater.removeCallSite(*CI);
1898       CI->replaceAllUsesWith(ReplVal);
1899       CI->eraseFromParent();
1900       ++NumOpenMPRuntimeCallsDeduplicated;
1901       Changed = true;
1902       return true;
1903     };
1904     RFI.foreachUse(SCC, ReplaceAndDeleteCB);
1905 
1906     return Changed;
1907   }
1908 
1909   /// Collect arguments that represent the global thread id in \p GTIdArgs.
1910   void collectGlobalThreadIdArguments(SmallSetVector<Value *, 16> &GTIdArgs) {
1911     // TODO: Below we basically perform a fixpoint iteration with a pessimistic
1912     //       initialization. We could define an AbstractAttribute instead and
1913     //       run the Attributor here once it can be run as an SCC pass.
1914 
1915     // Helper to check the argument \p ArgNo at all call sites of \p F for
1916     // a GTId.
1917     auto CallArgOpIsGTId = [&](Function &F, unsigned ArgNo, CallInst &RefCI) {
1918       if (!F.hasLocalLinkage())
1919         return false;
1920       for (Use &U : F.uses()) {
1921         if (CallInst *CI = getCallIfRegularCall(U)) {
1922           Value *ArgOp = CI->getArgOperand(ArgNo);
1923           if (CI == &RefCI || GTIdArgs.count(ArgOp) ||
1924               getCallIfRegularCall(
1925                   *ArgOp, &OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num]))
1926             continue;
1927         }
1928         return false;
1929       }
1930       return true;
1931     };
1932 
1933     // Helper to identify uses of a GTId as GTId arguments.
1934     auto AddUserArgs = [&](Value &GTId) {
1935       for (Use &U : GTId.uses())
1936         if (CallInst *CI = dyn_cast<CallInst>(U.getUser()))
1937           if (CI->isArgOperand(&U))
1938             if (Function *Callee = CI->getCalledFunction())
1939               if (CallArgOpIsGTId(*Callee, U.getOperandNo(), *CI))
1940                 GTIdArgs.insert(Callee->getArg(U.getOperandNo()));
1941     };
1942 
1943     // The argument users of __kmpc_global_thread_num calls are GTIds.
1944     OMPInformationCache::RuntimeFunctionInfo &GlobThreadNumRFI =
1945         OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num];
1946 
1947     GlobThreadNumRFI.foreachUse(SCC, [&](Use &U, Function &F) {
1948       if (CallInst *CI = getCallIfRegularCall(U, &GlobThreadNumRFI))
1949         AddUserArgs(*CI);
1950       return false;
1951     });
1952 
1953     // Transitively search for more arguments by looking at the users of the
1954     // ones we know already. During the search the GTIdArgs vector is extended
1955     // so we cannot cache the size nor can we use a range based for.
1956     for (unsigned U = 0; U < GTIdArgs.size(); ++U)
1957       AddUserArgs(*GTIdArgs[U]);
1958   }
1959 
1960   /// Kernel (=GPU) optimizations and utility functions
1961   ///
1962   ///{{
1963 
1964   /// Cache to remember the unique kernel for a function.
1965   DenseMap<Function *, std::optional<Kernel>> UniqueKernelMap;
1966 
1967   /// Find the unique kernel that will execute \p F, if any.
1968   Kernel getUniqueKernelFor(Function &F);
1969 
1970   /// Find the unique kernel that will execute \p I, if any.
1971   Kernel getUniqueKernelFor(Instruction &I) {
1972     return getUniqueKernelFor(*I.getFunction());
1973   }
1974 
1975   /// Rewrite the device (=GPU) code state machine create in non-SPMD mode in
1976   /// the cases we can avoid taking the address of a function.
1977   bool rewriteDeviceCodeStateMachine();
1978 
1979   ///
1980   ///}}
1981 
1982   /// Emit a remark generically
1983   ///
1984   /// This template function can be used to generically emit a remark. The
1985   /// RemarkKind should be one of the following:
1986   ///   - OptimizationRemark to indicate a successful optimization attempt
1987   ///   - OptimizationRemarkMissed to report a failed optimization attempt
1988   ///   - OptimizationRemarkAnalysis to provide additional information about an
1989   ///     optimization attempt
1990   ///
1991   /// The remark is built using a callback function provided by the caller that
1992   /// takes a RemarkKind as input and returns a RemarkKind.
1993   template <typename RemarkKind, typename RemarkCallBack>
1994   void emitRemark(Instruction *I, StringRef RemarkName,
1995                   RemarkCallBack &&RemarkCB) const {
1996     Function *F = I->getParent()->getParent();
1997     auto &ORE = OREGetter(F);
1998 
1999     if (RemarkName.starts_with("OMP"))
2000       ORE.emit([&]() {
2001         return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I))
2002                << " [" << RemarkName << "]";
2003       });
2004     else
2005       ORE.emit(
2006           [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I)); });
2007   }
2008 
2009   /// Emit a remark on a function.
2010   template <typename RemarkKind, typename RemarkCallBack>
2011   void emitRemark(Function *F, StringRef RemarkName,
2012                   RemarkCallBack &&RemarkCB) const {
2013     auto &ORE = OREGetter(F);
2014 
2015     if (RemarkName.starts_with("OMP"))
2016       ORE.emit([&]() {
2017         return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F))
2018                << " [" << RemarkName << "]";
2019       });
2020     else
2021       ORE.emit(
2022           [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F)); });
2023   }
2024 
2025   /// The underlying module.
2026   Module &M;
2027 
2028   /// The SCC we are operating on.
2029   SmallVectorImpl<Function *> &SCC;
2030 
2031   /// Callback to update the call graph, the first argument is a removed call,
2032   /// the second an optional replacement call.
2033   CallGraphUpdater &CGUpdater;
2034 
2035   /// Callback to get an OptimizationRemarkEmitter from a Function *
2036   OptimizationRemarkGetter OREGetter;
2037 
2038   /// OpenMP-specific information cache. Also Used for Attributor runs.
2039   OMPInformationCache &OMPInfoCache;
2040 
2041   /// Attributor instance.
2042   Attributor &A;
2043 
2044   /// Helper function to run Attributor on SCC.
2045   bool runAttributor(bool IsModulePass) {
2046     if (SCC.empty())
2047       return false;
2048 
2049     registerAAs(IsModulePass);
2050 
2051     ChangeStatus Changed = A.run();
2052 
2053     LLVM_DEBUG(dbgs() << "[Attributor] Done with " << SCC.size()
2054                       << " functions, result: " << Changed << ".\n");
2055 
2056     if (Changed == ChangeStatus::CHANGED)
2057       OMPInfoCache.invalidateAnalyses();
2058 
2059     return Changed == ChangeStatus::CHANGED;
2060   }
2061 
2062   void registerFoldRuntimeCall(RuntimeFunction RF);
2063 
2064   /// Populate the Attributor with abstract attribute opportunities in the
2065   /// functions.
2066   void registerAAs(bool IsModulePass);
2067 
2068 public:
2069   /// Callback to register AAs for live functions, including internal functions
2070   /// marked live during the traversal.
2071   static void registerAAsForFunction(Attributor &A, const Function &F);
2072 };
2073 
2074 Kernel OpenMPOpt::getUniqueKernelFor(Function &F) {
2075   if (OMPInfoCache.CGSCC && !OMPInfoCache.CGSCC->empty() &&
2076       !OMPInfoCache.CGSCC->contains(&F))
2077     return nullptr;
2078 
2079   // Use a scope to keep the lifetime of the CachedKernel short.
2080   {
2081     std::optional<Kernel> &CachedKernel = UniqueKernelMap[&F];
2082     if (CachedKernel)
2083       return *CachedKernel;
2084 
2085     // TODO: We should use an AA to create an (optimistic and callback
2086     //       call-aware) call graph. For now we stick to simple patterns that
2087     //       are less powerful, basically the worst fixpoint.
2088     if (isOpenMPKernel(F)) {
2089       CachedKernel = Kernel(&F);
2090       return *CachedKernel;
2091     }
2092 
2093     CachedKernel = nullptr;
2094     if (!F.hasLocalLinkage()) {
2095 
2096       // See https://openmp.llvm.org/remarks/OptimizationRemarks.html
2097       auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2098         return ORA << "Potentially unknown OpenMP target region caller.";
2099       };
2100       emitRemark<OptimizationRemarkAnalysis>(&F, "OMP100", Remark);
2101 
2102       return nullptr;
2103     }
2104   }
2105 
2106   auto GetUniqueKernelForUse = [&](const Use &U) -> Kernel {
2107     if (auto *Cmp = dyn_cast<ICmpInst>(U.getUser())) {
2108       // Allow use in equality comparisons.
2109       if (Cmp->isEquality())
2110         return getUniqueKernelFor(*Cmp);
2111       return nullptr;
2112     }
2113     if (auto *CB = dyn_cast<CallBase>(U.getUser())) {
2114       // Allow direct calls.
2115       if (CB->isCallee(&U))
2116         return getUniqueKernelFor(*CB);
2117 
2118       OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
2119           OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
2120       // Allow the use in __kmpc_parallel_51 calls.
2121       if (OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI))
2122         return getUniqueKernelFor(*CB);
2123       return nullptr;
2124     }
2125     // Disallow every other use.
2126     return nullptr;
2127   };
2128 
2129   // TODO: In the future we want to track more than just a unique kernel.
2130   SmallPtrSet<Kernel, 2> PotentialKernels;
2131   OMPInformationCache::foreachUse(F, [&](const Use &U) {
2132     PotentialKernels.insert(GetUniqueKernelForUse(U));
2133   });
2134 
2135   Kernel K = nullptr;
2136   if (PotentialKernels.size() == 1)
2137     K = *PotentialKernels.begin();
2138 
2139   // Cache the result.
2140   UniqueKernelMap[&F] = K;
2141 
2142   return K;
2143 }
2144 
2145 bool OpenMPOpt::rewriteDeviceCodeStateMachine() {
2146   OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
2147       OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
2148 
2149   bool Changed = false;
2150   if (!KernelParallelRFI)
2151     return Changed;
2152 
2153   // If we have disabled state machine changes, exit
2154   if (DisableOpenMPOptStateMachineRewrite)
2155     return Changed;
2156 
2157   for (Function *F : SCC) {
2158 
2159     // Check if the function is a use in a __kmpc_parallel_51 call at
2160     // all.
2161     bool UnknownUse = false;
2162     bool KernelParallelUse = false;
2163     unsigned NumDirectCalls = 0;
2164 
2165     SmallVector<Use *, 2> ToBeReplacedStateMachineUses;
2166     OMPInformationCache::foreachUse(*F, [&](Use &U) {
2167       if (auto *CB = dyn_cast<CallBase>(U.getUser()))
2168         if (CB->isCallee(&U)) {
2169           ++NumDirectCalls;
2170           return;
2171         }
2172 
2173       if (isa<ICmpInst>(U.getUser())) {
2174         ToBeReplacedStateMachineUses.push_back(&U);
2175         return;
2176       }
2177 
2178       // Find wrapper functions that represent parallel kernels.
2179       CallInst *CI =
2180           OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI);
2181       const unsigned int WrapperFunctionArgNo = 6;
2182       if (!KernelParallelUse && CI &&
2183           CI->getArgOperandNo(&U) == WrapperFunctionArgNo) {
2184         KernelParallelUse = true;
2185         ToBeReplacedStateMachineUses.push_back(&U);
2186         return;
2187       }
2188       UnknownUse = true;
2189     });
2190 
2191     // Do not emit a remark if we haven't seen a __kmpc_parallel_51
2192     // use.
2193     if (!KernelParallelUse)
2194       continue;
2195 
2196     // If this ever hits, we should investigate.
2197     // TODO: Checking the number of uses is not a necessary restriction and
2198     // should be lifted.
2199     if (UnknownUse || NumDirectCalls != 1 ||
2200         ToBeReplacedStateMachineUses.size() > 2) {
2201       auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2202         return ORA << "Parallel region is used in "
2203                    << (UnknownUse ? "unknown" : "unexpected")
2204                    << " ways. Will not attempt to rewrite the state machine.";
2205       };
2206       emitRemark<OptimizationRemarkAnalysis>(F, "OMP101", Remark);
2207       continue;
2208     }
2209 
2210     // Even if we have __kmpc_parallel_51 calls, we (for now) give
2211     // up if the function is not called from a unique kernel.
2212     Kernel K = getUniqueKernelFor(*F);
2213     if (!K) {
2214       auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2215         return ORA << "Parallel region is not called from a unique kernel. "
2216                       "Will not attempt to rewrite the state machine.";
2217       };
2218       emitRemark<OptimizationRemarkAnalysis>(F, "OMP102", Remark);
2219       continue;
2220     }
2221 
2222     // We now know F is a parallel body function called only from the kernel K.
2223     // We also identified the state machine uses in which we replace the
2224     // function pointer by a new global symbol for identification purposes. This
2225     // ensures only direct calls to the function are left.
2226 
2227     Module &M = *F->getParent();
2228     Type *Int8Ty = Type::getInt8Ty(M.getContext());
2229 
2230     auto *ID = new GlobalVariable(
2231         M, Int8Ty, /* isConstant */ true, GlobalValue::PrivateLinkage,
2232         UndefValue::get(Int8Ty), F->getName() + ".ID");
2233 
2234     for (Use *U : ToBeReplacedStateMachineUses)
2235       U->set(ConstantExpr::getPointerBitCastOrAddrSpaceCast(
2236           ID, U->get()->getType()));
2237 
2238     ++NumOpenMPParallelRegionsReplacedInGPUStateMachine;
2239 
2240     Changed = true;
2241   }
2242 
2243   return Changed;
2244 }
2245 
2246 /// Abstract Attribute for tracking ICV values.
2247 struct AAICVTracker : public StateWrapper<BooleanState, AbstractAttribute> {
2248   using Base = StateWrapper<BooleanState, AbstractAttribute>;
2249   AAICVTracker(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
2250 
2251   /// Returns true if value is assumed to be tracked.
2252   bool isAssumedTracked() const { return getAssumed(); }
2253 
2254   /// Returns true if value is known to be tracked.
2255   bool isKnownTracked() const { return getAssumed(); }
2256 
2257   /// Create an abstract attribute biew for the position \p IRP.
2258   static AAICVTracker &createForPosition(const IRPosition &IRP, Attributor &A);
2259 
2260   /// Return the value with which \p I can be replaced for specific \p ICV.
2261   virtual std::optional<Value *> getReplacementValue(InternalControlVar ICV,
2262                                                      const Instruction *I,
2263                                                      Attributor &A) const {
2264     return std::nullopt;
2265   }
2266 
2267   /// Return an assumed unique ICV value if a single candidate is found. If
2268   /// there cannot be one, return a nullptr. If it is not clear yet, return
2269   /// std::nullopt.
2270   virtual std::optional<Value *>
2271   getUniqueReplacementValue(InternalControlVar ICV) const = 0;
2272 
2273   // Currently only nthreads is being tracked.
2274   // this array will only grow with time.
2275   InternalControlVar TrackableICVs[1] = {ICV_nthreads};
2276 
2277   /// See AbstractAttribute::getName()
2278   const std::string getName() const override { return "AAICVTracker"; }
2279 
2280   /// See AbstractAttribute::getIdAddr()
2281   const char *getIdAddr() const override { return &ID; }
2282 
2283   /// This function should return true if the type of the \p AA is AAICVTracker
2284   static bool classof(const AbstractAttribute *AA) {
2285     return (AA->getIdAddr() == &ID);
2286   }
2287 
2288   static const char ID;
2289 };
2290 
2291 struct AAICVTrackerFunction : public AAICVTracker {
2292   AAICVTrackerFunction(const IRPosition &IRP, Attributor &A)
2293       : AAICVTracker(IRP, A) {}
2294 
2295   // FIXME: come up with better string.
2296   const std::string getAsStr(Attributor *) const override {
2297     return "ICVTrackerFunction";
2298   }
2299 
2300   // FIXME: come up with some stats.
2301   void trackStatistics() const override {}
2302 
2303   /// We don't manifest anything for this AA.
2304   ChangeStatus manifest(Attributor &A) override {
2305     return ChangeStatus::UNCHANGED;
2306   }
2307 
2308   // Map of ICV to their values at specific program point.
2309   EnumeratedArray<DenseMap<Instruction *, Value *>, InternalControlVar,
2310                   InternalControlVar::ICV___last>
2311       ICVReplacementValuesMap;
2312 
2313   ChangeStatus updateImpl(Attributor &A) override {
2314     ChangeStatus HasChanged = ChangeStatus::UNCHANGED;
2315 
2316     Function *F = getAnchorScope();
2317 
2318     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2319 
2320     for (InternalControlVar ICV : TrackableICVs) {
2321       auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2322 
2323       auto &ValuesMap = ICVReplacementValuesMap[ICV];
2324       auto TrackValues = [&](Use &U, Function &) {
2325         CallInst *CI = OpenMPOpt::getCallIfRegularCall(U);
2326         if (!CI)
2327           return false;
2328 
2329         // FIXME: handle setters with more that 1 arguments.
2330         /// Track new value.
2331         if (ValuesMap.insert(std::make_pair(CI, CI->getArgOperand(0))).second)
2332           HasChanged = ChangeStatus::CHANGED;
2333 
2334         return false;
2335       };
2336 
2337       auto CallCheck = [&](Instruction &I) {
2338         std::optional<Value *> ReplVal = getValueForCall(A, I, ICV);
2339         if (ReplVal && ValuesMap.insert(std::make_pair(&I, *ReplVal)).second)
2340           HasChanged = ChangeStatus::CHANGED;
2341 
2342         return true;
2343       };
2344 
2345       // Track all changes of an ICV.
2346       SetterRFI.foreachUse(TrackValues, F);
2347 
2348       bool UsedAssumedInformation = false;
2349       A.checkForAllInstructions(CallCheck, *this, {Instruction::Call},
2350                                 UsedAssumedInformation,
2351                                 /* CheckBBLivenessOnly */ true);
2352 
2353       /// TODO: Figure out a way to avoid adding entry in
2354       /// ICVReplacementValuesMap
2355       Instruction *Entry = &F->getEntryBlock().front();
2356       if (HasChanged == ChangeStatus::CHANGED && !ValuesMap.count(Entry))
2357         ValuesMap.insert(std::make_pair(Entry, nullptr));
2358     }
2359 
2360     return HasChanged;
2361   }
2362 
2363   /// Helper to check if \p I is a call and get the value for it if it is
2364   /// unique.
2365   std::optional<Value *> getValueForCall(Attributor &A, const Instruction &I,
2366                                          InternalControlVar &ICV) const {
2367 
2368     const auto *CB = dyn_cast<CallBase>(&I);
2369     if (!CB || CB->hasFnAttr("no_openmp") ||
2370         CB->hasFnAttr("no_openmp_routines"))
2371       return std::nullopt;
2372 
2373     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2374     auto &GetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Getter];
2375     auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2376     Function *CalledFunction = CB->getCalledFunction();
2377 
2378     // Indirect call, assume ICV changes.
2379     if (CalledFunction == nullptr)
2380       return nullptr;
2381     if (CalledFunction == GetterRFI.Declaration)
2382       return std::nullopt;
2383     if (CalledFunction == SetterRFI.Declaration) {
2384       if (ICVReplacementValuesMap[ICV].count(&I))
2385         return ICVReplacementValuesMap[ICV].lookup(&I);
2386 
2387       return nullptr;
2388     }
2389 
2390     // Since we don't know, assume it changes the ICV.
2391     if (CalledFunction->isDeclaration())
2392       return nullptr;
2393 
2394     const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2395         *this, IRPosition::callsite_returned(*CB), DepClassTy::REQUIRED);
2396 
2397     if (ICVTrackingAA->isAssumedTracked()) {
2398       std::optional<Value *> URV =
2399           ICVTrackingAA->getUniqueReplacementValue(ICV);
2400       if (!URV || (*URV && AA::isValidAtPosition(AA::ValueAndContext(**URV, I),
2401                                                  OMPInfoCache)))
2402         return URV;
2403     }
2404 
2405     // If we don't know, assume it changes.
2406     return nullptr;
2407   }
2408 
2409   // We don't check unique value for a function, so return std::nullopt.
2410   std::optional<Value *>
2411   getUniqueReplacementValue(InternalControlVar ICV) const override {
2412     return std::nullopt;
2413   }
2414 
2415   /// Return the value with which \p I can be replaced for specific \p ICV.
2416   std::optional<Value *> getReplacementValue(InternalControlVar ICV,
2417                                              const Instruction *I,
2418                                              Attributor &A) const override {
2419     const auto &ValuesMap = ICVReplacementValuesMap[ICV];
2420     if (ValuesMap.count(I))
2421       return ValuesMap.lookup(I);
2422 
2423     SmallVector<const Instruction *, 16> Worklist;
2424     SmallPtrSet<const Instruction *, 16> Visited;
2425     Worklist.push_back(I);
2426 
2427     std::optional<Value *> ReplVal;
2428 
2429     while (!Worklist.empty()) {
2430       const Instruction *CurrInst = Worklist.pop_back_val();
2431       if (!Visited.insert(CurrInst).second)
2432         continue;
2433 
2434       const BasicBlock *CurrBB = CurrInst->getParent();
2435 
2436       // Go up and look for all potential setters/calls that might change the
2437       // ICV.
2438       while ((CurrInst = CurrInst->getPrevNode())) {
2439         if (ValuesMap.count(CurrInst)) {
2440           std::optional<Value *> NewReplVal = ValuesMap.lookup(CurrInst);
2441           // Unknown value, track new.
2442           if (!ReplVal) {
2443             ReplVal = NewReplVal;
2444             break;
2445           }
2446 
2447           // If we found a new value, we can't know the icv value anymore.
2448           if (NewReplVal)
2449             if (ReplVal != NewReplVal)
2450               return nullptr;
2451 
2452           break;
2453         }
2454 
2455         std::optional<Value *> NewReplVal = getValueForCall(A, *CurrInst, ICV);
2456         if (!NewReplVal)
2457           continue;
2458 
2459         // Unknown value, track new.
2460         if (!ReplVal) {
2461           ReplVal = NewReplVal;
2462           break;
2463         }
2464 
2465         // if (NewReplVal.hasValue())
2466         // We found a new value, we can't know the icv value anymore.
2467         if (ReplVal != NewReplVal)
2468           return nullptr;
2469       }
2470 
2471       // If we are in the same BB and we have a value, we are done.
2472       if (CurrBB == I->getParent() && ReplVal)
2473         return ReplVal;
2474 
2475       // Go through all predecessors and add terminators for analysis.
2476       for (const BasicBlock *Pred : predecessors(CurrBB))
2477         if (const Instruction *Terminator = Pred->getTerminator())
2478           Worklist.push_back(Terminator);
2479     }
2480 
2481     return ReplVal;
2482   }
2483 };
2484 
2485 struct AAICVTrackerFunctionReturned : AAICVTracker {
2486   AAICVTrackerFunctionReturned(const IRPosition &IRP, Attributor &A)
2487       : AAICVTracker(IRP, A) {}
2488 
2489   // FIXME: come up with better string.
2490   const std::string getAsStr(Attributor *) const override {
2491     return "ICVTrackerFunctionReturned";
2492   }
2493 
2494   // FIXME: come up with some stats.
2495   void trackStatistics() const override {}
2496 
2497   /// We don't manifest anything for this AA.
2498   ChangeStatus manifest(Attributor &A) override {
2499     return ChangeStatus::UNCHANGED;
2500   }
2501 
2502   // Map of ICV to their values at specific program point.
2503   EnumeratedArray<std::optional<Value *>, InternalControlVar,
2504                   InternalControlVar::ICV___last>
2505       ICVReplacementValuesMap;
2506 
2507   /// Return the value with which \p I can be replaced for specific \p ICV.
2508   std::optional<Value *>
2509   getUniqueReplacementValue(InternalControlVar ICV) const override {
2510     return ICVReplacementValuesMap[ICV];
2511   }
2512 
2513   ChangeStatus updateImpl(Attributor &A) override {
2514     ChangeStatus Changed = ChangeStatus::UNCHANGED;
2515     const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2516         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2517 
2518     if (!ICVTrackingAA->isAssumedTracked())
2519       return indicatePessimisticFixpoint();
2520 
2521     for (InternalControlVar ICV : TrackableICVs) {
2522       std::optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2523       std::optional<Value *> UniqueICVValue;
2524 
2525       auto CheckReturnInst = [&](Instruction &I) {
2526         std::optional<Value *> NewReplVal =
2527             ICVTrackingAA->getReplacementValue(ICV, &I, A);
2528 
2529         // If we found a second ICV value there is no unique returned value.
2530         if (UniqueICVValue && UniqueICVValue != NewReplVal)
2531           return false;
2532 
2533         UniqueICVValue = NewReplVal;
2534 
2535         return true;
2536       };
2537 
2538       bool UsedAssumedInformation = false;
2539       if (!A.checkForAllInstructions(CheckReturnInst, *this, {Instruction::Ret},
2540                                      UsedAssumedInformation,
2541                                      /* CheckBBLivenessOnly */ true))
2542         UniqueICVValue = nullptr;
2543 
2544       if (UniqueICVValue == ReplVal)
2545         continue;
2546 
2547       ReplVal = UniqueICVValue;
2548       Changed = ChangeStatus::CHANGED;
2549     }
2550 
2551     return Changed;
2552   }
2553 };
2554 
2555 struct AAICVTrackerCallSite : AAICVTracker {
2556   AAICVTrackerCallSite(const IRPosition &IRP, Attributor &A)
2557       : AAICVTracker(IRP, A) {}
2558 
2559   void initialize(Attributor &A) override {
2560     assert(getAnchorScope() && "Expected anchor function");
2561 
2562     // We only initialize this AA for getters, so we need to know which ICV it
2563     // gets.
2564     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2565     for (InternalControlVar ICV : TrackableICVs) {
2566       auto ICVInfo = OMPInfoCache.ICVs[ICV];
2567       auto &Getter = OMPInfoCache.RFIs[ICVInfo.Getter];
2568       if (Getter.Declaration == getAssociatedFunction()) {
2569         AssociatedICV = ICVInfo.Kind;
2570         return;
2571       }
2572     }
2573 
2574     /// Unknown ICV.
2575     indicatePessimisticFixpoint();
2576   }
2577 
2578   ChangeStatus manifest(Attributor &A) override {
2579     if (!ReplVal || !*ReplVal)
2580       return ChangeStatus::UNCHANGED;
2581 
2582     A.changeAfterManifest(IRPosition::inst(*getCtxI()), **ReplVal);
2583     A.deleteAfterManifest(*getCtxI());
2584 
2585     return ChangeStatus::CHANGED;
2586   }
2587 
2588   // FIXME: come up with better string.
2589   const std::string getAsStr(Attributor *) const override {
2590     return "ICVTrackerCallSite";
2591   }
2592 
2593   // FIXME: come up with some stats.
2594   void trackStatistics() const override {}
2595 
2596   InternalControlVar AssociatedICV;
2597   std::optional<Value *> ReplVal;
2598 
2599   ChangeStatus updateImpl(Attributor &A) override {
2600     const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2601         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2602 
2603     // We don't have any information, so we assume it changes the ICV.
2604     if (!ICVTrackingAA->isAssumedTracked())
2605       return indicatePessimisticFixpoint();
2606 
2607     std::optional<Value *> NewReplVal =
2608         ICVTrackingAA->getReplacementValue(AssociatedICV, getCtxI(), A);
2609 
2610     if (ReplVal == NewReplVal)
2611       return ChangeStatus::UNCHANGED;
2612 
2613     ReplVal = NewReplVal;
2614     return ChangeStatus::CHANGED;
2615   }
2616 
2617   // Return the value with which associated value can be replaced for specific
2618   // \p ICV.
2619   std::optional<Value *>
2620   getUniqueReplacementValue(InternalControlVar ICV) const override {
2621     return ReplVal;
2622   }
2623 };
2624 
2625 struct AAICVTrackerCallSiteReturned : AAICVTracker {
2626   AAICVTrackerCallSiteReturned(const IRPosition &IRP, Attributor &A)
2627       : AAICVTracker(IRP, A) {}
2628 
2629   // FIXME: come up with better string.
2630   const std::string getAsStr(Attributor *) const override {
2631     return "ICVTrackerCallSiteReturned";
2632   }
2633 
2634   // FIXME: come up with some stats.
2635   void trackStatistics() const override {}
2636 
2637   /// We don't manifest anything for this AA.
2638   ChangeStatus manifest(Attributor &A) override {
2639     return ChangeStatus::UNCHANGED;
2640   }
2641 
2642   // Map of ICV to their values at specific program point.
2643   EnumeratedArray<std::optional<Value *>, InternalControlVar,
2644                   InternalControlVar::ICV___last>
2645       ICVReplacementValuesMap;
2646 
2647   /// Return the value with which associated value can be replaced for specific
2648   /// \p ICV.
2649   std::optional<Value *>
2650   getUniqueReplacementValue(InternalControlVar ICV) const override {
2651     return ICVReplacementValuesMap[ICV];
2652   }
2653 
2654   ChangeStatus updateImpl(Attributor &A) override {
2655     ChangeStatus Changed = ChangeStatus::UNCHANGED;
2656     const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2657         *this, IRPosition::returned(*getAssociatedFunction()),
2658         DepClassTy::REQUIRED);
2659 
2660     // We don't have any information, so we assume it changes the ICV.
2661     if (!ICVTrackingAA->isAssumedTracked())
2662       return indicatePessimisticFixpoint();
2663 
2664     for (InternalControlVar ICV : TrackableICVs) {
2665       std::optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2666       std::optional<Value *> NewReplVal =
2667           ICVTrackingAA->getUniqueReplacementValue(ICV);
2668 
2669       if (ReplVal == NewReplVal)
2670         continue;
2671 
2672       ReplVal = NewReplVal;
2673       Changed = ChangeStatus::CHANGED;
2674     }
2675     return Changed;
2676   }
2677 };
2678 
2679 /// Determines if \p BB exits the function unconditionally itself or reaches a
2680 /// block that does through only unique successors.
2681 static bool hasFunctionEndAsUniqueSuccessor(const BasicBlock *BB) {
2682   if (succ_empty(BB))
2683     return true;
2684   const BasicBlock *const Successor = BB->getUniqueSuccessor();
2685   if (!Successor)
2686     return false;
2687   return hasFunctionEndAsUniqueSuccessor(Successor);
2688 }
2689 
2690 struct AAExecutionDomainFunction : public AAExecutionDomain {
2691   AAExecutionDomainFunction(const IRPosition &IRP, Attributor &A)
2692       : AAExecutionDomain(IRP, A) {}
2693 
2694   ~AAExecutionDomainFunction() { delete RPOT; }
2695 
2696   void initialize(Attributor &A) override {
2697     Function *F = getAnchorScope();
2698     assert(F && "Expected anchor function");
2699     RPOT = new ReversePostOrderTraversal<Function *>(F);
2700   }
2701 
2702   const std::string getAsStr(Attributor *) const override {
2703     unsigned TotalBlocks = 0, InitialThreadBlocks = 0, AlignedBlocks = 0;
2704     for (auto &It : BEDMap) {
2705       if (!It.getFirst())
2706         continue;
2707       TotalBlocks++;
2708       InitialThreadBlocks += It.getSecond().IsExecutedByInitialThreadOnly;
2709       AlignedBlocks += It.getSecond().IsReachedFromAlignedBarrierOnly &&
2710                        It.getSecond().IsReachingAlignedBarrierOnly;
2711     }
2712     return "[AAExecutionDomain] " + std::to_string(InitialThreadBlocks) + "/" +
2713            std::to_string(AlignedBlocks) + " of " +
2714            std::to_string(TotalBlocks) +
2715            " executed by initial thread / aligned";
2716   }
2717 
2718   /// See AbstractAttribute::trackStatistics().
2719   void trackStatistics() const override {}
2720 
2721   ChangeStatus manifest(Attributor &A) override {
2722     LLVM_DEBUG({
2723       for (const BasicBlock &BB : *getAnchorScope()) {
2724         if (!isExecutedByInitialThreadOnly(BB))
2725           continue;
2726         dbgs() << TAG << " Basic block @" << getAnchorScope()->getName() << " "
2727                << BB.getName() << " is executed by a single thread.\n";
2728       }
2729     });
2730 
2731     ChangeStatus Changed = ChangeStatus::UNCHANGED;
2732 
2733     if (DisableOpenMPOptBarrierElimination)
2734       return Changed;
2735 
2736     SmallPtrSet<CallBase *, 16> DeletedBarriers;
2737     auto HandleAlignedBarrier = [&](CallBase *CB) {
2738       const ExecutionDomainTy &ED = CB ? CEDMap[{CB, PRE}] : BEDMap[nullptr];
2739       if (!ED.IsReachedFromAlignedBarrierOnly ||
2740           ED.EncounteredNonLocalSideEffect)
2741         return;
2742       if (!ED.EncounteredAssumes.empty() && !A.isModulePass())
2743         return;
2744 
2745       // We can remove this barrier, if it is one, or aligned barriers reaching
2746       // the kernel end (if CB is nullptr). Aligned barriers reaching the kernel
2747       // end should only be removed if the kernel end is their unique successor;
2748       // otherwise, they may have side-effects that aren't accounted for in the
2749       // kernel end in their other successors. If those barriers have other
2750       // barriers reaching them, those can be transitively removed as well as
2751       // long as the kernel end is also their unique successor.
2752       if (CB) {
2753         DeletedBarriers.insert(CB);
2754         A.deleteAfterManifest(*CB);
2755         ++NumBarriersEliminated;
2756         Changed = ChangeStatus::CHANGED;
2757       } else if (!ED.AlignedBarriers.empty()) {
2758         Changed = ChangeStatus::CHANGED;
2759         SmallVector<CallBase *> Worklist(ED.AlignedBarriers.begin(),
2760                                          ED.AlignedBarriers.end());
2761         SmallSetVector<CallBase *, 16> Visited;
2762         while (!Worklist.empty()) {
2763           CallBase *LastCB = Worklist.pop_back_val();
2764           if (!Visited.insert(LastCB))
2765             continue;
2766           if (LastCB->getFunction() != getAnchorScope())
2767             continue;
2768           if (!hasFunctionEndAsUniqueSuccessor(LastCB->getParent()))
2769             continue;
2770           if (!DeletedBarriers.count(LastCB)) {
2771             ++NumBarriersEliminated;
2772             A.deleteAfterManifest(*LastCB);
2773             continue;
2774           }
2775           // The final aligned barrier (LastCB) reaching the kernel end was
2776           // removed already. This means we can go one step further and remove
2777           // the barriers encoutered last before (LastCB).
2778           const ExecutionDomainTy &LastED = CEDMap[{LastCB, PRE}];
2779           Worklist.append(LastED.AlignedBarriers.begin(),
2780                           LastED.AlignedBarriers.end());
2781         }
2782       }
2783 
2784       // If we actually eliminated a barrier we need to eliminate the associated
2785       // llvm.assumes as well to avoid creating UB.
2786       if (!ED.EncounteredAssumes.empty() && (CB || !ED.AlignedBarriers.empty()))
2787         for (auto *AssumeCB : ED.EncounteredAssumes)
2788           A.deleteAfterManifest(*AssumeCB);
2789     };
2790 
2791     for (auto *CB : AlignedBarriers)
2792       HandleAlignedBarrier(CB);
2793 
2794     // Handle the "kernel end barrier" for kernels too.
2795     if (omp::isOpenMPKernel(*getAnchorScope()))
2796       HandleAlignedBarrier(nullptr);
2797 
2798     return Changed;
2799   }
2800 
2801   bool isNoOpFence(const FenceInst &FI) const override {
2802     return getState().isValidState() && !NonNoOpFences.count(&FI);
2803   }
2804 
2805   /// Merge barrier and assumption information from \p PredED into the successor
2806   /// \p ED.
2807   void
2808   mergeInPredecessorBarriersAndAssumptions(Attributor &A, ExecutionDomainTy &ED,
2809                                            const ExecutionDomainTy &PredED);
2810 
2811   /// Merge all information from \p PredED into the successor \p ED. If
2812   /// \p InitialEdgeOnly is set, only the initial edge will enter the block
2813   /// represented by \p ED from this predecessor.
2814   bool mergeInPredecessor(Attributor &A, ExecutionDomainTy &ED,
2815                           const ExecutionDomainTy &PredED,
2816                           bool InitialEdgeOnly = false);
2817 
2818   /// Accumulate information for the entry block in \p EntryBBED.
2819   bool handleCallees(Attributor &A, ExecutionDomainTy &EntryBBED);
2820 
2821   /// See AbstractAttribute::updateImpl.
2822   ChangeStatus updateImpl(Attributor &A) override;
2823 
2824   /// Query interface, see AAExecutionDomain
2825   ///{
2826   bool isExecutedByInitialThreadOnly(const BasicBlock &BB) const override {
2827     if (!isValidState())
2828       return false;
2829     assert(BB.getParent() == getAnchorScope() && "Block is out of scope!");
2830     return BEDMap.lookup(&BB).IsExecutedByInitialThreadOnly;
2831   }
2832 
2833   bool isExecutedInAlignedRegion(Attributor &A,
2834                                  const Instruction &I) const override {
2835     assert(I.getFunction() == getAnchorScope() &&
2836            "Instruction is out of scope!");
2837     if (!isValidState())
2838       return false;
2839 
2840     bool ForwardIsOk = true;
2841     const Instruction *CurI;
2842 
2843     // Check forward until a call or the block end is reached.
2844     CurI = &I;
2845     do {
2846       auto *CB = dyn_cast<CallBase>(CurI);
2847       if (!CB)
2848         continue;
2849       if (CB != &I && AlignedBarriers.contains(const_cast<CallBase *>(CB)))
2850         return true;
2851       const auto &It = CEDMap.find({CB, PRE});
2852       if (It == CEDMap.end())
2853         continue;
2854       if (!It->getSecond().IsReachingAlignedBarrierOnly)
2855         ForwardIsOk = false;
2856       break;
2857     } while ((CurI = CurI->getNextNonDebugInstruction()));
2858 
2859     if (!CurI && !BEDMap.lookup(I.getParent()).IsReachingAlignedBarrierOnly)
2860       ForwardIsOk = false;
2861 
2862     // Check backward until a call or the block beginning is reached.
2863     CurI = &I;
2864     do {
2865       auto *CB = dyn_cast<CallBase>(CurI);
2866       if (!CB)
2867         continue;
2868       if (CB != &I && AlignedBarriers.contains(const_cast<CallBase *>(CB)))
2869         return true;
2870       const auto &It = CEDMap.find({CB, POST});
2871       if (It == CEDMap.end())
2872         continue;
2873       if (It->getSecond().IsReachedFromAlignedBarrierOnly)
2874         break;
2875       return false;
2876     } while ((CurI = CurI->getPrevNonDebugInstruction()));
2877 
2878     // Delayed decision on the forward pass to allow aligned barrier detection
2879     // in the backwards traversal.
2880     if (!ForwardIsOk)
2881       return false;
2882 
2883     if (!CurI) {
2884       const BasicBlock *BB = I.getParent();
2885       if (BB == &BB->getParent()->getEntryBlock())
2886         return BEDMap.lookup(nullptr).IsReachedFromAlignedBarrierOnly;
2887       if (!llvm::all_of(predecessors(BB), [&](const BasicBlock *PredBB) {
2888             return BEDMap.lookup(PredBB).IsReachedFromAlignedBarrierOnly;
2889           })) {
2890         return false;
2891       }
2892     }
2893 
2894     // On neither traversal we found a anything but aligned barriers.
2895     return true;
2896   }
2897 
2898   ExecutionDomainTy getExecutionDomain(const BasicBlock &BB) const override {
2899     assert(isValidState() &&
2900            "No request should be made against an invalid state!");
2901     return BEDMap.lookup(&BB);
2902   }
2903   std::pair<ExecutionDomainTy, ExecutionDomainTy>
2904   getExecutionDomain(const CallBase &CB) const override {
2905     assert(isValidState() &&
2906            "No request should be made against an invalid state!");
2907     return {CEDMap.lookup({&CB, PRE}), CEDMap.lookup({&CB, POST})};
2908   }
2909   ExecutionDomainTy getFunctionExecutionDomain() const override {
2910     assert(isValidState() &&
2911            "No request should be made against an invalid state!");
2912     return InterProceduralED;
2913   }
2914   ///}
2915 
2916   // Check if the edge into the successor block contains a condition that only
2917   // lets the main thread execute it.
2918   static bool isInitialThreadOnlyEdge(Attributor &A, BranchInst *Edge,
2919                                       BasicBlock &SuccessorBB) {
2920     if (!Edge || !Edge->isConditional())
2921       return false;
2922     if (Edge->getSuccessor(0) != &SuccessorBB)
2923       return false;
2924 
2925     auto *Cmp = dyn_cast<CmpInst>(Edge->getCondition());
2926     if (!Cmp || !Cmp->isTrueWhenEqual() || !Cmp->isEquality())
2927       return false;
2928 
2929     ConstantInt *C = dyn_cast<ConstantInt>(Cmp->getOperand(1));
2930     if (!C)
2931       return false;
2932 
2933     // Match: -1 == __kmpc_target_init (for non-SPMD kernels only!)
2934     if (C->isAllOnesValue()) {
2935       auto *CB = dyn_cast<CallBase>(Cmp->getOperand(0));
2936       auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2937       auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
2938       CB = CB ? OpenMPOpt::getCallIfRegularCall(*CB, &RFI) : nullptr;
2939       if (!CB)
2940         return false;
2941       ConstantStruct *KernelEnvC =
2942           KernelInfo::getKernelEnvironementFromKernelInitCB(CB);
2943       ConstantInt *ExecModeC =
2944           KernelInfo::getExecModeFromKernelEnvironment(KernelEnvC);
2945       return ExecModeC->getSExtValue() & OMP_TGT_EXEC_MODE_GENERIC;
2946     }
2947 
2948     if (C->isZero()) {
2949       // Match: 0 == llvm.nvvm.read.ptx.sreg.tid.x()
2950       if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
2951         if (II->getIntrinsicID() == Intrinsic::nvvm_read_ptx_sreg_tid_x)
2952           return true;
2953 
2954       // Match: 0 == llvm.amdgcn.workitem.id.x()
2955       if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
2956         if (II->getIntrinsicID() == Intrinsic::amdgcn_workitem_id_x)
2957           return true;
2958     }
2959 
2960     return false;
2961   };
2962 
2963   /// Mapping containing information about the function for other AAs.
2964   ExecutionDomainTy InterProceduralED;
2965 
2966   enum Direction { PRE = 0, POST = 1 };
2967   /// Mapping containing information per block.
2968   DenseMap<const BasicBlock *, ExecutionDomainTy> BEDMap;
2969   DenseMap<PointerIntPair<const CallBase *, 1, Direction>, ExecutionDomainTy>
2970       CEDMap;
2971   SmallSetVector<CallBase *, 16> AlignedBarriers;
2972 
2973   ReversePostOrderTraversal<Function *> *RPOT = nullptr;
2974 
2975   /// Set \p R to \V and report true if that changed \p R.
2976   static bool setAndRecord(bool &R, bool V) {
2977     bool Eq = (R == V);
2978     R = V;
2979     return !Eq;
2980   }
2981 
2982   /// Collection of fences known to be non-no-opt. All fences not in this set
2983   /// can be assumed no-opt.
2984   SmallPtrSet<const FenceInst *, 8> NonNoOpFences;
2985 };
2986 
2987 void AAExecutionDomainFunction::mergeInPredecessorBarriersAndAssumptions(
2988     Attributor &A, ExecutionDomainTy &ED, const ExecutionDomainTy &PredED) {
2989   for (auto *EA : PredED.EncounteredAssumes)
2990     ED.addAssumeInst(A, *EA);
2991 
2992   for (auto *AB : PredED.AlignedBarriers)
2993     ED.addAlignedBarrier(A, *AB);
2994 }
2995 
2996 bool AAExecutionDomainFunction::mergeInPredecessor(
2997     Attributor &A, ExecutionDomainTy &ED, const ExecutionDomainTy &PredED,
2998     bool InitialEdgeOnly) {
2999 
3000   bool Changed = false;
3001   Changed |=
3002       setAndRecord(ED.IsExecutedByInitialThreadOnly,
3003                    InitialEdgeOnly || (PredED.IsExecutedByInitialThreadOnly &&
3004                                        ED.IsExecutedByInitialThreadOnly));
3005 
3006   Changed |= setAndRecord(ED.IsReachedFromAlignedBarrierOnly,
3007                           ED.IsReachedFromAlignedBarrierOnly &&
3008                               PredED.IsReachedFromAlignedBarrierOnly);
3009   Changed |= setAndRecord(ED.EncounteredNonLocalSideEffect,
3010                           ED.EncounteredNonLocalSideEffect |
3011                               PredED.EncounteredNonLocalSideEffect);
3012   // Do not track assumptions and barriers as part of Changed.
3013   if (ED.IsReachedFromAlignedBarrierOnly)
3014     mergeInPredecessorBarriersAndAssumptions(A, ED, PredED);
3015   else
3016     ED.clearAssumeInstAndAlignedBarriers();
3017   return Changed;
3018 }
3019 
3020 bool AAExecutionDomainFunction::handleCallees(Attributor &A,
3021                                               ExecutionDomainTy &EntryBBED) {
3022   SmallVector<std::pair<ExecutionDomainTy, ExecutionDomainTy>, 4> CallSiteEDs;
3023   auto PredForCallSite = [&](AbstractCallSite ACS) {
3024     const auto *EDAA = A.getAAFor<AAExecutionDomain>(
3025         *this, IRPosition::function(*ACS.getInstruction()->getFunction()),
3026         DepClassTy::OPTIONAL);
3027     if (!EDAA || !EDAA->getState().isValidState())
3028       return false;
3029     CallSiteEDs.emplace_back(
3030         EDAA->getExecutionDomain(*cast<CallBase>(ACS.getInstruction())));
3031     return true;
3032   };
3033 
3034   ExecutionDomainTy ExitED;
3035   bool AllCallSitesKnown;
3036   if (A.checkForAllCallSites(PredForCallSite, *this,
3037                              /* RequiresAllCallSites */ true,
3038                              AllCallSitesKnown)) {
3039     for (const auto &[CSInED, CSOutED] : CallSiteEDs) {
3040       mergeInPredecessor(A, EntryBBED, CSInED);
3041       ExitED.IsReachingAlignedBarrierOnly &=
3042           CSOutED.IsReachingAlignedBarrierOnly;
3043     }
3044 
3045   } else {
3046     // We could not find all predecessors, so this is either a kernel or a
3047     // function with external linkage (or with some other weird uses).
3048     if (omp::isOpenMPKernel(*getAnchorScope())) {
3049       EntryBBED.IsExecutedByInitialThreadOnly = false;
3050       EntryBBED.IsReachedFromAlignedBarrierOnly = true;
3051       EntryBBED.EncounteredNonLocalSideEffect = false;
3052       ExitED.IsReachingAlignedBarrierOnly = false;
3053     } else {
3054       EntryBBED.IsExecutedByInitialThreadOnly = false;
3055       EntryBBED.IsReachedFromAlignedBarrierOnly = false;
3056       EntryBBED.EncounteredNonLocalSideEffect = true;
3057       ExitED.IsReachingAlignedBarrierOnly = false;
3058     }
3059   }
3060 
3061   bool Changed = false;
3062   auto &FnED = BEDMap[nullptr];
3063   Changed |= setAndRecord(FnED.IsReachedFromAlignedBarrierOnly,
3064                           FnED.IsReachedFromAlignedBarrierOnly &
3065                               EntryBBED.IsReachedFromAlignedBarrierOnly);
3066   Changed |= setAndRecord(FnED.IsReachingAlignedBarrierOnly,
3067                           FnED.IsReachingAlignedBarrierOnly &
3068                               ExitED.IsReachingAlignedBarrierOnly);
3069   Changed |= setAndRecord(FnED.IsExecutedByInitialThreadOnly,
3070                           EntryBBED.IsExecutedByInitialThreadOnly);
3071   return Changed;
3072 }
3073 
3074 ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
3075 
3076   bool Changed = false;
3077 
3078   // Helper to deal with an aligned barrier encountered during the forward
3079   // traversal. \p CB is the aligned barrier, \p ED is the execution domain when
3080   // it was encountered.
3081   auto HandleAlignedBarrier = [&](CallBase &CB, ExecutionDomainTy &ED) {
3082     Changed |= AlignedBarriers.insert(&CB);
3083     // First, update the barrier ED kept in the separate CEDMap.
3084     auto &CallInED = CEDMap[{&CB, PRE}];
3085     Changed |= mergeInPredecessor(A, CallInED, ED);
3086     CallInED.IsReachingAlignedBarrierOnly = true;
3087     // Next adjust the ED we use for the traversal.
3088     ED.EncounteredNonLocalSideEffect = false;
3089     ED.IsReachedFromAlignedBarrierOnly = true;
3090     // Aligned barrier collection has to come last.
3091     ED.clearAssumeInstAndAlignedBarriers();
3092     ED.addAlignedBarrier(A, CB);
3093     auto &CallOutED = CEDMap[{&CB, POST}];
3094     Changed |= mergeInPredecessor(A, CallOutED, ED);
3095   };
3096 
3097   auto *LivenessAA =
3098       A.getAAFor<AAIsDead>(*this, getIRPosition(), DepClassTy::OPTIONAL);
3099 
3100   Function *F = getAnchorScope();
3101   BasicBlock &EntryBB = F->getEntryBlock();
3102   bool IsKernel = omp::isOpenMPKernel(*F);
3103 
3104   SmallVector<Instruction *> SyncInstWorklist;
3105   for (auto &RIt : *RPOT) {
3106     BasicBlock &BB = *RIt;
3107 
3108     bool IsEntryBB = &BB == &EntryBB;
3109     // TODO: We use local reasoning since we don't have a divergence analysis
3110     // 	     running as well. We could basically allow uniform branches here.
3111     bool AlignedBarrierLastInBlock = IsEntryBB && IsKernel;
3112     bool IsExplicitlyAligned = IsEntryBB && IsKernel;
3113     ExecutionDomainTy ED;
3114     // Propagate "incoming edges" into information about this block.
3115     if (IsEntryBB) {
3116       Changed |= handleCallees(A, ED);
3117     } else {
3118       // For live non-entry blocks we only propagate
3119       // information via live edges.
3120       if (LivenessAA && LivenessAA->isAssumedDead(&BB))
3121         continue;
3122 
3123       for (auto *PredBB : predecessors(&BB)) {
3124         if (LivenessAA && LivenessAA->isEdgeDead(PredBB, &BB))
3125           continue;
3126         bool InitialEdgeOnly = isInitialThreadOnlyEdge(
3127             A, dyn_cast<BranchInst>(PredBB->getTerminator()), BB);
3128         mergeInPredecessor(A, ED, BEDMap[PredBB], InitialEdgeOnly);
3129       }
3130     }
3131 
3132     // Now we traverse the block, accumulate effects in ED and attach
3133     // information to calls.
3134     for (Instruction &I : BB) {
3135       bool UsedAssumedInformation;
3136       if (A.isAssumedDead(I, *this, LivenessAA, UsedAssumedInformation,
3137                           /* CheckBBLivenessOnly */ false, DepClassTy::OPTIONAL,
3138                           /* CheckForDeadStore */ true))
3139         continue;
3140 
3141       // Asummes and "assume-like" (dbg, lifetime, ...) are handled first, the
3142       // former is collected the latter is ignored.
3143       if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
3144         if (auto *AI = dyn_cast_or_null<AssumeInst>(II)) {
3145           ED.addAssumeInst(A, *AI);
3146           continue;
3147         }
3148         // TODO: Should we also collect and delete lifetime markers?
3149         if (II->isAssumeLikeIntrinsic())
3150           continue;
3151       }
3152 
3153       if (auto *FI = dyn_cast<FenceInst>(&I)) {
3154         if (!ED.EncounteredNonLocalSideEffect) {
3155           // An aligned fence without non-local side-effects is a no-op.
3156           if (ED.IsReachedFromAlignedBarrierOnly)
3157             continue;
3158           // A non-aligned fence without non-local side-effects is a no-op
3159           // if the ordering only publishes non-local side-effects (or less).
3160           switch (FI->getOrdering()) {
3161           case AtomicOrdering::NotAtomic:
3162             continue;
3163           case AtomicOrdering::Unordered:
3164             continue;
3165           case AtomicOrdering::Monotonic:
3166             continue;
3167           case AtomicOrdering::Acquire:
3168             break;
3169           case AtomicOrdering::Release:
3170             continue;
3171           case AtomicOrdering::AcquireRelease:
3172             break;
3173           case AtomicOrdering::SequentiallyConsistent:
3174             break;
3175           };
3176         }
3177         NonNoOpFences.insert(FI);
3178       }
3179 
3180       auto *CB = dyn_cast<CallBase>(&I);
3181       bool IsNoSync = AA::isNoSyncInst(A, I, *this);
3182       bool IsAlignedBarrier =
3183           !IsNoSync && CB &&
3184           AANoSync::isAlignedBarrier(*CB, AlignedBarrierLastInBlock);
3185 
3186       AlignedBarrierLastInBlock &= IsNoSync;
3187       IsExplicitlyAligned &= IsNoSync;
3188 
3189       // Next we check for calls. Aligned barriers are handled
3190       // explicitly, everything else is kept for the backward traversal and will
3191       // also affect our state.
3192       if (CB) {
3193         if (IsAlignedBarrier) {
3194           HandleAlignedBarrier(*CB, ED);
3195           AlignedBarrierLastInBlock = true;
3196           IsExplicitlyAligned = true;
3197           continue;
3198         }
3199 
3200         // Check the pointer(s) of a memory intrinsic explicitly.
3201         if (isa<MemIntrinsic>(&I)) {
3202           if (!ED.EncounteredNonLocalSideEffect &&
3203               AA::isPotentiallyAffectedByBarrier(A, I, *this))
3204             ED.EncounteredNonLocalSideEffect = true;
3205           if (!IsNoSync) {
3206             ED.IsReachedFromAlignedBarrierOnly = false;
3207             SyncInstWorklist.push_back(&I);
3208           }
3209           continue;
3210         }
3211 
3212         // Record how we entered the call, then accumulate the effect of the
3213         // call in ED for potential use by the callee.
3214         auto &CallInED = CEDMap[{CB, PRE}];
3215         Changed |= mergeInPredecessor(A, CallInED, ED);
3216 
3217         // If we have a sync-definition we can check if it starts/ends in an
3218         // aligned barrier. If we are unsure we assume any sync breaks
3219         // alignment.
3220         Function *Callee = CB->getCalledFunction();
3221         if (!IsNoSync && Callee && !Callee->isDeclaration()) {
3222           const auto *EDAA = A.getAAFor<AAExecutionDomain>(
3223               *this, IRPosition::function(*Callee), DepClassTy::OPTIONAL);
3224           if (EDAA && EDAA->getState().isValidState()) {
3225             const auto &CalleeED = EDAA->getFunctionExecutionDomain();
3226             ED.IsReachedFromAlignedBarrierOnly =
3227                 CalleeED.IsReachedFromAlignedBarrierOnly;
3228             AlignedBarrierLastInBlock = ED.IsReachedFromAlignedBarrierOnly;
3229             if (IsNoSync || !CalleeED.IsReachedFromAlignedBarrierOnly)
3230               ED.EncounteredNonLocalSideEffect |=
3231                   CalleeED.EncounteredNonLocalSideEffect;
3232             else
3233               ED.EncounteredNonLocalSideEffect =
3234                   CalleeED.EncounteredNonLocalSideEffect;
3235             if (!CalleeED.IsReachingAlignedBarrierOnly) {
3236               Changed |=
3237                   setAndRecord(CallInED.IsReachingAlignedBarrierOnly, false);
3238               SyncInstWorklist.push_back(&I);
3239             }
3240             if (CalleeED.IsReachedFromAlignedBarrierOnly)
3241               mergeInPredecessorBarriersAndAssumptions(A, ED, CalleeED);
3242             auto &CallOutED = CEDMap[{CB, POST}];
3243             Changed |= mergeInPredecessor(A, CallOutED, ED);
3244             continue;
3245           }
3246         }
3247         if (!IsNoSync) {
3248           ED.IsReachedFromAlignedBarrierOnly = false;
3249           Changed |= setAndRecord(CallInED.IsReachingAlignedBarrierOnly, false);
3250           SyncInstWorklist.push_back(&I);
3251         }
3252         AlignedBarrierLastInBlock &= ED.IsReachedFromAlignedBarrierOnly;
3253         ED.EncounteredNonLocalSideEffect |= !CB->doesNotAccessMemory();
3254         auto &CallOutED = CEDMap[{CB, POST}];
3255         Changed |= mergeInPredecessor(A, CallOutED, ED);
3256       }
3257 
3258       if (!I.mayHaveSideEffects() && !I.mayReadFromMemory())
3259         continue;
3260 
3261       // If we have a callee we try to use fine-grained information to
3262       // determine local side-effects.
3263       if (CB) {
3264         const auto *MemAA = A.getAAFor<AAMemoryLocation>(
3265             *this, IRPosition::callsite_function(*CB), DepClassTy::OPTIONAL);
3266 
3267         auto AccessPred = [&](const Instruction *I, const Value *Ptr,
3268                               AAMemoryLocation::AccessKind,
3269                               AAMemoryLocation::MemoryLocationsKind) {
3270           return !AA::isPotentiallyAffectedByBarrier(A, {Ptr}, *this, I);
3271         };
3272         if (MemAA && MemAA->getState().isValidState() &&
3273             MemAA->checkForAllAccessesToMemoryKind(
3274                 AccessPred, AAMemoryLocation::ALL_LOCATIONS))
3275           continue;
3276       }
3277 
3278       auto &InfoCache = A.getInfoCache();
3279       if (!I.mayHaveSideEffects() && InfoCache.isOnlyUsedByAssume(I))
3280         continue;
3281 
3282       if (auto *LI = dyn_cast<LoadInst>(&I))
3283         if (LI->hasMetadata(LLVMContext::MD_invariant_load))
3284           continue;
3285 
3286       if (!ED.EncounteredNonLocalSideEffect &&
3287           AA::isPotentiallyAffectedByBarrier(A, I, *this))
3288         ED.EncounteredNonLocalSideEffect = true;
3289     }
3290 
3291     bool IsEndAndNotReachingAlignedBarriersOnly = false;
3292     if (!isa<UnreachableInst>(BB.getTerminator()) &&
3293         !BB.getTerminator()->getNumSuccessors()) {
3294 
3295       Changed |= mergeInPredecessor(A, InterProceduralED, ED);
3296 
3297       auto &FnED = BEDMap[nullptr];
3298       if (IsKernel && !IsExplicitlyAligned)
3299         FnED.IsReachingAlignedBarrierOnly = false;
3300       Changed |= mergeInPredecessor(A, FnED, ED);
3301 
3302       if (!FnED.IsReachingAlignedBarrierOnly) {
3303         IsEndAndNotReachingAlignedBarriersOnly = true;
3304         SyncInstWorklist.push_back(BB.getTerminator());
3305         auto &BBED = BEDMap[&BB];
3306         Changed |= setAndRecord(BBED.IsReachingAlignedBarrierOnly, false);
3307       }
3308     }
3309 
3310     ExecutionDomainTy &StoredED = BEDMap[&BB];
3311     ED.IsReachingAlignedBarrierOnly = StoredED.IsReachingAlignedBarrierOnly &
3312                                       !IsEndAndNotReachingAlignedBarriersOnly;
3313 
3314     // Check if we computed anything different as part of the forward
3315     // traversal. We do not take assumptions and aligned barriers into account
3316     // as they do not influence the state we iterate. Backward traversal values
3317     // are handled later on.
3318     if (ED.IsExecutedByInitialThreadOnly !=
3319             StoredED.IsExecutedByInitialThreadOnly ||
3320         ED.IsReachedFromAlignedBarrierOnly !=
3321             StoredED.IsReachedFromAlignedBarrierOnly ||
3322         ED.EncounteredNonLocalSideEffect !=
3323             StoredED.EncounteredNonLocalSideEffect)
3324       Changed = true;
3325 
3326     // Update the state with the new value.
3327     StoredED = std::move(ED);
3328   }
3329 
3330   // Propagate (non-aligned) sync instruction effects backwards until the
3331   // entry is hit or an aligned barrier.
3332   SmallSetVector<BasicBlock *, 16> Visited;
3333   while (!SyncInstWorklist.empty()) {
3334     Instruction *SyncInst = SyncInstWorklist.pop_back_val();
3335     Instruction *CurInst = SyncInst;
3336     bool HitAlignedBarrierOrKnownEnd = false;
3337     while ((CurInst = CurInst->getPrevNode())) {
3338       auto *CB = dyn_cast<CallBase>(CurInst);
3339       if (!CB)
3340         continue;
3341       auto &CallOutED = CEDMap[{CB, POST}];
3342       Changed |= setAndRecord(CallOutED.IsReachingAlignedBarrierOnly, false);
3343       auto &CallInED = CEDMap[{CB, PRE}];
3344       HitAlignedBarrierOrKnownEnd =
3345           AlignedBarriers.count(CB) || !CallInED.IsReachingAlignedBarrierOnly;
3346       if (HitAlignedBarrierOrKnownEnd)
3347         break;
3348       Changed |= setAndRecord(CallInED.IsReachingAlignedBarrierOnly, false);
3349     }
3350     if (HitAlignedBarrierOrKnownEnd)
3351       continue;
3352     BasicBlock *SyncBB = SyncInst->getParent();
3353     for (auto *PredBB : predecessors(SyncBB)) {
3354       if (LivenessAA && LivenessAA->isEdgeDead(PredBB, SyncBB))
3355         continue;
3356       if (!Visited.insert(PredBB))
3357         continue;
3358       auto &PredED = BEDMap[PredBB];
3359       if (setAndRecord(PredED.IsReachingAlignedBarrierOnly, false)) {
3360         Changed = true;
3361         SyncInstWorklist.push_back(PredBB->getTerminator());
3362       }
3363     }
3364     if (SyncBB != &EntryBB)
3365       continue;
3366     Changed |=
3367         setAndRecord(InterProceduralED.IsReachingAlignedBarrierOnly, false);
3368   }
3369 
3370   return Changed ? ChangeStatus::CHANGED : ChangeStatus::UNCHANGED;
3371 }
3372 
3373 /// Try to replace memory allocation calls called by a single thread with a
3374 /// static buffer of shared memory.
3375 struct AAHeapToShared : public StateWrapper<BooleanState, AbstractAttribute> {
3376   using Base = StateWrapper<BooleanState, AbstractAttribute>;
3377   AAHeapToShared(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
3378 
3379   /// Create an abstract attribute view for the position \p IRP.
3380   static AAHeapToShared &createForPosition(const IRPosition &IRP,
3381                                            Attributor &A);
3382 
3383   /// Returns true if HeapToShared conversion is assumed to be possible.
3384   virtual bool isAssumedHeapToShared(CallBase &CB) const = 0;
3385 
3386   /// Returns true if HeapToShared conversion is assumed and the CB is a
3387   /// callsite to a free operation to be removed.
3388   virtual bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const = 0;
3389 
3390   /// See AbstractAttribute::getName().
3391   const std::string getName() const override { return "AAHeapToShared"; }
3392 
3393   /// See AbstractAttribute::getIdAddr().
3394   const char *getIdAddr() const override { return &ID; }
3395 
3396   /// This function should return true if the type of the \p AA is
3397   /// AAHeapToShared.
3398   static bool classof(const AbstractAttribute *AA) {
3399     return (AA->getIdAddr() == &ID);
3400   }
3401 
3402   /// Unique ID (due to the unique address)
3403   static const char ID;
3404 };
3405 
3406 struct AAHeapToSharedFunction : public AAHeapToShared {
3407   AAHeapToSharedFunction(const IRPosition &IRP, Attributor &A)
3408       : AAHeapToShared(IRP, A) {}
3409 
3410   const std::string getAsStr(Attributor *) const override {
3411     return "[AAHeapToShared] " + std::to_string(MallocCalls.size()) +
3412            " malloc calls eligible.";
3413   }
3414 
3415   /// See AbstractAttribute::trackStatistics().
3416   void trackStatistics() const override {}
3417 
3418   /// This functions finds free calls that will be removed by the
3419   /// HeapToShared transformation.
3420   void findPotentialRemovedFreeCalls(Attributor &A) {
3421     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3422     auto &FreeRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
3423 
3424     PotentialRemovedFreeCalls.clear();
3425     // Update free call users of found malloc calls.
3426     for (CallBase *CB : MallocCalls) {
3427       SmallVector<CallBase *, 4> FreeCalls;
3428       for (auto *U : CB->users()) {
3429         CallBase *C = dyn_cast<CallBase>(U);
3430         if (C && C->getCalledFunction() == FreeRFI.Declaration)
3431           FreeCalls.push_back(C);
3432       }
3433 
3434       if (FreeCalls.size() != 1)
3435         continue;
3436 
3437       PotentialRemovedFreeCalls.insert(FreeCalls.front());
3438     }
3439   }
3440 
3441   void initialize(Attributor &A) override {
3442     if (DisableOpenMPOptDeglobalization) {
3443       indicatePessimisticFixpoint();
3444       return;
3445     }
3446 
3447     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3448     auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3449     if (!RFI.Declaration)
3450       return;
3451 
3452     Attributor::SimplifictionCallbackTy SCB =
3453         [](const IRPosition &, const AbstractAttribute *,
3454            bool &) -> std::optional<Value *> { return nullptr; };
3455 
3456     Function *F = getAnchorScope();
3457     for (User *U : RFI.Declaration->users())
3458       if (CallBase *CB = dyn_cast<CallBase>(U)) {
3459         if (CB->getFunction() != F)
3460           continue;
3461         MallocCalls.insert(CB);
3462         A.registerSimplificationCallback(IRPosition::callsite_returned(*CB),
3463                                          SCB);
3464       }
3465 
3466     findPotentialRemovedFreeCalls(A);
3467   }
3468 
3469   bool isAssumedHeapToShared(CallBase &CB) const override {
3470     return isValidState() && MallocCalls.count(&CB);
3471   }
3472 
3473   bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const override {
3474     return isValidState() && PotentialRemovedFreeCalls.count(&CB);
3475   }
3476 
3477   ChangeStatus manifest(Attributor &A) override {
3478     if (MallocCalls.empty())
3479       return ChangeStatus::UNCHANGED;
3480 
3481     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3482     auto &FreeCall = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
3483 
3484     Function *F = getAnchorScope();
3485     auto *HS = A.lookupAAFor<AAHeapToStack>(IRPosition::function(*F), this,
3486                                             DepClassTy::OPTIONAL);
3487 
3488     ChangeStatus Changed = ChangeStatus::UNCHANGED;
3489     for (CallBase *CB : MallocCalls) {
3490       // Skip replacing this if HeapToStack has already claimed it.
3491       if (HS && HS->isAssumedHeapToStack(*CB))
3492         continue;
3493 
3494       // Find the unique free call to remove it.
3495       SmallVector<CallBase *, 4> FreeCalls;
3496       for (auto *U : CB->users()) {
3497         CallBase *C = dyn_cast<CallBase>(U);
3498         if (C && C->getCalledFunction() == FreeCall.Declaration)
3499           FreeCalls.push_back(C);
3500       }
3501       if (FreeCalls.size() != 1)
3502         continue;
3503 
3504       auto *AllocSize = cast<ConstantInt>(CB->getArgOperand(0));
3505 
3506       if (AllocSize->getZExtValue() + SharedMemoryUsed > SharedMemoryLimit) {
3507         LLVM_DEBUG(dbgs() << TAG << "Cannot replace call " << *CB
3508                           << " with shared memory."
3509                           << " Shared memory usage is limited to "
3510                           << SharedMemoryLimit << " bytes\n");
3511         continue;
3512       }
3513 
3514       LLVM_DEBUG(dbgs() << TAG << "Replace globalization call " << *CB
3515                         << " with " << AllocSize->getZExtValue()
3516                         << " bytes of shared memory\n");
3517 
3518       // Create a new shared memory buffer of the same size as the allocation
3519       // and replace all the uses of the original allocation with it.
3520       Module *M = CB->getModule();
3521       Type *Int8Ty = Type::getInt8Ty(M->getContext());
3522       Type *Int8ArrTy = ArrayType::get(Int8Ty, AllocSize->getZExtValue());
3523       auto *SharedMem = new GlobalVariable(
3524           *M, Int8ArrTy, /* IsConstant */ false, GlobalValue::InternalLinkage,
3525           PoisonValue::get(Int8ArrTy), CB->getName() + "_shared", nullptr,
3526           GlobalValue::NotThreadLocal,
3527           static_cast<unsigned>(AddressSpace::Shared));
3528       auto *NewBuffer =
3529           ConstantExpr::getPointerCast(SharedMem, Int8Ty->getPointerTo());
3530 
3531       auto Remark = [&](OptimizationRemark OR) {
3532         return OR << "Replaced globalized variable with "
3533                   << ore::NV("SharedMemory", AllocSize->getZExtValue())
3534                   << (AllocSize->isOne() ? " byte " : " bytes ")
3535                   << "of shared memory.";
3536       };
3537       A.emitRemark<OptimizationRemark>(CB, "OMP111", Remark);
3538 
3539       MaybeAlign Alignment = CB->getRetAlign();
3540       assert(Alignment &&
3541              "HeapToShared on allocation without alignment attribute");
3542       SharedMem->setAlignment(*Alignment);
3543 
3544       A.changeAfterManifest(IRPosition::callsite_returned(*CB), *NewBuffer);
3545       A.deleteAfterManifest(*CB);
3546       A.deleteAfterManifest(*FreeCalls.front());
3547 
3548       SharedMemoryUsed += AllocSize->getZExtValue();
3549       NumBytesMovedToSharedMemory = SharedMemoryUsed;
3550       Changed = ChangeStatus::CHANGED;
3551     }
3552 
3553     return Changed;
3554   }
3555 
3556   ChangeStatus updateImpl(Attributor &A) override {
3557     if (MallocCalls.empty())
3558       return indicatePessimisticFixpoint();
3559     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3560     auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3561     if (!RFI.Declaration)
3562       return ChangeStatus::UNCHANGED;
3563 
3564     Function *F = getAnchorScope();
3565 
3566     auto NumMallocCalls = MallocCalls.size();
3567 
3568     // Only consider malloc calls executed by a single thread with a constant.
3569     for (User *U : RFI.Declaration->users()) {
3570       if (CallBase *CB = dyn_cast<CallBase>(U)) {
3571         if (CB->getCaller() != F)
3572           continue;
3573         if (!MallocCalls.count(CB))
3574           continue;
3575         if (!isa<ConstantInt>(CB->getArgOperand(0))) {
3576           MallocCalls.remove(CB);
3577           continue;
3578         }
3579         const auto *ED = A.getAAFor<AAExecutionDomain>(
3580             *this, IRPosition::function(*F), DepClassTy::REQUIRED);
3581         if (!ED || !ED->isExecutedByInitialThreadOnly(*CB))
3582           MallocCalls.remove(CB);
3583       }
3584     }
3585 
3586     findPotentialRemovedFreeCalls(A);
3587 
3588     if (NumMallocCalls != MallocCalls.size())
3589       return ChangeStatus::CHANGED;
3590 
3591     return ChangeStatus::UNCHANGED;
3592   }
3593 
3594   /// Collection of all malloc calls in a function.
3595   SmallSetVector<CallBase *, 4> MallocCalls;
3596   /// Collection of potentially removed free calls in a function.
3597   SmallPtrSet<CallBase *, 4> PotentialRemovedFreeCalls;
3598   /// The total amount of shared memory that has been used for HeapToShared.
3599   unsigned SharedMemoryUsed = 0;
3600 };
3601 
3602 struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
3603   using Base = StateWrapper<KernelInfoState, AbstractAttribute>;
3604   AAKernelInfo(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
3605 
3606   /// The callee value is tracked beyond a simple stripPointerCasts, so we allow
3607   /// unknown callees.
3608   static bool requiresCalleeForCallBase() { return false; }
3609 
3610   /// Statistics are tracked as part of manifest for now.
3611   void trackStatistics() const override {}
3612 
3613   /// See AbstractAttribute::getAsStr()
3614   const std::string getAsStr(Attributor *) const override {
3615     if (!isValidState())
3616       return "<invalid>";
3617     return std::string(SPMDCompatibilityTracker.isAssumed() ? "SPMD"
3618                                                             : "generic") +
3619            std::string(SPMDCompatibilityTracker.isAtFixpoint() ? " [FIX]"
3620                                                                : "") +
3621            std::string(" #PRs: ") +
3622            (ReachedKnownParallelRegions.isValidState()
3623                 ? std::to_string(ReachedKnownParallelRegions.size())
3624                 : "<invalid>") +
3625            ", #Unknown PRs: " +
3626            (ReachedUnknownParallelRegions.isValidState()
3627                 ? std::to_string(ReachedUnknownParallelRegions.size())
3628                 : "<invalid>") +
3629            ", #Reaching Kernels: " +
3630            (ReachingKernelEntries.isValidState()
3631                 ? std::to_string(ReachingKernelEntries.size())
3632                 : "<invalid>") +
3633            ", #ParLevels: " +
3634            (ParallelLevels.isValidState()
3635                 ? std::to_string(ParallelLevels.size())
3636                 : "<invalid>") +
3637            ", NestedPar: " + (NestedParallelism ? "yes" : "no");
3638   }
3639 
3640   /// Create an abstract attribute biew for the position \p IRP.
3641   static AAKernelInfo &createForPosition(const IRPosition &IRP, Attributor &A);
3642 
3643   /// See AbstractAttribute::getName()
3644   const std::string getName() const override { return "AAKernelInfo"; }
3645 
3646   /// See AbstractAttribute::getIdAddr()
3647   const char *getIdAddr() const override { return &ID; }
3648 
3649   /// This function should return true if the type of the \p AA is AAKernelInfo
3650   static bool classof(const AbstractAttribute *AA) {
3651     return (AA->getIdAddr() == &ID);
3652   }
3653 
3654   static const char ID;
3655 };
3656 
3657 /// The function kernel info abstract attribute, basically, what can we say
3658 /// about a function with regards to the KernelInfoState.
3659 struct AAKernelInfoFunction : AAKernelInfo {
3660   AAKernelInfoFunction(const IRPosition &IRP, Attributor &A)
3661       : AAKernelInfo(IRP, A) {}
3662 
3663   SmallPtrSet<Instruction *, 4> GuardedInstructions;
3664 
3665   SmallPtrSetImpl<Instruction *> &getGuardedInstructions() {
3666     return GuardedInstructions;
3667   }
3668 
3669   void setConfigurationOfKernelEnvironment(ConstantStruct *ConfigC) {
3670     Constant *NewKernelEnvC = ConstantFoldInsertValueInstruction(
3671         KernelEnvC, ConfigC, {KernelInfo::ConfigurationIdx});
3672     assert(NewKernelEnvC && "Failed to create new kernel environment");
3673     KernelEnvC = cast<ConstantStruct>(NewKernelEnvC);
3674   }
3675 
3676 #define KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MEMBER)                        \
3677   void set##MEMBER##OfKernelEnvironment(ConstantInt *NewVal) {                 \
3678     ConstantStruct *ConfigC =                                                  \
3679         KernelInfo::getConfigurationFromKernelEnvironment(KernelEnvC);         \
3680     Constant *NewConfigC = ConstantFoldInsertValueInstruction(                 \
3681         ConfigC, NewVal, {KernelInfo::MEMBER##Idx});                           \
3682     assert(NewConfigC && "Failed to create new configuration environment");    \
3683     setConfigurationOfKernelEnvironment(cast<ConstantStruct>(NewConfigC));     \
3684   }
3685 
3686   KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(UseGenericStateMachine)
3687   KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MayUseNestedParallelism)
3688   KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(ExecMode)
3689   KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MinThreads)
3690   KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MaxThreads)
3691   KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MinTeams)
3692   KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MaxTeams)
3693 
3694 #undef KERNEL_ENVIRONMENT_CONFIGURATION_SETTER
3695 
3696   /// See AbstractAttribute::initialize(...).
3697   void initialize(Attributor &A) override {
3698     // This is a high-level transform that might change the constant arguments
3699     // of the init and dinit calls. We need to tell the Attributor about this
3700     // to avoid other parts using the current constant value for simpliication.
3701     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3702 
3703     Function *Fn = getAnchorScope();
3704 
3705     OMPInformationCache::RuntimeFunctionInfo &InitRFI =
3706         OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
3707     OMPInformationCache::RuntimeFunctionInfo &DeinitRFI =
3708         OMPInfoCache.RFIs[OMPRTL___kmpc_target_deinit];
3709 
3710     // For kernels we perform more initialization work, first we find the init
3711     // and deinit calls.
3712     auto StoreCallBase = [](Use &U,
3713                             OMPInformationCache::RuntimeFunctionInfo &RFI,
3714                             CallBase *&Storage) {
3715       CallBase *CB = OpenMPOpt::getCallIfRegularCall(U, &RFI);
3716       assert(CB &&
3717              "Unexpected use of __kmpc_target_init or __kmpc_target_deinit!");
3718       assert(!Storage &&
3719              "Multiple uses of __kmpc_target_init or __kmpc_target_deinit!");
3720       Storage = CB;
3721       return false;
3722     };
3723     InitRFI.foreachUse(
3724         [&](Use &U, Function &) {
3725           StoreCallBase(U, InitRFI, KernelInitCB);
3726           return false;
3727         },
3728         Fn);
3729     DeinitRFI.foreachUse(
3730         [&](Use &U, Function &) {
3731           StoreCallBase(U, DeinitRFI, KernelDeinitCB);
3732           return false;
3733         },
3734         Fn);
3735 
3736     // Ignore kernels without initializers such as global constructors.
3737     if (!KernelInitCB || !KernelDeinitCB)
3738       return;
3739 
3740     // Add itself to the reaching kernel and set IsKernelEntry.
3741     ReachingKernelEntries.insert(Fn);
3742     IsKernelEntry = true;
3743 
3744     KernelEnvC =
3745         KernelInfo::getKernelEnvironementFromKernelInitCB(KernelInitCB);
3746     GlobalVariable *KernelEnvGV =
3747         KernelInfo::getKernelEnvironementGVFromKernelInitCB(KernelInitCB);
3748 
3749     Attributor::GlobalVariableSimplifictionCallbackTy
3750         KernelConfigurationSimplifyCB =
3751             [&](const GlobalVariable &GV, const AbstractAttribute *AA,
3752                 bool &UsedAssumedInformation) -> std::optional<Constant *> {
3753       if (!isAtFixpoint()) {
3754         if (!AA)
3755           return nullptr;
3756         UsedAssumedInformation = true;
3757         A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
3758       }
3759       return KernelEnvC;
3760     };
3761 
3762     A.registerGlobalVariableSimplificationCallback(
3763         *KernelEnvGV, KernelConfigurationSimplifyCB);
3764 
3765     // Check if we know we are in SPMD-mode already.
3766     ConstantInt *ExecModeC =
3767         KernelInfo::getExecModeFromKernelEnvironment(KernelEnvC);
3768     ConstantInt *AssumedExecModeC = ConstantInt::get(
3769         ExecModeC->getIntegerType(),
3770         ExecModeC->getSExtValue() | OMP_TGT_EXEC_MODE_GENERIC_SPMD);
3771     if (ExecModeC->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD)
3772       SPMDCompatibilityTracker.indicateOptimisticFixpoint();
3773     else if (DisableOpenMPOptSPMDization)
3774       // This is a generic region but SPMDization is disabled so stop
3775       // tracking.
3776       SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3777     else
3778       setExecModeOfKernelEnvironment(AssumedExecModeC);
3779 
3780     const Triple T(Fn->getParent()->getTargetTriple());
3781     auto *Int32Ty = Type::getInt32Ty(Fn->getContext());
3782     auto [MinThreads, MaxThreads] =
3783         OpenMPIRBuilder::readThreadBoundsForKernel(T, *Fn);
3784     if (MinThreads)
3785       setMinThreadsOfKernelEnvironment(ConstantInt::get(Int32Ty, MinThreads));
3786     if (MaxThreads)
3787       setMaxThreadsOfKernelEnvironment(ConstantInt::get(Int32Ty, MaxThreads));
3788     auto [MinTeams, MaxTeams] =
3789         OpenMPIRBuilder::readTeamBoundsForKernel(T, *Fn);
3790     if (MinTeams)
3791       setMinTeamsOfKernelEnvironment(ConstantInt::get(Int32Ty, MinTeams));
3792     if (MaxTeams)
3793       setMaxTeamsOfKernelEnvironment(ConstantInt::get(Int32Ty, MaxTeams));
3794 
3795     ConstantInt *MayUseNestedParallelismC =
3796         KernelInfo::getMayUseNestedParallelismFromKernelEnvironment(KernelEnvC);
3797     ConstantInt *AssumedMayUseNestedParallelismC = ConstantInt::get(
3798         MayUseNestedParallelismC->getIntegerType(), NestedParallelism);
3799     setMayUseNestedParallelismOfKernelEnvironment(
3800         AssumedMayUseNestedParallelismC);
3801 
3802     if (!DisableOpenMPOptStateMachineRewrite) {
3803       ConstantInt *UseGenericStateMachineC =
3804           KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
3805               KernelEnvC);
3806       ConstantInt *AssumedUseGenericStateMachineC =
3807           ConstantInt::get(UseGenericStateMachineC->getIntegerType(), false);
3808       setUseGenericStateMachineOfKernelEnvironment(
3809           AssumedUseGenericStateMachineC);
3810     }
3811 
3812     // Register virtual uses of functions we might need to preserve.
3813     auto RegisterVirtualUse = [&](RuntimeFunction RFKind,
3814                                   Attributor::VirtualUseCallbackTy &CB) {
3815       if (!OMPInfoCache.RFIs[RFKind].Declaration)
3816         return;
3817       A.registerVirtualUseCallback(*OMPInfoCache.RFIs[RFKind].Declaration, CB);
3818     };
3819 
3820     // Add a dependence to ensure updates if the state changes.
3821     auto AddDependence = [](Attributor &A, const AAKernelInfo *KI,
3822                             const AbstractAttribute *QueryingAA) {
3823       if (QueryingAA) {
3824         A.recordDependence(*KI, *QueryingAA, DepClassTy::OPTIONAL);
3825       }
3826       return true;
3827     };
3828 
3829     Attributor::VirtualUseCallbackTy CustomStateMachineUseCB =
3830         [&](Attributor &A, const AbstractAttribute *QueryingAA) {
3831           // Whenever we create a custom state machine we will insert calls to
3832           // __kmpc_get_hardware_num_threads_in_block,
3833           // __kmpc_get_warp_size,
3834           // __kmpc_barrier_simple_generic,
3835           // __kmpc_kernel_parallel, and
3836           // __kmpc_kernel_end_parallel.
3837           // Not needed if we are on track for SPMDzation.
3838           if (SPMDCompatibilityTracker.isValidState())
3839             return AddDependence(A, this, QueryingAA);
3840           // Not needed if we can't rewrite due to an invalid state.
3841           if (!ReachedKnownParallelRegions.isValidState())
3842             return AddDependence(A, this, QueryingAA);
3843           return false;
3844         };
3845 
3846     // Not needed if we are pre-runtime merge.
3847     if (!KernelInitCB->getCalledFunction()->isDeclaration()) {
3848       RegisterVirtualUse(OMPRTL___kmpc_get_hardware_num_threads_in_block,
3849                          CustomStateMachineUseCB);
3850       RegisterVirtualUse(OMPRTL___kmpc_get_warp_size, CustomStateMachineUseCB);
3851       RegisterVirtualUse(OMPRTL___kmpc_barrier_simple_generic,
3852                          CustomStateMachineUseCB);
3853       RegisterVirtualUse(OMPRTL___kmpc_kernel_parallel,
3854                          CustomStateMachineUseCB);
3855       RegisterVirtualUse(OMPRTL___kmpc_kernel_end_parallel,
3856                          CustomStateMachineUseCB);
3857     }
3858 
3859     // If we do not perform SPMDzation we do not need the virtual uses below.
3860     if (SPMDCompatibilityTracker.isAtFixpoint())
3861       return;
3862 
3863     Attributor::VirtualUseCallbackTy HWThreadIdUseCB =
3864         [&](Attributor &A, const AbstractAttribute *QueryingAA) {
3865           // Whenever we perform SPMDzation we will insert
3866           // __kmpc_get_hardware_thread_id_in_block calls.
3867           if (!SPMDCompatibilityTracker.isValidState())
3868             return AddDependence(A, this, QueryingAA);
3869           return false;
3870         };
3871     RegisterVirtualUse(OMPRTL___kmpc_get_hardware_thread_id_in_block,
3872                        HWThreadIdUseCB);
3873 
3874     Attributor::VirtualUseCallbackTy SPMDBarrierUseCB =
3875         [&](Attributor &A, const AbstractAttribute *QueryingAA) {
3876           // Whenever we perform SPMDzation with guarding we will insert
3877           // __kmpc_simple_barrier_spmd calls. If SPMDzation failed, there is
3878           // nothing to guard, or there are no parallel regions, we don't need
3879           // the calls.
3880           if (!SPMDCompatibilityTracker.isValidState())
3881             return AddDependence(A, this, QueryingAA);
3882           if (SPMDCompatibilityTracker.empty())
3883             return AddDependence(A, this, QueryingAA);
3884           if (!mayContainParallelRegion())
3885             return AddDependence(A, this, QueryingAA);
3886           return false;
3887         };
3888     RegisterVirtualUse(OMPRTL___kmpc_barrier_simple_spmd, SPMDBarrierUseCB);
3889   }
3890 
3891   /// Sanitize the string \p S such that it is a suitable global symbol name.
3892   static std::string sanitizeForGlobalName(std::string S) {
3893     std::replace_if(
3894         S.begin(), S.end(),
3895         [](const char C) {
3896           return !((C >= 'a' && C <= 'z') || (C >= 'A' && C <= 'Z') ||
3897                    (C >= '0' && C <= '9') || C == '_');
3898         },
3899         '.');
3900     return S;
3901   }
3902 
3903   /// Modify the IR based on the KernelInfoState as the fixpoint iteration is
3904   /// finished now.
3905   ChangeStatus manifest(Attributor &A) override {
3906     // If we are not looking at a kernel with __kmpc_target_init and
3907     // __kmpc_target_deinit call we cannot actually manifest the information.
3908     if (!KernelInitCB || !KernelDeinitCB)
3909       return ChangeStatus::UNCHANGED;
3910 
3911     ChangeStatus Changed = ChangeStatus::UNCHANGED;
3912 
3913     bool HasBuiltStateMachine = true;
3914     if (!changeToSPMDMode(A, Changed)) {
3915       if (!KernelInitCB->getCalledFunction()->isDeclaration())
3916         HasBuiltStateMachine = buildCustomStateMachine(A, Changed);
3917       else
3918         HasBuiltStateMachine = false;
3919     }
3920 
3921     // We need to reset KernelEnvC if specific rewriting is not done.
3922     ConstantStruct *ExistingKernelEnvC =
3923         KernelInfo::getKernelEnvironementFromKernelInitCB(KernelInitCB);
3924     ConstantInt *OldUseGenericStateMachineVal =
3925         KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
3926             ExistingKernelEnvC);
3927     if (!HasBuiltStateMachine)
3928       setUseGenericStateMachineOfKernelEnvironment(
3929           OldUseGenericStateMachineVal);
3930 
3931     // At last, update the KernelEnvc
3932     GlobalVariable *KernelEnvGV =
3933         KernelInfo::getKernelEnvironementGVFromKernelInitCB(KernelInitCB);
3934     if (KernelEnvGV->getInitializer() != KernelEnvC) {
3935       KernelEnvGV->setInitializer(KernelEnvC);
3936       Changed = ChangeStatus::CHANGED;
3937     }
3938 
3939     return Changed;
3940   }
3941 
3942   void insertInstructionGuardsHelper(Attributor &A) {
3943     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3944 
3945     auto CreateGuardedRegion = [&](Instruction *RegionStartI,
3946                                    Instruction *RegionEndI) {
3947       LoopInfo *LI = nullptr;
3948       DominatorTree *DT = nullptr;
3949       MemorySSAUpdater *MSU = nullptr;
3950       using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
3951 
3952       BasicBlock *ParentBB = RegionStartI->getParent();
3953       Function *Fn = ParentBB->getParent();
3954       Module &M = *Fn->getParent();
3955 
3956       // Create all the blocks and logic.
3957       // ParentBB:
3958       //    goto RegionCheckTidBB
3959       // RegionCheckTidBB:
3960       //    Tid = __kmpc_hardware_thread_id()
3961       //    if (Tid != 0)
3962       //        goto RegionBarrierBB
3963       // RegionStartBB:
3964       //    <execute instructions guarded>
3965       //    goto RegionEndBB
3966       // RegionEndBB:
3967       //    <store escaping values to shared mem>
3968       //    goto RegionBarrierBB
3969       //  RegionBarrierBB:
3970       //    __kmpc_simple_barrier_spmd()
3971       //    // second barrier is omitted if lacking escaping values.
3972       //    <load escaping values from shared mem>
3973       //    __kmpc_simple_barrier_spmd()
3974       //    goto RegionExitBB
3975       // RegionExitBB:
3976       //    <execute rest of instructions>
3977 
3978       BasicBlock *RegionEndBB = SplitBlock(ParentBB, RegionEndI->getNextNode(),
3979                                            DT, LI, MSU, "region.guarded.end");
3980       BasicBlock *RegionBarrierBB =
3981           SplitBlock(RegionEndBB, &*RegionEndBB->getFirstInsertionPt(), DT, LI,
3982                      MSU, "region.barrier");
3983       BasicBlock *RegionExitBB =
3984           SplitBlock(RegionBarrierBB, &*RegionBarrierBB->getFirstInsertionPt(),
3985                      DT, LI, MSU, "region.exit");
3986       BasicBlock *RegionStartBB =
3987           SplitBlock(ParentBB, RegionStartI, DT, LI, MSU, "region.guarded");
3988 
3989       assert(ParentBB->getUniqueSuccessor() == RegionStartBB &&
3990              "Expected a different CFG");
3991 
3992       BasicBlock *RegionCheckTidBB = SplitBlock(
3993           ParentBB, ParentBB->getTerminator(), DT, LI, MSU, "region.check.tid");
3994 
3995       // Register basic blocks with the Attributor.
3996       A.registerManifestAddedBasicBlock(*RegionEndBB);
3997       A.registerManifestAddedBasicBlock(*RegionBarrierBB);
3998       A.registerManifestAddedBasicBlock(*RegionExitBB);
3999       A.registerManifestAddedBasicBlock(*RegionStartBB);
4000       A.registerManifestAddedBasicBlock(*RegionCheckTidBB);
4001 
4002       bool HasBroadcastValues = false;
4003       // Find escaping outputs from the guarded region to outside users and
4004       // broadcast their values to them.
4005       for (Instruction &I : *RegionStartBB) {
4006         SmallVector<Use *, 4> OutsideUses;
4007         for (Use &U : I.uses()) {
4008           Instruction &UsrI = *cast<Instruction>(U.getUser());
4009           if (UsrI.getParent() != RegionStartBB)
4010             OutsideUses.push_back(&U);
4011         }
4012 
4013         if (OutsideUses.empty())
4014           continue;
4015 
4016         HasBroadcastValues = true;
4017 
4018         // Emit a global variable in shared memory to store the broadcasted
4019         // value.
4020         auto *SharedMem = new GlobalVariable(
4021             M, I.getType(), /* IsConstant */ false,
4022             GlobalValue::InternalLinkage, UndefValue::get(I.getType()),
4023             sanitizeForGlobalName(
4024                 (I.getName() + ".guarded.output.alloc").str()),
4025             nullptr, GlobalValue::NotThreadLocal,
4026             static_cast<unsigned>(AddressSpace::Shared));
4027 
4028         // Emit a store instruction to update the value.
4029         new StoreInst(&I, SharedMem, RegionEndBB->getTerminator());
4030 
4031         LoadInst *LoadI = new LoadInst(I.getType(), SharedMem,
4032                                        I.getName() + ".guarded.output.load",
4033                                        RegionBarrierBB->getTerminator());
4034 
4035         // Emit a load instruction and replace uses of the output value.
4036         for (Use *U : OutsideUses)
4037           A.changeUseAfterManifest(*U, *LoadI);
4038       }
4039 
4040       auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4041 
4042       // Go to tid check BB in ParentBB.
4043       const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
4044       ParentBB->getTerminator()->eraseFromParent();
4045       OpenMPIRBuilder::LocationDescription Loc(
4046           InsertPointTy(ParentBB, ParentBB->end()), DL);
4047       OMPInfoCache.OMPBuilder.updateToLocation(Loc);
4048       uint32_t SrcLocStrSize;
4049       auto *SrcLocStr =
4050           OMPInfoCache.OMPBuilder.getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4051       Value *Ident =
4052           OMPInfoCache.OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4053       BranchInst::Create(RegionCheckTidBB, ParentBB)->setDebugLoc(DL);
4054 
4055       // Add check for Tid in RegionCheckTidBB
4056       RegionCheckTidBB->getTerminator()->eraseFromParent();
4057       OpenMPIRBuilder::LocationDescription LocRegionCheckTid(
4058           InsertPointTy(RegionCheckTidBB, RegionCheckTidBB->end()), DL);
4059       OMPInfoCache.OMPBuilder.updateToLocation(LocRegionCheckTid);
4060       FunctionCallee HardwareTidFn =
4061           OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4062               M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
4063       CallInst *Tid =
4064           OMPInfoCache.OMPBuilder.Builder.CreateCall(HardwareTidFn, {});
4065       Tid->setDebugLoc(DL);
4066       OMPInfoCache.setCallingConvention(HardwareTidFn, Tid);
4067       Value *TidCheck = OMPInfoCache.OMPBuilder.Builder.CreateIsNull(Tid);
4068       OMPInfoCache.OMPBuilder.Builder
4069           .CreateCondBr(TidCheck, RegionStartBB, RegionBarrierBB)
4070           ->setDebugLoc(DL);
4071 
4072       // First barrier for synchronization, ensures main thread has updated
4073       // values.
4074       FunctionCallee BarrierFn =
4075           OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4076               M, OMPRTL___kmpc_barrier_simple_spmd);
4077       OMPInfoCache.OMPBuilder.updateToLocation(InsertPointTy(
4078           RegionBarrierBB, RegionBarrierBB->getFirstInsertionPt()));
4079       CallInst *Barrier =
4080           OMPInfoCache.OMPBuilder.Builder.CreateCall(BarrierFn, {Ident, Tid});
4081       Barrier->setDebugLoc(DL);
4082       OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
4083 
4084       // Second barrier ensures workers have read broadcast values.
4085       if (HasBroadcastValues) {
4086         CallInst *Barrier = CallInst::Create(BarrierFn, {Ident, Tid}, "",
4087                                              RegionBarrierBB->getTerminator());
4088         Barrier->setDebugLoc(DL);
4089         OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
4090       }
4091     };
4092 
4093     auto &AllocSharedRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
4094     SmallPtrSet<BasicBlock *, 8> Visited;
4095     for (Instruction *GuardedI : SPMDCompatibilityTracker) {
4096       BasicBlock *BB = GuardedI->getParent();
4097       if (!Visited.insert(BB).second)
4098         continue;
4099 
4100       SmallVector<std::pair<Instruction *, Instruction *>> Reorders;
4101       Instruction *LastEffect = nullptr;
4102       BasicBlock::reverse_iterator IP = BB->rbegin(), IPEnd = BB->rend();
4103       while (++IP != IPEnd) {
4104         if (!IP->mayHaveSideEffects() && !IP->mayReadFromMemory())
4105           continue;
4106         Instruction *I = &*IP;
4107         if (OpenMPOpt::getCallIfRegularCall(*I, &AllocSharedRFI))
4108           continue;
4109         if (!I->user_empty() || !SPMDCompatibilityTracker.contains(I)) {
4110           LastEffect = nullptr;
4111           continue;
4112         }
4113         if (LastEffect)
4114           Reorders.push_back({I, LastEffect});
4115         LastEffect = &*IP;
4116       }
4117       for (auto &Reorder : Reorders)
4118         Reorder.first->moveBefore(Reorder.second);
4119     }
4120 
4121     SmallVector<std::pair<Instruction *, Instruction *>, 4> GuardedRegions;
4122 
4123     for (Instruction *GuardedI : SPMDCompatibilityTracker) {
4124       BasicBlock *BB = GuardedI->getParent();
4125       auto *CalleeAA = A.lookupAAFor<AAKernelInfo>(
4126           IRPosition::function(*GuardedI->getFunction()), nullptr,
4127           DepClassTy::NONE);
4128       assert(CalleeAA != nullptr && "Expected Callee AAKernelInfo");
4129       auto &CalleeAAFunction = *cast<AAKernelInfoFunction>(CalleeAA);
4130       // Continue if instruction is already guarded.
4131       if (CalleeAAFunction.getGuardedInstructions().contains(GuardedI))
4132         continue;
4133 
4134       Instruction *GuardedRegionStart = nullptr, *GuardedRegionEnd = nullptr;
4135       for (Instruction &I : *BB) {
4136         // If instruction I needs to be guarded update the guarded region
4137         // bounds.
4138         if (SPMDCompatibilityTracker.contains(&I)) {
4139           CalleeAAFunction.getGuardedInstructions().insert(&I);
4140           if (GuardedRegionStart)
4141             GuardedRegionEnd = &I;
4142           else
4143             GuardedRegionStart = GuardedRegionEnd = &I;
4144 
4145           continue;
4146         }
4147 
4148         // Instruction I does not need guarding, store
4149         // any region found and reset bounds.
4150         if (GuardedRegionStart) {
4151           GuardedRegions.push_back(
4152               std::make_pair(GuardedRegionStart, GuardedRegionEnd));
4153           GuardedRegionStart = nullptr;
4154           GuardedRegionEnd = nullptr;
4155         }
4156       }
4157     }
4158 
4159     for (auto &GR : GuardedRegions)
4160       CreateGuardedRegion(GR.first, GR.second);
4161   }
4162 
4163   void forceSingleThreadPerWorkgroupHelper(Attributor &A) {
4164     // Only allow 1 thread per workgroup to continue executing the user code.
4165     //
4166     //     InitCB = __kmpc_target_init(...)
4167     //     ThreadIdInBlock = __kmpc_get_hardware_thread_id_in_block();
4168     //     if (ThreadIdInBlock != 0) return;
4169     // UserCode:
4170     //     // user code
4171     //
4172     auto &Ctx = getAnchorValue().getContext();
4173     Function *Kernel = getAssociatedFunction();
4174     assert(Kernel && "Expected an associated function!");
4175 
4176     // Create block for user code to branch to from initial block.
4177     BasicBlock *InitBB = KernelInitCB->getParent();
4178     BasicBlock *UserCodeBB = InitBB->splitBasicBlock(
4179         KernelInitCB->getNextNode(), "main.thread.user_code");
4180     BasicBlock *ReturnBB =
4181         BasicBlock::Create(Ctx, "exit.threads", Kernel, UserCodeBB);
4182 
4183     // Register blocks with attributor:
4184     A.registerManifestAddedBasicBlock(*InitBB);
4185     A.registerManifestAddedBasicBlock(*UserCodeBB);
4186     A.registerManifestAddedBasicBlock(*ReturnBB);
4187 
4188     // Debug location:
4189     const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
4190     ReturnInst::Create(Ctx, ReturnBB)->setDebugLoc(DLoc);
4191     InitBB->getTerminator()->eraseFromParent();
4192 
4193     // Prepare call to OMPRTL___kmpc_get_hardware_thread_id_in_block.
4194     Module &M = *Kernel->getParent();
4195     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4196     FunctionCallee ThreadIdInBlockFn =
4197         OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4198             M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
4199 
4200     // Get thread ID in block.
4201     CallInst *ThreadIdInBlock =
4202         CallInst::Create(ThreadIdInBlockFn, "thread_id.in.block", InitBB);
4203     OMPInfoCache.setCallingConvention(ThreadIdInBlockFn, ThreadIdInBlock);
4204     ThreadIdInBlock->setDebugLoc(DLoc);
4205 
4206     // Eliminate all threads in the block with ID not equal to 0:
4207     Instruction *IsMainThread =
4208         ICmpInst::Create(ICmpInst::ICmp, CmpInst::ICMP_NE, ThreadIdInBlock,
4209                          ConstantInt::get(ThreadIdInBlock->getType(), 0),
4210                          "thread.is_main", InitBB);
4211     IsMainThread->setDebugLoc(DLoc);
4212     BranchInst::Create(ReturnBB, UserCodeBB, IsMainThread, InitBB);
4213   }
4214 
4215   bool changeToSPMDMode(Attributor &A, ChangeStatus &Changed) {
4216     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4217 
4218     // We cannot change to SPMD mode if the runtime functions aren't availible.
4219     if (!OMPInfoCache.runtimeFnsAvailable(
4220             {OMPRTL___kmpc_get_hardware_thread_id_in_block,
4221              OMPRTL___kmpc_barrier_simple_spmd}))
4222       return false;
4223 
4224     if (!SPMDCompatibilityTracker.isAssumed()) {
4225       for (Instruction *NonCompatibleI : SPMDCompatibilityTracker) {
4226         if (!NonCompatibleI)
4227           continue;
4228 
4229         // Skip diagnostics on calls to known OpenMP runtime functions for now.
4230         if (auto *CB = dyn_cast<CallBase>(NonCompatibleI))
4231           if (OMPInfoCache.RTLFunctions.contains(CB->getCalledFunction()))
4232             continue;
4233 
4234         auto Remark = [&](OptimizationRemarkAnalysis ORA) {
4235           ORA << "Value has potential side effects preventing SPMD-mode "
4236                  "execution";
4237           if (isa<CallBase>(NonCompatibleI)) {
4238             ORA << ". Add `__attribute__((assume(\"ompx_spmd_amenable\")))` to "
4239                    "the called function to override";
4240           }
4241           return ORA << ".";
4242         };
4243         A.emitRemark<OptimizationRemarkAnalysis>(NonCompatibleI, "OMP121",
4244                                                  Remark);
4245 
4246         LLVM_DEBUG(dbgs() << TAG << "SPMD-incompatible side-effect: "
4247                           << *NonCompatibleI << "\n");
4248       }
4249 
4250       return false;
4251     }
4252 
4253     // Get the actual kernel, could be the caller of the anchor scope if we have
4254     // a debug wrapper.
4255     Function *Kernel = getAnchorScope();
4256     if (Kernel->hasLocalLinkage()) {
4257       assert(Kernel->hasOneUse() && "Unexpected use of debug kernel wrapper.");
4258       auto *CB = cast<CallBase>(Kernel->user_back());
4259       Kernel = CB->getCaller();
4260     }
4261     assert(omp::isOpenMPKernel(*Kernel) && "Expected kernel function!");
4262 
4263     // Check if the kernel is already in SPMD mode, if so, return success.
4264     ConstantStruct *ExistingKernelEnvC =
4265         KernelInfo::getKernelEnvironementFromKernelInitCB(KernelInitCB);
4266     auto *ExecModeC =
4267         KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC);
4268     const int8_t ExecModeVal = ExecModeC->getSExtValue();
4269     if (ExecModeVal != OMP_TGT_EXEC_MODE_GENERIC)
4270       return true;
4271 
4272     // We will now unconditionally modify the IR, indicate a change.
4273     Changed = ChangeStatus::CHANGED;
4274 
4275     // Do not use instruction guards when no parallel is present inside
4276     // the target region.
4277     if (mayContainParallelRegion())
4278       insertInstructionGuardsHelper(A);
4279     else
4280       forceSingleThreadPerWorkgroupHelper(A);
4281 
4282     // Adjust the global exec mode flag that tells the runtime what mode this
4283     // kernel is executed in.
4284     assert(ExecModeVal == OMP_TGT_EXEC_MODE_GENERIC &&
4285            "Initially non-SPMD kernel has SPMD exec mode!");
4286     setExecModeOfKernelEnvironment(
4287         ConstantInt::get(ExecModeC->getIntegerType(),
4288                          ExecModeVal | OMP_TGT_EXEC_MODE_GENERIC_SPMD));
4289 
4290     ++NumOpenMPTargetRegionKernelsSPMD;
4291 
4292     auto Remark = [&](OptimizationRemark OR) {
4293       return OR << "Transformed generic-mode kernel to SPMD-mode.";
4294     };
4295     A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP120", Remark);
4296     return true;
4297   };
4298 
4299   bool buildCustomStateMachine(Attributor &A, ChangeStatus &Changed) {
4300     // If we have disabled state machine rewrites, don't make a custom one
4301     if (DisableOpenMPOptStateMachineRewrite)
4302       return false;
4303 
4304     // Don't rewrite the state machine if we are not in a valid state.
4305     if (!ReachedKnownParallelRegions.isValidState())
4306       return false;
4307 
4308     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4309     if (!OMPInfoCache.runtimeFnsAvailable(
4310             {OMPRTL___kmpc_get_hardware_num_threads_in_block,
4311              OMPRTL___kmpc_get_warp_size, OMPRTL___kmpc_barrier_simple_generic,
4312              OMPRTL___kmpc_kernel_parallel, OMPRTL___kmpc_kernel_end_parallel}))
4313       return false;
4314 
4315     ConstantStruct *ExistingKernelEnvC =
4316         KernelInfo::getKernelEnvironementFromKernelInitCB(KernelInitCB);
4317 
4318     // Check if the current configuration is non-SPMD and generic state machine.
4319     // If we already have SPMD mode or a custom state machine we do not need to
4320     // go any further. If it is anything but a constant something is weird and
4321     // we give up.
4322     ConstantInt *UseStateMachineC =
4323         KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
4324             ExistingKernelEnvC);
4325     ConstantInt *ModeC =
4326         KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC);
4327 
4328     // If we are stuck with generic mode, try to create a custom device (=GPU)
4329     // state machine which is specialized for the parallel regions that are
4330     // reachable by the kernel.
4331     if (UseStateMachineC->isZero() ||
4332         (ModeC->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD))
4333       return false;
4334 
4335     Changed = ChangeStatus::CHANGED;
4336 
4337     // If not SPMD mode, indicate we use a custom state machine now.
4338     setUseGenericStateMachineOfKernelEnvironment(
4339         ConstantInt::get(UseStateMachineC->getIntegerType(), false));
4340 
4341     // If we don't actually need a state machine we are done here. This can
4342     // happen if there simply are no parallel regions. In the resulting kernel
4343     // all worker threads will simply exit right away, leaving the main thread
4344     // to do the work alone.
4345     if (!mayContainParallelRegion()) {
4346       ++NumOpenMPTargetRegionKernelsWithoutStateMachine;
4347 
4348       auto Remark = [&](OptimizationRemark OR) {
4349         return OR << "Removing unused state machine from generic-mode kernel.";
4350       };
4351       A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP130", Remark);
4352 
4353       return true;
4354     }
4355 
4356     // Keep track in the statistics of our new shiny custom state machine.
4357     if (ReachedUnknownParallelRegions.empty()) {
4358       ++NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback;
4359 
4360       auto Remark = [&](OptimizationRemark OR) {
4361         return OR << "Rewriting generic-mode kernel with a customized state "
4362                      "machine.";
4363       };
4364       A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP131", Remark);
4365     } else {
4366       ++NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback;
4367 
4368       auto Remark = [&](OptimizationRemarkAnalysis OR) {
4369         return OR << "Generic-mode kernel is executed with a customized state "
4370                      "machine that requires a fallback.";
4371       };
4372       A.emitRemark<OptimizationRemarkAnalysis>(KernelInitCB, "OMP132", Remark);
4373 
4374       // Tell the user why we ended up with a fallback.
4375       for (CallBase *UnknownParallelRegionCB : ReachedUnknownParallelRegions) {
4376         if (!UnknownParallelRegionCB)
4377           continue;
4378         auto Remark = [&](OptimizationRemarkAnalysis ORA) {
4379           return ORA << "Call may contain unknown parallel regions. Use "
4380                      << "`__attribute__((assume(\"omp_no_parallelism\")))` to "
4381                         "override.";
4382         };
4383         A.emitRemark<OptimizationRemarkAnalysis>(UnknownParallelRegionCB,
4384                                                  "OMP133", Remark);
4385       }
4386     }
4387 
4388     // Create all the blocks:
4389     //
4390     //                       InitCB = __kmpc_target_init(...)
4391     //                       BlockHwSize =
4392     //                         __kmpc_get_hardware_num_threads_in_block();
4393     //                       WarpSize = __kmpc_get_warp_size();
4394     //                       BlockSize = BlockHwSize - WarpSize;
4395     // IsWorkerCheckBB:      bool IsWorker = InitCB != -1;
4396     //                       if (IsWorker) {
4397     //                         if (InitCB >= BlockSize) return;
4398     // SMBeginBB:               __kmpc_barrier_simple_generic(...);
4399     //                         void *WorkFn;
4400     //                         bool Active = __kmpc_kernel_parallel(&WorkFn);
4401     //                         if (!WorkFn) return;
4402     // SMIsActiveCheckBB:       if (Active) {
4403     // SMIfCascadeCurrentBB:      if      (WorkFn == <ParFn0>)
4404     //                              ParFn0(...);
4405     // SMIfCascadeCurrentBB:      else if (WorkFn == <ParFn1>)
4406     //                              ParFn1(...);
4407     //                            ...
4408     // SMIfCascadeCurrentBB:      else
4409     //                              ((WorkFnTy*)WorkFn)(...);
4410     // SMEndParallelBB:           __kmpc_kernel_end_parallel(...);
4411     //                          }
4412     // SMDoneBB:                __kmpc_barrier_simple_generic(...);
4413     //                          goto SMBeginBB;
4414     //                       }
4415     // UserCodeEntryBB:      // user code
4416     //                       __kmpc_target_deinit(...)
4417     //
4418     auto &Ctx = getAnchorValue().getContext();
4419     Function *Kernel = getAssociatedFunction();
4420     assert(Kernel && "Expected an associated function!");
4421 
4422     BasicBlock *InitBB = KernelInitCB->getParent();
4423     BasicBlock *UserCodeEntryBB = InitBB->splitBasicBlock(
4424         KernelInitCB->getNextNode(), "thread.user_code.check");
4425     BasicBlock *IsWorkerCheckBB =
4426         BasicBlock::Create(Ctx, "is_worker_check", Kernel, UserCodeEntryBB);
4427     BasicBlock *StateMachineBeginBB = BasicBlock::Create(
4428         Ctx, "worker_state_machine.begin", Kernel, UserCodeEntryBB);
4429     BasicBlock *StateMachineFinishedBB = BasicBlock::Create(
4430         Ctx, "worker_state_machine.finished", Kernel, UserCodeEntryBB);
4431     BasicBlock *StateMachineIsActiveCheckBB = BasicBlock::Create(
4432         Ctx, "worker_state_machine.is_active.check", Kernel, UserCodeEntryBB);
4433     BasicBlock *StateMachineIfCascadeCurrentBB =
4434         BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
4435                            Kernel, UserCodeEntryBB);
4436     BasicBlock *StateMachineEndParallelBB =
4437         BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.end",
4438                            Kernel, UserCodeEntryBB);
4439     BasicBlock *StateMachineDoneBarrierBB = BasicBlock::Create(
4440         Ctx, "worker_state_machine.done.barrier", Kernel, UserCodeEntryBB);
4441     A.registerManifestAddedBasicBlock(*InitBB);
4442     A.registerManifestAddedBasicBlock(*UserCodeEntryBB);
4443     A.registerManifestAddedBasicBlock(*IsWorkerCheckBB);
4444     A.registerManifestAddedBasicBlock(*StateMachineBeginBB);
4445     A.registerManifestAddedBasicBlock(*StateMachineFinishedBB);
4446     A.registerManifestAddedBasicBlock(*StateMachineIsActiveCheckBB);
4447     A.registerManifestAddedBasicBlock(*StateMachineIfCascadeCurrentBB);
4448     A.registerManifestAddedBasicBlock(*StateMachineEndParallelBB);
4449     A.registerManifestAddedBasicBlock(*StateMachineDoneBarrierBB);
4450 
4451     const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
4452     ReturnInst::Create(Ctx, StateMachineFinishedBB)->setDebugLoc(DLoc);
4453     InitBB->getTerminator()->eraseFromParent();
4454 
4455     Instruction *IsWorker =
4456         ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_NE, KernelInitCB,
4457                          ConstantInt::get(KernelInitCB->getType(), -1),
4458                          "thread.is_worker", InitBB);
4459     IsWorker->setDebugLoc(DLoc);
4460     BranchInst::Create(IsWorkerCheckBB, UserCodeEntryBB, IsWorker, InitBB);
4461 
4462     Module &M = *Kernel->getParent();
4463     FunctionCallee BlockHwSizeFn =
4464         OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4465             M, OMPRTL___kmpc_get_hardware_num_threads_in_block);
4466     FunctionCallee WarpSizeFn =
4467         OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4468             M, OMPRTL___kmpc_get_warp_size);
4469     CallInst *BlockHwSize =
4470         CallInst::Create(BlockHwSizeFn, "block.hw_size", IsWorkerCheckBB);
4471     OMPInfoCache.setCallingConvention(BlockHwSizeFn, BlockHwSize);
4472     BlockHwSize->setDebugLoc(DLoc);
4473     CallInst *WarpSize =
4474         CallInst::Create(WarpSizeFn, "warp.size", IsWorkerCheckBB);
4475     OMPInfoCache.setCallingConvention(WarpSizeFn, WarpSize);
4476     WarpSize->setDebugLoc(DLoc);
4477     Instruction *BlockSize = BinaryOperator::CreateSub(
4478         BlockHwSize, WarpSize, "block.size", IsWorkerCheckBB);
4479     BlockSize->setDebugLoc(DLoc);
4480     Instruction *IsMainOrWorker = ICmpInst::Create(
4481         ICmpInst::ICmp, llvm::CmpInst::ICMP_SLT, KernelInitCB, BlockSize,
4482         "thread.is_main_or_worker", IsWorkerCheckBB);
4483     IsMainOrWorker->setDebugLoc(DLoc);
4484     BranchInst::Create(StateMachineBeginBB, StateMachineFinishedBB,
4485                        IsMainOrWorker, IsWorkerCheckBB);
4486 
4487     // Create local storage for the work function pointer.
4488     const DataLayout &DL = M.getDataLayout();
4489     Type *VoidPtrTy = PointerType::getUnqual(Ctx);
4490     Instruction *WorkFnAI =
4491         new AllocaInst(VoidPtrTy, DL.getAllocaAddrSpace(), nullptr,
4492                        "worker.work_fn.addr", &Kernel->getEntryBlock().front());
4493     WorkFnAI->setDebugLoc(DLoc);
4494 
4495     OMPInfoCache.OMPBuilder.updateToLocation(
4496         OpenMPIRBuilder::LocationDescription(
4497             IRBuilder<>::InsertPoint(StateMachineBeginBB,
4498                                      StateMachineBeginBB->end()),
4499             DLoc));
4500 
4501     Value *Ident = KernelInfo::getIdentFromKernelEnvironment(KernelEnvC);
4502     Value *GTid = KernelInitCB;
4503 
4504     FunctionCallee BarrierFn =
4505         OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4506             M, OMPRTL___kmpc_barrier_simple_generic);
4507     CallInst *Barrier =
4508         CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineBeginBB);
4509     OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
4510     Barrier->setDebugLoc(DLoc);
4511 
4512     if (WorkFnAI->getType()->getPointerAddressSpace() !=
4513         (unsigned int)AddressSpace::Generic) {
4514       WorkFnAI = new AddrSpaceCastInst(
4515           WorkFnAI, PointerType::get(Ctx, (unsigned int)AddressSpace::Generic),
4516           WorkFnAI->getName() + ".generic", StateMachineBeginBB);
4517       WorkFnAI->setDebugLoc(DLoc);
4518     }
4519 
4520     FunctionCallee KernelParallelFn =
4521         OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4522             M, OMPRTL___kmpc_kernel_parallel);
4523     CallInst *IsActiveWorker = CallInst::Create(
4524         KernelParallelFn, {WorkFnAI}, "worker.is_active", StateMachineBeginBB);
4525     OMPInfoCache.setCallingConvention(KernelParallelFn, IsActiveWorker);
4526     IsActiveWorker->setDebugLoc(DLoc);
4527     Instruction *WorkFn = new LoadInst(VoidPtrTy, WorkFnAI, "worker.work_fn",
4528                                        StateMachineBeginBB);
4529     WorkFn->setDebugLoc(DLoc);
4530 
4531     FunctionType *ParallelRegionFnTy = FunctionType::get(
4532         Type::getVoidTy(Ctx), {Type::getInt16Ty(Ctx), Type::getInt32Ty(Ctx)},
4533         false);
4534 
4535     Instruction *IsDone =
4536         ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn,
4537                          Constant::getNullValue(VoidPtrTy), "worker.is_done",
4538                          StateMachineBeginBB);
4539     IsDone->setDebugLoc(DLoc);
4540     BranchInst::Create(StateMachineFinishedBB, StateMachineIsActiveCheckBB,
4541                        IsDone, StateMachineBeginBB)
4542         ->setDebugLoc(DLoc);
4543 
4544     BranchInst::Create(StateMachineIfCascadeCurrentBB,
4545                        StateMachineDoneBarrierBB, IsActiveWorker,
4546                        StateMachineIsActiveCheckBB)
4547         ->setDebugLoc(DLoc);
4548 
4549     Value *ZeroArg =
4550         Constant::getNullValue(ParallelRegionFnTy->getParamType(0));
4551 
4552     const unsigned int WrapperFunctionArgNo = 6;
4553 
4554     // Now that we have most of the CFG skeleton it is time for the if-cascade
4555     // that checks the function pointer we got from the runtime against the
4556     // parallel regions we expect, if there are any.
4557     for (int I = 0, E = ReachedKnownParallelRegions.size(); I < E; ++I) {
4558       auto *CB = ReachedKnownParallelRegions[I];
4559       auto *ParallelRegion = dyn_cast<Function>(
4560           CB->getArgOperand(WrapperFunctionArgNo)->stripPointerCasts());
4561       BasicBlock *PRExecuteBB = BasicBlock::Create(
4562           Ctx, "worker_state_machine.parallel_region.execute", Kernel,
4563           StateMachineEndParallelBB);
4564       CallInst::Create(ParallelRegion, {ZeroArg, GTid}, "", PRExecuteBB)
4565           ->setDebugLoc(DLoc);
4566       BranchInst::Create(StateMachineEndParallelBB, PRExecuteBB)
4567           ->setDebugLoc(DLoc);
4568 
4569       BasicBlock *PRNextBB =
4570           BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
4571                              Kernel, StateMachineEndParallelBB);
4572       A.registerManifestAddedBasicBlock(*PRExecuteBB);
4573       A.registerManifestAddedBasicBlock(*PRNextBB);
4574 
4575       // Check if we need to compare the pointer at all or if we can just
4576       // call the parallel region function.
4577       Value *IsPR;
4578       if (I + 1 < E || !ReachedUnknownParallelRegions.empty()) {
4579         Instruction *CmpI = ICmpInst::Create(
4580             ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn, ParallelRegion,
4581             "worker.check_parallel_region", StateMachineIfCascadeCurrentBB);
4582         CmpI->setDebugLoc(DLoc);
4583         IsPR = CmpI;
4584       } else {
4585         IsPR = ConstantInt::getTrue(Ctx);
4586       }
4587 
4588       BranchInst::Create(PRExecuteBB, PRNextBB, IsPR,
4589                          StateMachineIfCascadeCurrentBB)
4590           ->setDebugLoc(DLoc);
4591       StateMachineIfCascadeCurrentBB = PRNextBB;
4592     }
4593 
4594     // At the end of the if-cascade we place the indirect function pointer call
4595     // in case we might need it, that is if there can be parallel regions we
4596     // have not handled in the if-cascade above.
4597     if (!ReachedUnknownParallelRegions.empty()) {
4598       StateMachineIfCascadeCurrentBB->setName(
4599           "worker_state_machine.parallel_region.fallback.execute");
4600       CallInst::Create(ParallelRegionFnTy, WorkFn, {ZeroArg, GTid}, "",
4601                        StateMachineIfCascadeCurrentBB)
4602           ->setDebugLoc(DLoc);
4603     }
4604     BranchInst::Create(StateMachineEndParallelBB,
4605                        StateMachineIfCascadeCurrentBB)
4606         ->setDebugLoc(DLoc);
4607 
4608     FunctionCallee EndParallelFn =
4609         OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4610             M, OMPRTL___kmpc_kernel_end_parallel);
4611     CallInst *EndParallel =
4612         CallInst::Create(EndParallelFn, {}, "", StateMachineEndParallelBB);
4613     OMPInfoCache.setCallingConvention(EndParallelFn, EndParallel);
4614     EndParallel->setDebugLoc(DLoc);
4615     BranchInst::Create(StateMachineDoneBarrierBB, StateMachineEndParallelBB)
4616         ->setDebugLoc(DLoc);
4617 
4618     CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineDoneBarrierBB)
4619         ->setDebugLoc(DLoc);
4620     BranchInst::Create(StateMachineBeginBB, StateMachineDoneBarrierBB)
4621         ->setDebugLoc(DLoc);
4622 
4623     return true;
4624   }
4625 
4626   /// Fixpoint iteration update function. Will be called every time a dependence
4627   /// changed its state (and in the beginning).
4628   ChangeStatus updateImpl(Attributor &A) override {
4629     KernelInfoState StateBefore = getState();
4630 
4631     // When we leave this function this RAII will make sure the member
4632     // KernelEnvC is updated properly depending on the state. That member is
4633     // used for simplification of values and needs to be up to date at all
4634     // times.
4635     struct UpdateKernelEnvCRAII {
4636       AAKernelInfoFunction &AA;
4637 
4638       UpdateKernelEnvCRAII(AAKernelInfoFunction &AA) : AA(AA) {}
4639 
4640       ~UpdateKernelEnvCRAII() {
4641         if (!AA.KernelEnvC)
4642           return;
4643 
4644         ConstantStruct *ExistingKernelEnvC =
4645             KernelInfo::getKernelEnvironementFromKernelInitCB(AA.KernelInitCB);
4646 
4647         if (!AA.isValidState()) {
4648           AA.KernelEnvC = ExistingKernelEnvC;
4649           return;
4650         }
4651 
4652         if (!AA.ReachedKnownParallelRegions.isValidState())
4653           AA.setUseGenericStateMachineOfKernelEnvironment(
4654               KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
4655                   ExistingKernelEnvC));
4656 
4657         if (!AA.SPMDCompatibilityTracker.isValidState())
4658           AA.setExecModeOfKernelEnvironment(
4659               KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC));
4660 
4661         ConstantInt *MayUseNestedParallelismC =
4662             KernelInfo::getMayUseNestedParallelismFromKernelEnvironment(
4663                 AA.KernelEnvC);
4664         ConstantInt *NewMayUseNestedParallelismC = ConstantInt::get(
4665             MayUseNestedParallelismC->getIntegerType(), AA.NestedParallelism);
4666         AA.setMayUseNestedParallelismOfKernelEnvironment(
4667             NewMayUseNestedParallelismC);
4668       }
4669     } RAII(*this);
4670 
4671     // Callback to check a read/write instruction.
4672     auto CheckRWInst = [&](Instruction &I) {
4673       // We handle calls later.
4674       if (isa<CallBase>(I))
4675         return true;
4676       // We only care about write effects.
4677       if (!I.mayWriteToMemory())
4678         return true;
4679       if (auto *SI = dyn_cast<StoreInst>(&I)) {
4680         const auto *UnderlyingObjsAA = A.getAAFor<AAUnderlyingObjects>(
4681             *this, IRPosition::value(*SI->getPointerOperand()),
4682             DepClassTy::OPTIONAL);
4683         auto *HS = A.getAAFor<AAHeapToStack>(
4684             *this, IRPosition::function(*I.getFunction()),
4685             DepClassTy::OPTIONAL);
4686         if (UnderlyingObjsAA &&
4687             UnderlyingObjsAA->forallUnderlyingObjects([&](Value &Obj) {
4688               if (AA::isAssumedThreadLocalObject(A, Obj, *this))
4689                 return true;
4690               // Check for AAHeapToStack moved objects which must not be
4691               // guarded.
4692               auto *CB = dyn_cast<CallBase>(&Obj);
4693               return CB && HS && HS->isAssumedHeapToStack(*CB);
4694             }))
4695           return true;
4696       }
4697 
4698       // Insert instruction that needs guarding.
4699       SPMDCompatibilityTracker.insert(&I);
4700       return true;
4701     };
4702 
4703     bool UsedAssumedInformationInCheckRWInst = false;
4704     if (!SPMDCompatibilityTracker.isAtFixpoint())
4705       if (!A.checkForAllReadWriteInstructions(
4706               CheckRWInst, *this, UsedAssumedInformationInCheckRWInst))
4707         SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4708 
4709     bool UsedAssumedInformationFromReachingKernels = false;
4710     if (!IsKernelEntry) {
4711       updateParallelLevels(A);
4712 
4713       bool AllReachingKernelsKnown = true;
4714       updateReachingKernelEntries(A, AllReachingKernelsKnown);
4715       UsedAssumedInformationFromReachingKernels = !AllReachingKernelsKnown;
4716 
4717       if (!SPMDCompatibilityTracker.empty()) {
4718         if (!ParallelLevels.isValidState())
4719           SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4720         else if (!ReachingKernelEntries.isValidState())
4721           SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4722         else {
4723           // Check if all reaching kernels agree on the mode as we can otherwise
4724           // not guard instructions. We might not be sure about the mode so we
4725           // we cannot fix the internal spmd-zation state either.
4726           int SPMD = 0, Generic = 0;
4727           for (auto *Kernel : ReachingKernelEntries) {
4728             auto *CBAA = A.getAAFor<AAKernelInfo>(
4729                 *this, IRPosition::function(*Kernel), DepClassTy::OPTIONAL);
4730             if (CBAA && CBAA->SPMDCompatibilityTracker.isValidState() &&
4731                 CBAA->SPMDCompatibilityTracker.isAssumed())
4732               ++SPMD;
4733             else
4734               ++Generic;
4735             if (!CBAA || !CBAA->SPMDCompatibilityTracker.isAtFixpoint())
4736               UsedAssumedInformationFromReachingKernels = true;
4737           }
4738           if (SPMD != 0 && Generic != 0)
4739             SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4740         }
4741       }
4742     }
4743 
4744     // Callback to check a call instruction.
4745     bool AllParallelRegionStatesWereFixed = true;
4746     bool AllSPMDStatesWereFixed = true;
4747     auto CheckCallInst = [&](Instruction &I) {
4748       auto &CB = cast<CallBase>(I);
4749       auto *CBAA = A.getAAFor<AAKernelInfo>(
4750           *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
4751       if (!CBAA)
4752         return false;
4753       getState() ^= CBAA->getState();
4754       AllSPMDStatesWereFixed &= CBAA->SPMDCompatibilityTracker.isAtFixpoint();
4755       AllParallelRegionStatesWereFixed &=
4756           CBAA->ReachedKnownParallelRegions.isAtFixpoint();
4757       AllParallelRegionStatesWereFixed &=
4758           CBAA->ReachedUnknownParallelRegions.isAtFixpoint();
4759       return true;
4760     };
4761 
4762     bool UsedAssumedInformationInCheckCallInst = false;
4763     if (!A.checkForAllCallLikeInstructions(
4764             CheckCallInst, *this, UsedAssumedInformationInCheckCallInst)) {
4765       LLVM_DEBUG(dbgs() << TAG
4766                         << "Failed to visit all call-like instructions!\n";);
4767       return indicatePessimisticFixpoint();
4768     }
4769 
4770     // If we haven't used any assumed information for the reached parallel
4771     // region states we can fix it.
4772     if (!UsedAssumedInformationInCheckCallInst &&
4773         AllParallelRegionStatesWereFixed) {
4774       ReachedKnownParallelRegions.indicateOptimisticFixpoint();
4775       ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
4776     }
4777 
4778     // If we haven't used any assumed information for the SPMD state we can fix
4779     // it.
4780     if (!UsedAssumedInformationInCheckRWInst &&
4781         !UsedAssumedInformationInCheckCallInst &&
4782         !UsedAssumedInformationFromReachingKernels && AllSPMDStatesWereFixed)
4783       SPMDCompatibilityTracker.indicateOptimisticFixpoint();
4784 
4785     return StateBefore == getState() ? ChangeStatus::UNCHANGED
4786                                      : ChangeStatus::CHANGED;
4787   }
4788 
4789 private:
4790   /// Update info regarding reaching kernels.
4791   void updateReachingKernelEntries(Attributor &A,
4792                                    bool &AllReachingKernelsKnown) {
4793     auto PredCallSite = [&](AbstractCallSite ACS) {
4794       Function *Caller = ACS.getInstruction()->getFunction();
4795 
4796       assert(Caller && "Caller is nullptr");
4797 
4798       auto *CAA = A.getOrCreateAAFor<AAKernelInfo>(
4799           IRPosition::function(*Caller), this, DepClassTy::REQUIRED);
4800       if (CAA && CAA->ReachingKernelEntries.isValidState()) {
4801         ReachingKernelEntries ^= CAA->ReachingKernelEntries;
4802         return true;
4803       }
4804 
4805       // We lost track of the caller of the associated function, any kernel
4806       // could reach now.
4807       ReachingKernelEntries.indicatePessimisticFixpoint();
4808 
4809       return true;
4810     };
4811 
4812     if (!A.checkForAllCallSites(PredCallSite, *this,
4813                                 true /* RequireAllCallSites */,
4814                                 AllReachingKernelsKnown))
4815       ReachingKernelEntries.indicatePessimisticFixpoint();
4816   }
4817 
4818   /// Update info regarding parallel levels.
4819   void updateParallelLevels(Attributor &A) {
4820     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4821     OMPInformationCache::RuntimeFunctionInfo &Parallel51RFI =
4822         OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
4823 
4824     auto PredCallSite = [&](AbstractCallSite ACS) {
4825       Function *Caller = ACS.getInstruction()->getFunction();
4826 
4827       assert(Caller && "Caller is nullptr");
4828 
4829       auto *CAA =
4830           A.getOrCreateAAFor<AAKernelInfo>(IRPosition::function(*Caller));
4831       if (CAA && CAA->ParallelLevels.isValidState()) {
4832         // Any function that is called by `__kmpc_parallel_51` will not be
4833         // folded as the parallel level in the function is updated. In order to
4834         // get it right, all the analysis would depend on the implentation. That
4835         // said, if in the future any change to the implementation, the analysis
4836         // could be wrong. As a consequence, we are just conservative here.
4837         if (Caller == Parallel51RFI.Declaration) {
4838           ParallelLevels.indicatePessimisticFixpoint();
4839           return true;
4840         }
4841 
4842         ParallelLevels ^= CAA->ParallelLevels;
4843 
4844         return true;
4845       }
4846 
4847       // We lost track of the caller of the associated function, any kernel
4848       // could reach now.
4849       ParallelLevels.indicatePessimisticFixpoint();
4850 
4851       return true;
4852     };
4853 
4854     bool AllCallSitesKnown = true;
4855     if (!A.checkForAllCallSites(PredCallSite, *this,
4856                                 true /* RequireAllCallSites */,
4857                                 AllCallSitesKnown))
4858       ParallelLevels.indicatePessimisticFixpoint();
4859   }
4860 };
4861 
4862 /// The call site kernel info abstract attribute, basically, what can we say
4863 /// about a call site with regards to the KernelInfoState. For now this simply
4864 /// forwards the information from the callee.
4865 struct AAKernelInfoCallSite : AAKernelInfo {
4866   AAKernelInfoCallSite(const IRPosition &IRP, Attributor &A)
4867       : AAKernelInfo(IRP, A) {}
4868 
4869   /// See AbstractAttribute::initialize(...).
4870   void initialize(Attributor &A) override {
4871     AAKernelInfo::initialize(A);
4872 
4873     CallBase &CB = cast<CallBase>(getAssociatedValue());
4874     auto *AssumptionAA = A.getAAFor<AAAssumptionInfo>(
4875         *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
4876 
4877     // Check for SPMD-mode assumptions.
4878     if (AssumptionAA && AssumptionAA->hasAssumption("ompx_spmd_amenable")) {
4879       indicateOptimisticFixpoint();
4880       return;
4881     }
4882 
4883     // First weed out calls we do not care about, that is readonly/readnone
4884     // calls, intrinsics, and "no_openmp" calls. Neither of these can reach a
4885     // parallel region or anything else we are looking for.
4886     if (!CB.mayWriteToMemory() || isa<IntrinsicInst>(CB)) {
4887       indicateOptimisticFixpoint();
4888       return;
4889     }
4890 
4891     // Next we check if we know the callee. If it is a known OpenMP function
4892     // we will handle them explicitly in the switch below. If it is not, we
4893     // will use an AAKernelInfo object on the callee to gather information and
4894     // merge that into the current state. The latter happens in the updateImpl.
4895     auto CheckCallee = [&](Function *Callee, unsigned NumCallees) {
4896       auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4897       const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
4898       if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
4899         // Unknown caller or declarations are not analyzable, we give up.
4900         if (!Callee || !A.isFunctionIPOAmendable(*Callee)) {
4901 
4902           // Unknown callees might contain parallel regions, except if they have
4903           // an appropriate assumption attached.
4904           if (!AssumptionAA ||
4905               !(AssumptionAA->hasAssumption("omp_no_openmp") ||
4906                 AssumptionAA->hasAssumption("omp_no_parallelism")))
4907             ReachedUnknownParallelRegions.insert(&CB);
4908 
4909           // If SPMDCompatibilityTracker is not fixed, we need to give up on the
4910           // idea we can run something unknown in SPMD-mode.
4911           if (!SPMDCompatibilityTracker.isAtFixpoint()) {
4912             SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4913             SPMDCompatibilityTracker.insert(&CB);
4914           }
4915 
4916           // We have updated the state for this unknown call properly, there
4917           // won't be any change so we indicate a fixpoint.
4918           indicateOptimisticFixpoint();
4919         }
4920         // If the callee is known and can be used in IPO, we will update the
4921         // state based on the callee state in updateImpl.
4922         return;
4923       }
4924       if (NumCallees > 1) {
4925         indicatePessimisticFixpoint();
4926         return;
4927       }
4928 
4929       RuntimeFunction RF = It->getSecond();
4930       switch (RF) {
4931       // All the functions we know are compatible with SPMD mode.
4932       case OMPRTL___kmpc_is_spmd_exec_mode:
4933       case OMPRTL___kmpc_distribute_static_fini:
4934       case OMPRTL___kmpc_for_static_fini:
4935       case OMPRTL___kmpc_global_thread_num:
4936       case OMPRTL___kmpc_get_hardware_num_threads_in_block:
4937       case OMPRTL___kmpc_get_hardware_num_blocks:
4938       case OMPRTL___kmpc_single:
4939       case OMPRTL___kmpc_end_single:
4940       case OMPRTL___kmpc_master:
4941       case OMPRTL___kmpc_end_master:
4942       case OMPRTL___kmpc_barrier:
4943       case OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2:
4944       case OMPRTL___kmpc_nvptx_teams_reduce_nowait_v2:
4945       case OMPRTL___kmpc_error:
4946       case OMPRTL___kmpc_flush:
4947       case OMPRTL___kmpc_get_hardware_thread_id_in_block:
4948       case OMPRTL___kmpc_get_warp_size:
4949       case OMPRTL_omp_get_thread_num:
4950       case OMPRTL_omp_get_num_threads:
4951       case OMPRTL_omp_get_max_threads:
4952       case OMPRTL_omp_in_parallel:
4953       case OMPRTL_omp_get_dynamic:
4954       case OMPRTL_omp_get_cancellation:
4955       case OMPRTL_omp_get_nested:
4956       case OMPRTL_omp_get_schedule:
4957       case OMPRTL_omp_get_thread_limit:
4958       case OMPRTL_omp_get_supported_active_levels:
4959       case OMPRTL_omp_get_max_active_levels:
4960       case OMPRTL_omp_get_level:
4961       case OMPRTL_omp_get_ancestor_thread_num:
4962       case OMPRTL_omp_get_team_size:
4963       case OMPRTL_omp_get_active_level:
4964       case OMPRTL_omp_in_final:
4965       case OMPRTL_omp_get_proc_bind:
4966       case OMPRTL_omp_get_num_places:
4967       case OMPRTL_omp_get_num_procs:
4968       case OMPRTL_omp_get_place_proc_ids:
4969       case OMPRTL_omp_get_place_num:
4970       case OMPRTL_omp_get_partition_num_places:
4971       case OMPRTL_omp_get_partition_place_nums:
4972       case OMPRTL_omp_get_wtime:
4973         break;
4974       case OMPRTL___kmpc_distribute_static_init_4:
4975       case OMPRTL___kmpc_distribute_static_init_4u:
4976       case OMPRTL___kmpc_distribute_static_init_8:
4977       case OMPRTL___kmpc_distribute_static_init_8u:
4978       case OMPRTL___kmpc_for_static_init_4:
4979       case OMPRTL___kmpc_for_static_init_4u:
4980       case OMPRTL___kmpc_for_static_init_8:
4981       case OMPRTL___kmpc_for_static_init_8u: {
4982         // Check the schedule and allow static schedule in SPMD mode.
4983         unsigned ScheduleArgOpNo = 2;
4984         auto *ScheduleTypeCI =
4985             dyn_cast<ConstantInt>(CB.getArgOperand(ScheduleArgOpNo));
4986         unsigned ScheduleTypeVal =
4987             ScheduleTypeCI ? ScheduleTypeCI->getZExtValue() : 0;
4988         switch (OMPScheduleType(ScheduleTypeVal)) {
4989         case OMPScheduleType::UnorderedStatic:
4990         case OMPScheduleType::UnorderedStaticChunked:
4991         case OMPScheduleType::OrderedDistribute:
4992         case OMPScheduleType::OrderedDistributeChunked:
4993           break;
4994         default:
4995           SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4996           SPMDCompatibilityTracker.insert(&CB);
4997           break;
4998         };
4999       } break;
5000       case OMPRTL___kmpc_target_init:
5001         KernelInitCB = &CB;
5002         break;
5003       case OMPRTL___kmpc_target_deinit:
5004         KernelDeinitCB = &CB;
5005         break;
5006       case OMPRTL___kmpc_parallel_51:
5007         if (!handleParallel51(A, CB))
5008           indicatePessimisticFixpoint();
5009         return;
5010       case OMPRTL___kmpc_omp_task:
5011         // We do not look into tasks right now, just give up.
5012         SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5013         SPMDCompatibilityTracker.insert(&CB);
5014         ReachedUnknownParallelRegions.insert(&CB);
5015         break;
5016       case OMPRTL___kmpc_alloc_shared:
5017       case OMPRTL___kmpc_free_shared:
5018         // Return without setting a fixpoint, to be resolved in updateImpl.
5019         return;
5020       default:
5021         // Unknown OpenMP runtime calls cannot be executed in SPMD-mode,
5022         // generally. However, they do not hide parallel regions.
5023         SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5024         SPMDCompatibilityTracker.insert(&CB);
5025         break;
5026       }
5027       // All other OpenMP runtime calls will not reach parallel regions so they
5028       // can be safely ignored for now. Since it is a known OpenMP runtime call
5029       // we have now modeled all effects and there is no need for any update.
5030       indicateOptimisticFixpoint();
5031     };
5032 
5033     const auto *AACE =
5034         A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::OPTIONAL);
5035     if (!AACE || !AACE->getState().isValidState() || AACE->hasUnknownCallee()) {
5036       CheckCallee(getAssociatedFunction(), 1);
5037       return;
5038     }
5039     const auto &OptimisticEdges = AACE->getOptimisticEdges();
5040     for (auto *Callee : OptimisticEdges) {
5041       CheckCallee(Callee, OptimisticEdges.size());
5042       if (isAtFixpoint())
5043         break;
5044     }
5045   }
5046 
5047   ChangeStatus updateImpl(Attributor &A) override {
5048     // TODO: Once we have call site specific value information we can provide
5049     //       call site specific liveness information and then it makes
5050     //       sense to specialize attributes for call sites arguments instead of
5051     //       redirecting requests to the callee argument.
5052     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
5053     KernelInfoState StateBefore = getState();
5054 
5055     auto CheckCallee = [&](Function *F, int NumCallees) {
5056       const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(F);
5057 
5058       // If F is not a runtime function, propagate the AAKernelInfo of the
5059       // callee.
5060       if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
5061         const IRPosition &FnPos = IRPosition::function(*F);
5062         auto *FnAA =
5063             A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED);
5064         if (!FnAA)
5065           return indicatePessimisticFixpoint();
5066         if (getState() == FnAA->getState())
5067           return ChangeStatus::UNCHANGED;
5068         getState() = FnAA->getState();
5069         return ChangeStatus::CHANGED;
5070       }
5071       if (NumCallees > 1)
5072         return indicatePessimisticFixpoint();
5073 
5074       CallBase &CB = cast<CallBase>(getAssociatedValue());
5075       if (It->getSecond() == OMPRTL___kmpc_parallel_51) {
5076         if (!handleParallel51(A, CB))
5077           return indicatePessimisticFixpoint();
5078         return StateBefore == getState() ? ChangeStatus::UNCHANGED
5079                                          : ChangeStatus::CHANGED;
5080       }
5081 
5082       // F is a runtime function that allocates or frees memory, check
5083       // AAHeapToStack and AAHeapToShared.
5084       assert(
5085           (It->getSecond() == OMPRTL___kmpc_alloc_shared ||
5086            It->getSecond() == OMPRTL___kmpc_free_shared) &&
5087           "Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call");
5088 
5089       auto *HeapToStackAA = A.getAAFor<AAHeapToStack>(
5090           *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
5091       auto *HeapToSharedAA = A.getAAFor<AAHeapToShared>(
5092           *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
5093 
5094       RuntimeFunction RF = It->getSecond();
5095 
5096       switch (RF) {
5097       // If neither HeapToStack nor HeapToShared assume the call is removed,
5098       // assume SPMD incompatibility.
5099       case OMPRTL___kmpc_alloc_shared:
5100         if ((!HeapToStackAA || !HeapToStackAA->isAssumedHeapToStack(CB)) &&
5101             (!HeapToSharedAA || !HeapToSharedAA->isAssumedHeapToShared(CB)))
5102           SPMDCompatibilityTracker.insert(&CB);
5103         break;
5104       case OMPRTL___kmpc_free_shared:
5105         if ((!HeapToStackAA ||
5106              !HeapToStackAA->isAssumedHeapToStackRemovedFree(CB)) &&
5107             (!HeapToSharedAA ||
5108              !HeapToSharedAA->isAssumedHeapToSharedRemovedFree(CB)))
5109           SPMDCompatibilityTracker.insert(&CB);
5110         break;
5111       default:
5112         SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5113         SPMDCompatibilityTracker.insert(&CB);
5114       }
5115       return ChangeStatus::CHANGED;
5116     };
5117 
5118     const auto *AACE =
5119         A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::OPTIONAL);
5120     if (!AACE || !AACE->getState().isValidState() || AACE->hasUnknownCallee()) {
5121       if (Function *F = getAssociatedFunction())
5122         CheckCallee(F, /*NumCallees=*/1);
5123     } else {
5124       const auto &OptimisticEdges = AACE->getOptimisticEdges();
5125       for (auto *Callee : OptimisticEdges) {
5126         CheckCallee(Callee, OptimisticEdges.size());
5127         if (isAtFixpoint())
5128           break;
5129       }
5130     }
5131 
5132     return StateBefore == getState() ? ChangeStatus::UNCHANGED
5133                                      : ChangeStatus::CHANGED;
5134   }
5135 
5136   /// Deal with a __kmpc_parallel_51 call (\p CB). Returns true if the call was
5137   /// handled, if a problem occurred, false is returned.
5138   bool handleParallel51(Attributor &A, CallBase &CB) {
5139     const unsigned int NonWrapperFunctionArgNo = 5;
5140     const unsigned int WrapperFunctionArgNo = 6;
5141     auto ParallelRegionOpArgNo = SPMDCompatibilityTracker.isAssumed()
5142                                      ? NonWrapperFunctionArgNo
5143                                      : WrapperFunctionArgNo;
5144 
5145     auto *ParallelRegion = dyn_cast<Function>(
5146         CB.getArgOperand(ParallelRegionOpArgNo)->stripPointerCasts());
5147     if (!ParallelRegion)
5148       return false;
5149 
5150     ReachedKnownParallelRegions.insert(&CB);
5151     /// Check nested parallelism
5152     auto *FnAA = A.getAAFor<AAKernelInfo>(
5153         *this, IRPosition::function(*ParallelRegion), DepClassTy::OPTIONAL);
5154     NestedParallelism |= !FnAA || !FnAA->getState().isValidState() ||
5155                          !FnAA->ReachedKnownParallelRegions.empty() ||
5156                          !FnAA->ReachedKnownParallelRegions.isValidState() ||
5157                          !FnAA->ReachedUnknownParallelRegions.isValidState() ||
5158                          !FnAA->ReachedUnknownParallelRegions.empty();
5159     return true;
5160   }
5161 };
5162 
5163 struct AAFoldRuntimeCall
5164     : public StateWrapper<BooleanState, AbstractAttribute> {
5165   using Base = StateWrapper<BooleanState, AbstractAttribute>;
5166 
5167   AAFoldRuntimeCall(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
5168 
5169   /// Statistics are tracked as part of manifest for now.
5170   void trackStatistics() const override {}
5171 
5172   /// Create an abstract attribute biew for the position \p IRP.
5173   static AAFoldRuntimeCall &createForPosition(const IRPosition &IRP,
5174                                               Attributor &A);
5175 
5176   /// See AbstractAttribute::getName()
5177   const std::string getName() const override { return "AAFoldRuntimeCall"; }
5178 
5179   /// See AbstractAttribute::getIdAddr()
5180   const char *getIdAddr() const override { return &ID; }
5181 
5182   /// This function should return true if the type of the \p AA is
5183   /// AAFoldRuntimeCall
5184   static bool classof(const AbstractAttribute *AA) {
5185     return (AA->getIdAddr() == &ID);
5186   }
5187 
5188   static const char ID;
5189 };
5190 
5191 struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
5192   AAFoldRuntimeCallCallSiteReturned(const IRPosition &IRP, Attributor &A)
5193       : AAFoldRuntimeCall(IRP, A) {}
5194 
5195   /// See AbstractAttribute::getAsStr()
5196   const std::string getAsStr(Attributor *) const override {
5197     if (!isValidState())
5198       return "<invalid>";
5199 
5200     std::string Str("simplified value: ");
5201 
5202     if (!SimplifiedValue)
5203       return Str + std::string("none");
5204 
5205     if (!*SimplifiedValue)
5206       return Str + std::string("nullptr");
5207 
5208     if (ConstantInt *CI = dyn_cast<ConstantInt>(*SimplifiedValue))
5209       return Str + std::to_string(CI->getSExtValue());
5210 
5211     return Str + std::string("unknown");
5212   }
5213 
5214   void initialize(Attributor &A) override {
5215     if (DisableOpenMPOptFolding)
5216       indicatePessimisticFixpoint();
5217 
5218     Function *Callee = getAssociatedFunction();
5219 
5220     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
5221     const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
5222     assert(It != OMPInfoCache.RuntimeFunctionIDMap.end() &&
5223            "Expected a known OpenMP runtime function");
5224 
5225     RFKind = It->getSecond();
5226 
5227     CallBase &CB = cast<CallBase>(getAssociatedValue());
5228     A.registerSimplificationCallback(
5229         IRPosition::callsite_returned(CB),
5230         [&](const IRPosition &IRP, const AbstractAttribute *AA,
5231             bool &UsedAssumedInformation) -> std::optional<Value *> {
5232           assert((isValidState() ||
5233                   (SimplifiedValue && *SimplifiedValue == nullptr)) &&
5234                  "Unexpected invalid state!");
5235 
5236           if (!isAtFixpoint()) {
5237             UsedAssumedInformation = true;
5238             if (AA)
5239               A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
5240           }
5241           return SimplifiedValue;
5242         });
5243   }
5244 
5245   ChangeStatus updateImpl(Attributor &A) override {
5246     ChangeStatus Changed = ChangeStatus::UNCHANGED;
5247     switch (RFKind) {
5248     case OMPRTL___kmpc_is_spmd_exec_mode:
5249       Changed |= foldIsSPMDExecMode(A);
5250       break;
5251     case OMPRTL___kmpc_parallel_level:
5252       Changed |= foldParallelLevel(A);
5253       break;
5254     case OMPRTL___kmpc_get_hardware_num_threads_in_block:
5255       Changed = Changed | foldKernelFnAttribute(A, "omp_target_thread_limit");
5256       break;
5257     case OMPRTL___kmpc_get_hardware_num_blocks:
5258       Changed = Changed | foldKernelFnAttribute(A, "omp_target_num_teams");
5259       break;
5260     default:
5261       llvm_unreachable("Unhandled OpenMP runtime function!");
5262     }
5263 
5264     return Changed;
5265   }
5266 
5267   ChangeStatus manifest(Attributor &A) override {
5268     ChangeStatus Changed = ChangeStatus::UNCHANGED;
5269 
5270     if (SimplifiedValue && *SimplifiedValue) {
5271       Instruction &I = *getCtxI();
5272       A.changeAfterManifest(IRPosition::inst(I), **SimplifiedValue);
5273       A.deleteAfterManifest(I);
5274 
5275       CallBase *CB = dyn_cast<CallBase>(&I);
5276       auto Remark = [&](OptimizationRemark OR) {
5277         if (auto *C = dyn_cast<ConstantInt>(*SimplifiedValue))
5278           return OR << "Replacing OpenMP runtime call "
5279                     << CB->getCalledFunction()->getName() << " with "
5280                     << ore::NV("FoldedValue", C->getZExtValue()) << ".";
5281         return OR << "Replacing OpenMP runtime call "
5282                   << CB->getCalledFunction()->getName() << ".";
5283       };
5284 
5285       if (CB && EnableVerboseRemarks)
5286         A.emitRemark<OptimizationRemark>(CB, "OMP180", Remark);
5287 
5288       LLVM_DEBUG(dbgs() << TAG << "Replacing runtime call: " << I << " with "
5289                         << **SimplifiedValue << "\n");
5290 
5291       Changed = ChangeStatus::CHANGED;
5292     }
5293 
5294     return Changed;
5295   }
5296 
5297   ChangeStatus indicatePessimisticFixpoint() override {
5298     SimplifiedValue = nullptr;
5299     return AAFoldRuntimeCall::indicatePessimisticFixpoint();
5300   }
5301 
5302 private:
5303   /// Fold __kmpc_is_spmd_exec_mode into a constant if possible.
5304   ChangeStatus foldIsSPMDExecMode(Attributor &A) {
5305     std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
5306 
5307     unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
5308     unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
5309     auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
5310         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
5311 
5312     if (!CallerKernelInfoAA ||
5313         !CallerKernelInfoAA->ReachingKernelEntries.isValidState())
5314       return indicatePessimisticFixpoint();
5315 
5316     for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
5317       auto *AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
5318                                           DepClassTy::REQUIRED);
5319 
5320       if (!AA || !AA->isValidState()) {
5321         SimplifiedValue = nullptr;
5322         return indicatePessimisticFixpoint();
5323       }
5324 
5325       if (AA->SPMDCompatibilityTracker.isAssumed()) {
5326         if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5327           ++KnownSPMDCount;
5328         else
5329           ++AssumedSPMDCount;
5330       } else {
5331         if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5332           ++KnownNonSPMDCount;
5333         else
5334           ++AssumedNonSPMDCount;
5335       }
5336     }
5337 
5338     if ((AssumedSPMDCount + KnownSPMDCount) &&
5339         (AssumedNonSPMDCount + KnownNonSPMDCount))
5340       return indicatePessimisticFixpoint();
5341 
5342     auto &Ctx = getAnchorValue().getContext();
5343     if (KnownSPMDCount || AssumedSPMDCount) {
5344       assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
5345              "Expected only SPMD kernels!");
5346       // All reaching kernels are in SPMD mode. Update all function calls to
5347       // __kmpc_is_spmd_exec_mode to 1.
5348       SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), true);
5349     } else if (KnownNonSPMDCount || AssumedNonSPMDCount) {
5350       assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
5351              "Expected only non-SPMD kernels!");
5352       // All reaching kernels are in non-SPMD mode. Update all function
5353       // calls to __kmpc_is_spmd_exec_mode to 0.
5354       SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), false);
5355     } else {
5356       // We have empty reaching kernels, therefore we cannot tell if the
5357       // associated call site can be folded. At this moment, SimplifiedValue
5358       // must be none.
5359       assert(!SimplifiedValue && "SimplifiedValue should be none");
5360     }
5361 
5362     return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
5363                                                     : ChangeStatus::CHANGED;
5364   }
5365 
5366   /// Fold __kmpc_parallel_level into a constant if possible.
5367   ChangeStatus foldParallelLevel(Attributor &A) {
5368     std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
5369 
5370     auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
5371         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
5372 
5373     if (!CallerKernelInfoAA ||
5374         !CallerKernelInfoAA->ParallelLevels.isValidState())
5375       return indicatePessimisticFixpoint();
5376 
5377     if (!CallerKernelInfoAA->ReachingKernelEntries.isValidState())
5378       return indicatePessimisticFixpoint();
5379 
5380     if (CallerKernelInfoAA->ReachingKernelEntries.empty()) {
5381       assert(!SimplifiedValue &&
5382              "SimplifiedValue should keep none at this point");
5383       return ChangeStatus::UNCHANGED;
5384     }
5385 
5386     unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
5387     unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
5388     for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
5389       auto *AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
5390                                           DepClassTy::REQUIRED);
5391       if (!AA || !AA->SPMDCompatibilityTracker.isValidState())
5392         return indicatePessimisticFixpoint();
5393 
5394       if (AA->SPMDCompatibilityTracker.isAssumed()) {
5395         if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5396           ++KnownSPMDCount;
5397         else
5398           ++AssumedSPMDCount;
5399       } else {
5400         if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5401           ++KnownNonSPMDCount;
5402         else
5403           ++AssumedNonSPMDCount;
5404       }
5405     }
5406 
5407     if ((AssumedSPMDCount + KnownSPMDCount) &&
5408         (AssumedNonSPMDCount + KnownNonSPMDCount))
5409       return indicatePessimisticFixpoint();
5410 
5411     auto &Ctx = getAnchorValue().getContext();
5412     // If the caller can only be reached by SPMD kernel entries, the parallel
5413     // level is 1. Similarly, if the caller can only be reached by non-SPMD
5414     // kernel entries, it is 0.
5415     if (AssumedSPMDCount || KnownSPMDCount) {
5416       assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
5417              "Expected only SPMD kernels!");
5418       SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 1);
5419     } else {
5420       assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
5421              "Expected only non-SPMD kernels!");
5422       SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 0);
5423     }
5424     return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
5425                                                     : ChangeStatus::CHANGED;
5426   }
5427 
5428   ChangeStatus foldKernelFnAttribute(Attributor &A, llvm::StringRef Attr) {
5429     // Specialize only if all the calls agree with the attribute constant value
5430     int32_t CurrentAttrValue = -1;
5431     std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
5432 
5433     auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
5434         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
5435 
5436     if (!CallerKernelInfoAA ||
5437         !CallerKernelInfoAA->ReachingKernelEntries.isValidState())
5438       return indicatePessimisticFixpoint();
5439 
5440     // Iterate over the kernels that reach this function
5441     for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
5442       int32_t NextAttrVal = K->getFnAttributeAsParsedInteger(Attr, -1);
5443 
5444       if (NextAttrVal == -1 ||
5445           (CurrentAttrValue != -1 && CurrentAttrValue != NextAttrVal))
5446         return indicatePessimisticFixpoint();
5447       CurrentAttrValue = NextAttrVal;
5448     }
5449 
5450     if (CurrentAttrValue != -1) {
5451       auto &Ctx = getAnchorValue().getContext();
5452       SimplifiedValue =
5453           ConstantInt::get(Type::getInt32Ty(Ctx), CurrentAttrValue);
5454     }
5455     return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
5456                                                     : ChangeStatus::CHANGED;
5457   }
5458 
5459   /// An optional value the associated value is assumed to fold to. That is, we
5460   /// assume the associated value (which is a call) can be replaced by this
5461   /// simplified value.
5462   std::optional<Value *> SimplifiedValue;
5463 
5464   /// The runtime function kind of the callee of the associated call site.
5465   RuntimeFunction RFKind;
5466 };
5467 
5468 } // namespace
5469 
5470 /// Register folding callsite
5471 void OpenMPOpt::registerFoldRuntimeCall(RuntimeFunction RF) {
5472   auto &RFI = OMPInfoCache.RFIs[RF];
5473   RFI.foreachUse(SCC, [&](Use &U, Function &F) {
5474     CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &RFI);
5475     if (!CI)
5476       return false;
5477     A.getOrCreateAAFor<AAFoldRuntimeCall>(
5478         IRPosition::callsite_returned(*CI), /* QueryingAA */ nullptr,
5479         DepClassTy::NONE, /* ForceUpdate */ false,
5480         /* UpdateAfterInit */ false);
5481     return false;
5482   });
5483 }
5484 
5485 void OpenMPOpt::registerAAs(bool IsModulePass) {
5486   if (SCC.empty())
5487     return;
5488 
5489   if (IsModulePass) {
5490     // Ensure we create the AAKernelInfo AAs first and without triggering an
5491     // update. This will make sure we register all value simplification
5492     // callbacks before any other AA has the chance to create an AAValueSimplify
5493     // or similar.
5494     auto CreateKernelInfoCB = [&](Use &, Function &Kernel) {
5495       A.getOrCreateAAFor<AAKernelInfo>(
5496           IRPosition::function(Kernel), /* QueryingAA */ nullptr,
5497           DepClassTy::NONE, /* ForceUpdate */ false,
5498           /* UpdateAfterInit */ false);
5499       return false;
5500     };
5501     OMPInformationCache::RuntimeFunctionInfo &InitRFI =
5502         OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
5503     InitRFI.foreachUse(SCC, CreateKernelInfoCB);
5504 
5505     registerFoldRuntimeCall(OMPRTL___kmpc_is_spmd_exec_mode);
5506     registerFoldRuntimeCall(OMPRTL___kmpc_parallel_level);
5507     registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_threads_in_block);
5508     registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_blocks);
5509   }
5510 
5511   // Create CallSite AA for all Getters.
5512   if (DeduceICVValues) {
5513     for (int Idx = 0; Idx < OMPInfoCache.ICVs.size() - 1; ++Idx) {
5514       auto ICVInfo = OMPInfoCache.ICVs[static_cast<InternalControlVar>(Idx)];
5515 
5516       auto &GetterRFI = OMPInfoCache.RFIs[ICVInfo.Getter];
5517 
5518       auto CreateAA = [&](Use &U, Function &Caller) {
5519         CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &GetterRFI);
5520         if (!CI)
5521           return false;
5522 
5523         auto &CB = cast<CallBase>(*CI);
5524 
5525         IRPosition CBPos = IRPosition::callsite_function(CB);
5526         A.getOrCreateAAFor<AAICVTracker>(CBPos);
5527         return false;
5528       };
5529 
5530       GetterRFI.foreachUse(SCC, CreateAA);
5531     }
5532   }
5533 
5534   // Create an ExecutionDomain AA for every function and a HeapToStack AA for
5535   // every function if there is a device kernel.
5536   if (!isOpenMPDevice(M))
5537     return;
5538 
5539   for (auto *F : SCC) {
5540     if (F->isDeclaration())
5541       continue;
5542 
5543     // We look at internal functions only on-demand but if any use is not a
5544     // direct call or outside the current set of analyzed functions, we have
5545     // to do it eagerly.
5546     if (F->hasLocalLinkage()) {
5547       if (llvm::all_of(F->uses(), [this](const Use &U) {
5548             const auto *CB = dyn_cast<CallBase>(U.getUser());
5549             return CB && CB->isCallee(&U) &&
5550                    A.isRunOn(const_cast<Function *>(CB->getCaller()));
5551           }))
5552         continue;
5553     }
5554     registerAAsForFunction(A, *F);
5555   }
5556 }
5557 
5558 void OpenMPOpt::registerAAsForFunction(Attributor &A, const Function &F) {
5559   if (!DisableOpenMPOptDeglobalization)
5560     A.getOrCreateAAFor<AAHeapToShared>(IRPosition::function(F));
5561   A.getOrCreateAAFor<AAExecutionDomain>(IRPosition::function(F));
5562   if (!DisableOpenMPOptDeglobalization)
5563     A.getOrCreateAAFor<AAHeapToStack>(IRPosition::function(F));
5564   if (F.hasFnAttribute(Attribute::Convergent))
5565     A.getOrCreateAAFor<AANonConvergent>(IRPosition::function(F));
5566 
5567   for (auto &I : instructions(F)) {
5568     if (auto *LI = dyn_cast<LoadInst>(&I)) {
5569       bool UsedAssumedInformation = false;
5570       A.getAssumedSimplified(IRPosition::value(*LI), /* AA */ nullptr,
5571                              UsedAssumedInformation, AA::Interprocedural);
5572       continue;
5573     }
5574     if (auto *CI = dyn_cast<CallBase>(&I)) {
5575       if (CI->isIndirectCall())
5576         A.getOrCreateAAFor<AAIndirectCallInfo>(
5577             IRPosition::callsite_function(*CI));
5578     }
5579     if (auto *SI = dyn_cast<StoreInst>(&I)) {
5580       A.getOrCreateAAFor<AAIsDead>(IRPosition::value(*SI));
5581       continue;
5582     }
5583     if (auto *FI = dyn_cast<FenceInst>(&I)) {
5584       A.getOrCreateAAFor<AAIsDead>(IRPosition::value(*FI));
5585       continue;
5586     }
5587     if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
5588       if (II->getIntrinsicID() == Intrinsic::assume) {
5589         A.getOrCreateAAFor<AAPotentialValues>(
5590             IRPosition::value(*II->getArgOperand(0)));
5591         continue;
5592       }
5593     }
5594   }
5595 }
5596 
5597 const char AAICVTracker::ID = 0;
5598 const char AAKernelInfo::ID = 0;
5599 const char AAExecutionDomain::ID = 0;
5600 const char AAHeapToShared::ID = 0;
5601 const char AAFoldRuntimeCall::ID = 0;
5602 
5603 AAICVTracker &AAICVTracker::createForPosition(const IRPosition &IRP,
5604                                               Attributor &A) {
5605   AAICVTracker *AA = nullptr;
5606   switch (IRP.getPositionKind()) {
5607   case IRPosition::IRP_INVALID:
5608   case IRPosition::IRP_FLOAT:
5609   case IRPosition::IRP_ARGUMENT:
5610   case IRPosition::IRP_CALL_SITE_ARGUMENT:
5611     llvm_unreachable("ICVTracker can only be created for function position!");
5612   case IRPosition::IRP_RETURNED:
5613     AA = new (A.Allocator) AAICVTrackerFunctionReturned(IRP, A);
5614     break;
5615   case IRPosition::IRP_CALL_SITE_RETURNED:
5616     AA = new (A.Allocator) AAICVTrackerCallSiteReturned(IRP, A);
5617     break;
5618   case IRPosition::IRP_CALL_SITE:
5619     AA = new (A.Allocator) AAICVTrackerCallSite(IRP, A);
5620     break;
5621   case IRPosition::IRP_FUNCTION:
5622     AA = new (A.Allocator) AAICVTrackerFunction(IRP, A);
5623     break;
5624   }
5625 
5626   return *AA;
5627 }
5628 
5629 AAExecutionDomain &AAExecutionDomain::createForPosition(const IRPosition &IRP,
5630                                                         Attributor &A) {
5631   AAExecutionDomainFunction *AA = nullptr;
5632   switch (IRP.getPositionKind()) {
5633   case IRPosition::IRP_INVALID:
5634   case IRPosition::IRP_FLOAT:
5635   case IRPosition::IRP_ARGUMENT:
5636   case IRPosition::IRP_CALL_SITE_ARGUMENT:
5637   case IRPosition::IRP_RETURNED:
5638   case IRPosition::IRP_CALL_SITE_RETURNED:
5639   case IRPosition::IRP_CALL_SITE:
5640     llvm_unreachable(
5641         "AAExecutionDomain can only be created for function position!");
5642   case IRPosition::IRP_FUNCTION:
5643     AA = new (A.Allocator) AAExecutionDomainFunction(IRP, A);
5644     break;
5645   }
5646 
5647   return *AA;
5648 }
5649 
5650 AAHeapToShared &AAHeapToShared::createForPosition(const IRPosition &IRP,
5651                                                   Attributor &A) {
5652   AAHeapToSharedFunction *AA = nullptr;
5653   switch (IRP.getPositionKind()) {
5654   case IRPosition::IRP_INVALID:
5655   case IRPosition::IRP_FLOAT:
5656   case IRPosition::IRP_ARGUMENT:
5657   case IRPosition::IRP_CALL_SITE_ARGUMENT:
5658   case IRPosition::IRP_RETURNED:
5659   case IRPosition::IRP_CALL_SITE_RETURNED:
5660   case IRPosition::IRP_CALL_SITE:
5661     llvm_unreachable(
5662         "AAHeapToShared can only be created for function position!");
5663   case IRPosition::IRP_FUNCTION:
5664     AA = new (A.Allocator) AAHeapToSharedFunction(IRP, A);
5665     break;
5666   }
5667 
5668   return *AA;
5669 }
5670 
5671 AAKernelInfo &AAKernelInfo::createForPosition(const IRPosition &IRP,
5672                                               Attributor &A) {
5673   AAKernelInfo *AA = nullptr;
5674   switch (IRP.getPositionKind()) {
5675   case IRPosition::IRP_INVALID:
5676   case IRPosition::IRP_FLOAT:
5677   case IRPosition::IRP_ARGUMENT:
5678   case IRPosition::IRP_RETURNED:
5679   case IRPosition::IRP_CALL_SITE_RETURNED:
5680   case IRPosition::IRP_CALL_SITE_ARGUMENT:
5681     llvm_unreachable("KernelInfo can only be created for function position!");
5682   case IRPosition::IRP_CALL_SITE:
5683     AA = new (A.Allocator) AAKernelInfoCallSite(IRP, A);
5684     break;
5685   case IRPosition::IRP_FUNCTION:
5686     AA = new (A.Allocator) AAKernelInfoFunction(IRP, A);
5687     break;
5688   }
5689 
5690   return *AA;
5691 }
5692 
5693 AAFoldRuntimeCall &AAFoldRuntimeCall::createForPosition(const IRPosition &IRP,
5694                                                         Attributor &A) {
5695   AAFoldRuntimeCall *AA = nullptr;
5696   switch (IRP.getPositionKind()) {
5697   case IRPosition::IRP_INVALID:
5698   case IRPosition::IRP_FLOAT:
5699   case IRPosition::IRP_ARGUMENT:
5700   case IRPosition::IRP_RETURNED:
5701   case IRPosition::IRP_FUNCTION:
5702   case IRPosition::IRP_CALL_SITE:
5703   case IRPosition::IRP_CALL_SITE_ARGUMENT:
5704     llvm_unreachable("KernelInfo can only be created for call site position!");
5705   case IRPosition::IRP_CALL_SITE_RETURNED:
5706     AA = new (A.Allocator) AAFoldRuntimeCallCallSiteReturned(IRP, A);
5707     break;
5708   }
5709 
5710   return *AA;
5711 }
5712 
5713 PreservedAnalyses OpenMPOptPass::run(Module &M, ModuleAnalysisManager &AM) {
5714   if (!containsOpenMP(M))
5715     return PreservedAnalyses::all();
5716   if (DisableOpenMPOptimizations)
5717     return PreservedAnalyses::all();
5718 
5719   FunctionAnalysisManager &FAM =
5720       AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
5721   KernelSet Kernels = getDeviceKernels(M);
5722 
5723   if (PrintModuleBeforeOptimizations)
5724     LLVM_DEBUG(dbgs() << TAG << "Module before OpenMPOpt Module Pass:\n" << M);
5725 
5726   auto IsCalled = [&](Function &F) {
5727     if (Kernels.contains(&F))
5728       return true;
5729     for (const User *U : F.users())
5730       if (!isa<BlockAddress>(U))
5731         return true;
5732     return false;
5733   };
5734 
5735   auto EmitRemark = [&](Function &F) {
5736     auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F);
5737     ORE.emit([&]() {
5738       OptimizationRemarkAnalysis ORA(DEBUG_TYPE, "OMP140", &F);
5739       return ORA << "Could not internalize function. "
5740                  << "Some optimizations may not be possible. [OMP140]";
5741     });
5742   };
5743 
5744   bool Changed = false;
5745 
5746   // Create internal copies of each function if this is a kernel Module. This
5747   // allows iterprocedural passes to see every call edge.
5748   DenseMap<Function *, Function *> InternalizedMap;
5749   if (isOpenMPDevice(M)) {
5750     SmallPtrSet<Function *, 16> InternalizeFns;
5751     for (Function &F : M)
5752       if (!F.isDeclaration() && !Kernels.contains(&F) && IsCalled(F) &&
5753           !DisableInternalization) {
5754         if (Attributor::isInternalizable(F)) {
5755           InternalizeFns.insert(&F);
5756         } else if (!F.hasLocalLinkage() && !F.hasFnAttribute(Attribute::Cold)) {
5757           EmitRemark(F);
5758         }
5759       }
5760 
5761     Changed |=
5762         Attributor::internalizeFunctions(InternalizeFns, InternalizedMap);
5763   }
5764 
5765   // Look at every function in the Module unless it was internalized.
5766   SetVector<Function *> Functions;
5767   SmallVector<Function *, 16> SCC;
5768   for (Function &F : M)
5769     if (!F.isDeclaration() && !InternalizedMap.lookup(&F)) {
5770       SCC.push_back(&F);
5771       Functions.insert(&F);
5772     }
5773 
5774   if (SCC.empty())
5775     return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
5776 
5777   AnalysisGetter AG(FAM);
5778 
5779   auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
5780     return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);
5781   };
5782 
5783   BumpPtrAllocator Allocator;
5784   CallGraphUpdater CGUpdater;
5785 
5786   bool PostLink = LTOPhase == ThinOrFullLTOPhase::FullLTOPostLink ||
5787                   LTOPhase == ThinOrFullLTOPhase::ThinLTOPreLink;
5788   OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ nullptr, PostLink);
5789 
5790   unsigned MaxFixpointIterations =
5791       (isOpenMPDevice(M)) ? SetFixpointIterations : 32;
5792 
5793   AttributorConfig AC(CGUpdater);
5794   AC.DefaultInitializeLiveInternals = false;
5795   AC.IsModulePass = true;
5796   AC.RewriteSignatures = false;
5797   AC.MaxFixpointIterations = MaxFixpointIterations;
5798   AC.OREGetter = OREGetter;
5799   AC.PassName = DEBUG_TYPE;
5800   AC.InitializationCallback = OpenMPOpt::registerAAsForFunction;
5801   AC.IPOAmendableCB = [](const Function &F) {
5802     return F.hasFnAttribute("kernel");
5803   };
5804 
5805   Attributor A(Functions, InfoCache, AC);
5806 
5807   OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
5808   Changed |= OMPOpt.run(true);
5809 
5810   // Optionally inline device functions for potentially better performance.
5811   if (AlwaysInlineDeviceFunctions && isOpenMPDevice(M))
5812     for (Function &F : M)
5813       if (!F.isDeclaration() && !Kernels.contains(&F) &&
5814           !F.hasFnAttribute(Attribute::NoInline))
5815         F.addFnAttr(Attribute::AlwaysInline);
5816 
5817   if (PrintModuleAfterOptimizations)
5818     LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt Module Pass:\n" << M);
5819 
5820   if (Changed)
5821     return PreservedAnalyses::none();
5822 
5823   return PreservedAnalyses::all();
5824 }
5825 
5826 PreservedAnalyses OpenMPOptCGSCCPass::run(LazyCallGraph::SCC &C,
5827                                           CGSCCAnalysisManager &AM,
5828                                           LazyCallGraph &CG,
5829                                           CGSCCUpdateResult &UR) {
5830   if (!containsOpenMP(*C.begin()->getFunction().getParent()))
5831     return PreservedAnalyses::all();
5832   if (DisableOpenMPOptimizations)
5833     return PreservedAnalyses::all();
5834 
5835   SmallVector<Function *, 16> SCC;
5836   // If there are kernels in the module, we have to run on all SCC's.
5837   for (LazyCallGraph::Node &N : C) {
5838     Function *Fn = &N.getFunction();
5839     SCC.push_back(Fn);
5840   }
5841 
5842   if (SCC.empty())
5843     return PreservedAnalyses::all();
5844 
5845   Module &M = *C.begin()->getFunction().getParent();
5846 
5847   if (PrintModuleBeforeOptimizations)
5848     LLVM_DEBUG(dbgs() << TAG << "Module before OpenMPOpt CGSCC Pass:\n" << M);
5849 
5850   KernelSet Kernels = getDeviceKernels(M);
5851 
5852   FunctionAnalysisManager &FAM =
5853       AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();
5854 
5855   AnalysisGetter AG(FAM);
5856 
5857   auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
5858     return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);
5859   };
5860 
5861   BumpPtrAllocator Allocator;
5862   CallGraphUpdater CGUpdater;
5863   CGUpdater.initialize(CG, C, AM, UR);
5864 
5865   bool PostLink = LTOPhase == ThinOrFullLTOPhase::FullLTOPostLink ||
5866                   LTOPhase == ThinOrFullLTOPhase::ThinLTOPreLink;
5867   SetVector<Function *> Functions(SCC.begin(), SCC.end());
5868   OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator,
5869                                 /*CGSCC*/ &Functions, PostLink);
5870 
5871   unsigned MaxFixpointIterations =
5872       (isOpenMPDevice(M)) ? SetFixpointIterations : 32;
5873 
5874   AttributorConfig AC(CGUpdater);
5875   AC.DefaultInitializeLiveInternals = false;
5876   AC.IsModulePass = false;
5877   AC.RewriteSignatures = false;
5878   AC.MaxFixpointIterations = MaxFixpointIterations;
5879   AC.OREGetter = OREGetter;
5880   AC.PassName = DEBUG_TYPE;
5881   AC.InitializationCallback = OpenMPOpt::registerAAsForFunction;
5882 
5883   Attributor A(Functions, InfoCache, AC);
5884 
5885   OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
5886   bool Changed = OMPOpt.run(false);
5887 
5888   if (PrintModuleAfterOptimizations)
5889     LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt CGSCC Pass:\n" << M);
5890 
5891   if (Changed)
5892     return PreservedAnalyses::none();
5893 
5894   return PreservedAnalyses::all();
5895 }
5896 
5897 bool llvm::omp::isOpenMPKernel(Function &Fn) {
5898   return Fn.hasFnAttribute("kernel");
5899 }
5900 
5901 KernelSet llvm::omp::getDeviceKernels(Module &M) {
5902   // TODO: Create a more cross-platform way of determining device kernels.
5903   NamedMDNode *MD = M.getNamedMetadata("nvvm.annotations");
5904   KernelSet Kernels;
5905 
5906   if (!MD)
5907     return Kernels;
5908 
5909   for (auto *Op : MD->operands()) {
5910     if (Op->getNumOperands() < 2)
5911       continue;
5912     MDString *KindID = dyn_cast<MDString>(Op->getOperand(1));
5913     if (!KindID || KindID->getString() != "kernel")
5914       continue;
5915 
5916     Function *KernelFn =
5917         mdconst::dyn_extract_or_null<Function>(Op->getOperand(0));
5918     if (!KernelFn)
5919       continue;
5920 
5921     // We are only interested in OpenMP target regions. Others, such as kernels
5922     // generated by CUDA but linked together, are not interesting to this pass.
5923     if (isOpenMPKernel(*KernelFn)) {
5924       ++NumOpenMPTargetRegionKernels;
5925       Kernels.insert(KernelFn);
5926     } else
5927       ++NumNonOpenMPTargetRegionKernels;
5928   }
5929 
5930   return Kernels;
5931 }
5932 
5933 bool llvm::omp::containsOpenMP(Module &M) {
5934   Metadata *MD = M.getModuleFlag("openmp");
5935   if (!MD)
5936     return false;
5937 
5938   return true;
5939 }
5940 
5941 bool llvm::omp::isOpenMPDevice(Module &M) {
5942   Metadata *MD = M.getModuleFlag("openmp-device");
5943   if (!MD)
5944     return false;
5945 
5946   return true;
5947 }
5948