xref: /freebsd/contrib/llvm-project/llvm/lib/Transforms/IPO/OpenMPOpt.cpp (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1  //===-- IPO/OpenMPOpt.cpp - Collection of OpenMP specific optimizations ---===//
2  //
3  // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4  // See https://llvm.org/LICENSE.txt for license information.
5  // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6  //
7  //===----------------------------------------------------------------------===//
8  //
9  // OpenMP specific optimizations:
10  //
11  // - Deduplication of runtime calls, e.g., omp_get_thread_num.
12  // - Replacing globalized device memory with stack memory.
13  // - Replacing globalized device memory with shared memory.
14  // - Parallel region merging.
15  // - Transforming generic-mode device kernels to SPMD mode.
16  // - Specializing the state machine for generic-mode device kernels.
17  //
18  //===----------------------------------------------------------------------===//
19  
20  #include "llvm/Transforms/IPO/OpenMPOpt.h"
21  
22  #include "llvm/ADT/EnumeratedArray.h"
23  #include "llvm/ADT/PostOrderIterator.h"
24  #include "llvm/ADT/SetVector.h"
25  #include "llvm/ADT/SmallPtrSet.h"
26  #include "llvm/ADT/SmallVector.h"
27  #include "llvm/ADT/Statistic.h"
28  #include "llvm/ADT/StringExtras.h"
29  #include "llvm/ADT/StringRef.h"
30  #include "llvm/Analysis/CallGraph.h"
31  #include "llvm/Analysis/CallGraphSCCPass.h"
32  #include "llvm/Analysis/MemoryLocation.h"
33  #include "llvm/Analysis/OptimizationRemarkEmitter.h"
34  #include "llvm/Analysis/ValueTracking.h"
35  #include "llvm/Frontend/OpenMP/OMPConstants.h"
36  #include "llvm/Frontend/OpenMP/OMPDeviceConstants.h"
37  #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
38  #include "llvm/IR/Assumptions.h"
39  #include "llvm/IR/BasicBlock.h"
40  #include "llvm/IR/Constants.h"
41  #include "llvm/IR/DiagnosticInfo.h"
42  #include "llvm/IR/Dominators.h"
43  #include "llvm/IR/Function.h"
44  #include "llvm/IR/GlobalValue.h"
45  #include "llvm/IR/GlobalVariable.h"
46  #include "llvm/IR/InstrTypes.h"
47  #include "llvm/IR/Instruction.h"
48  #include "llvm/IR/Instructions.h"
49  #include "llvm/IR/IntrinsicInst.h"
50  #include "llvm/IR/IntrinsicsAMDGPU.h"
51  #include "llvm/IR/IntrinsicsNVPTX.h"
52  #include "llvm/IR/LLVMContext.h"
53  #include "llvm/Support/Casting.h"
54  #include "llvm/Support/CommandLine.h"
55  #include "llvm/Support/Debug.h"
56  #include "llvm/Transforms/IPO/Attributor.h"
57  #include "llvm/Transforms/Utils/BasicBlockUtils.h"
58  #include "llvm/Transforms/Utils/CallGraphUpdater.h"
59  
60  #include <algorithm>
61  #include <optional>
62  #include <string>
63  
64  using namespace llvm;
65  using namespace omp;
66  
67  #define DEBUG_TYPE "openmp-opt"
68  
69  static cl::opt<bool> DisableOpenMPOptimizations(
70      "openmp-opt-disable", cl::desc("Disable OpenMP specific optimizations."),
71      cl::Hidden, cl::init(false));
72  
73  static cl::opt<bool> EnableParallelRegionMerging(
74      "openmp-opt-enable-merging",
75      cl::desc("Enable the OpenMP region merging optimization."), cl::Hidden,
76      cl::init(false));
77  
78  static cl::opt<bool>
79      DisableInternalization("openmp-opt-disable-internalization",
80                             cl::desc("Disable function internalization."),
81                             cl::Hidden, cl::init(false));
82  
83  static cl::opt<bool> DeduceICVValues("openmp-deduce-icv-values",
84                                       cl::init(false), cl::Hidden);
85  static cl::opt<bool> PrintICVValues("openmp-print-icv-values", cl::init(false),
86                                      cl::Hidden);
87  static cl::opt<bool> PrintOpenMPKernels("openmp-print-gpu-kernels",
88                                          cl::init(false), cl::Hidden);
89  
90  static cl::opt<bool> HideMemoryTransferLatency(
91      "openmp-hide-memory-transfer-latency",
92      cl::desc("[WIP] Tries to hide the latency of host to device memory"
93               " transfers"),
94      cl::Hidden, cl::init(false));
95  
96  static cl::opt<bool> DisableOpenMPOptDeglobalization(
97      "openmp-opt-disable-deglobalization",
98      cl::desc("Disable OpenMP optimizations involving deglobalization."),
99      cl::Hidden, cl::init(false));
100  
101  static cl::opt<bool> DisableOpenMPOptSPMDization(
102      "openmp-opt-disable-spmdization",
103      cl::desc("Disable OpenMP optimizations involving SPMD-ization."),
104      cl::Hidden, cl::init(false));
105  
106  static cl::opt<bool> DisableOpenMPOptFolding(
107      "openmp-opt-disable-folding",
108      cl::desc("Disable OpenMP optimizations involving folding."), cl::Hidden,
109      cl::init(false));
110  
111  static cl::opt<bool> DisableOpenMPOptStateMachineRewrite(
112      "openmp-opt-disable-state-machine-rewrite",
113      cl::desc("Disable OpenMP optimizations that replace the state machine."),
114      cl::Hidden, cl::init(false));
115  
116  static cl::opt<bool> DisableOpenMPOptBarrierElimination(
117      "openmp-opt-disable-barrier-elimination",
118      cl::desc("Disable OpenMP optimizations that eliminate barriers."),
119      cl::Hidden, cl::init(false));
120  
121  static cl::opt<bool> PrintModuleAfterOptimizations(
122      "openmp-opt-print-module-after",
123      cl::desc("Print the current module after OpenMP optimizations."),
124      cl::Hidden, cl::init(false));
125  
126  static cl::opt<bool> PrintModuleBeforeOptimizations(
127      "openmp-opt-print-module-before",
128      cl::desc("Print the current module before OpenMP optimizations."),
129      cl::Hidden, cl::init(false));
130  
131  static cl::opt<bool> AlwaysInlineDeviceFunctions(
132      "openmp-opt-inline-device",
133      cl::desc("Inline all applicible functions on the device."), cl::Hidden,
134      cl::init(false));
135  
136  static cl::opt<bool>
137      EnableVerboseRemarks("openmp-opt-verbose-remarks",
138                           cl::desc("Enables more verbose remarks."), cl::Hidden,
139                           cl::init(false));
140  
141  static cl::opt<unsigned>
142      SetFixpointIterations("openmp-opt-max-iterations", cl::Hidden,
143                            cl::desc("Maximal number of attributor iterations."),
144                            cl::init(256));
145  
146  static cl::opt<unsigned>
147      SharedMemoryLimit("openmp-opt-shared-limit", cl::Hidden,
148                        cl::desc("Maximum amount of shared memory to use."),
149                        cl::init(std::numeric_limits<unsigned>::max()));
150  
151  STATISTIC(NumOpenMPRuntimeCallsDeduplicated,
152            "Number of OpenMP runtime calls deduplicated");
153  STATISTIC(NumOpenMPParallelRegionsDeleted,
154            "Number of OpenMP parallel regions deleted");
155  STATISTIC(NumOpenMPRuntimeFunctionsIdentified,
156            "Number of OpenMP runtime functions identified");
157  STATISTIC(NumOpenMPRuntimeFunctionUsesIdentified,
158            "Number of OpenMP runtime function uses identified");
159  STATISTIC(NumOpenMPTargetRegionKernels,
160            "Number of OpenMP target region entry points (=kernels) identified");
161  STATISTIC(NumNonOpenMPTargetRegionKernels,
162            "Number of non-OpenMP target region kernels identified");
163  STATISTIC(NumOpenMPTargetRegionKernelsSPMD,
164            "Number of OpenMP target region entry points (=kernels) executed in "
165            "SPMD-mode instead of generic-mode");
166  STATISTIC(NumOpenMPTargetRegionKernelsWithoutStateMachine,
167            "Number of OpenMP target region entry points (=kernels) executed in "
168            "generic-mode without a state machines");
169  STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback,
170            "Number of OpenMP target region entry points (=kernels) executed in "
171            "generic-mode with customized state machines with fallback");
172  STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback,
173            "Number of OpenMP target region entry points (=kernels) executed in "
174            "generic-mode with customized state machines without fallback");
175  STATISTIC(
176      NumOpenMPParallelRegionsReplacedInGPUStateMachine,
177      "Number of OpenMP parallel regions replaced with ID in GPU state machines");
178  STATISTIC(NumOpenMPParallelRegionsMerged,
179            "Number of OpenMP parallel regions merged");
180  STATISTIC(NumBytesMovedToSharedMemory,
181            "Amount of memory pushed to shared memory");
182  STATISTIC(NumBarriersEliminated, "Number of redundant barriers eliminated");
183  
184  #if !defined(NDEBUG)
185  static constexpr auto TAG = "[" DEBUG_TYPE "]";
186  #endif
187  
188  namespace KernelInfo {
189  
190  // struct ConfigurationEnvironmentTy {
191  //   uint8_t UseGenericStateMachine;
192  //   uint8_t MayUseNestedParallelism;
193  //   llvm::omp::OMPTgtExecModeFlags ExecMode;
194  //   int32_t MinThreads;
195  //   int32_t MaxThreads;
196  //   int32_t MinTeams;
197  //   int32_t MaxTeams;
198  // };
199  
200  // struct DynamicEnvironmentTy {
201  //   uint16_t DebugIndentionLevel;
202  // };
203  
204  // struct KernelEnvironmentTy {
205  //   ConfigurationEnvironmentTy Configuration;
206  //   IdentTy *Ident;
207  //   DynamicEnvironmentTy *DynamicEnv;
208  // };
209  
210  #define KERNEL_ENVIRONMENT_IDX(MEMBER, IDX)                                    \
211    constexpr const unsigned MEMBER##Idx = IDX;
212  
213  KERNEL_ENVIRONMENT_IDX(Configuration, 0)
214  KERNEL_ENVIRONMENT_IDX(Ident, 1)
215  
216  #undef KERNEL_ENVIRONMENT_IDX
217  
218  #define KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MEMBER, IDX)                      \
219    constexpr const unsigned MEMBER##Idx = IDX;
220  
221  KERNEL_ENVIRONMENT_CONFIGURATION_IDX(UseGenericStateMachine, 0)
222  KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MayUseNestedParallelism, 1)
223  KERNEL_ENVIRONMENT_CONFIGURATION_IDX(ExecMode, 2)
224  KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MinThreads, 3)
225  KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MaxThreads, 4)
226  KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MinTeams, 5)
227  KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MaxTeams, 6)
228  
229  #undef KERNEL_ENVIRONMENT_CONFIGURATION_IDX
230  
231  #define KERNEL_ENVIRONMENT_GETTER(MEMBER, RETURNTYPE)                          \
232    RETURNTYPE *get##MEMBER##FromKernelEnvironment(ConstantStruct *KernelEnvC) { \
233      return cast<RETURNTYPE>(KernelEnvC->getAggregateElement(MEMBER##Idx));     \
234    }
235  
KERNEL_ENVIRONMENT_GETTER(Ident,Constant)236  KERNEL_ENVIRONMENT_GETTER(Ident, Constant)
237  KERNEL_ENVIRONMENT_GETTER(Configuration, ConstantStruct)
238  
239  #undef KERNEL_ENVIRONMENT_GETTER
240  
241  #define KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MEMBER)                        \
242    ConstantInt *get##MEMBER##FromKernelEnvironment(                             \
243        ConstantStruct *KernelEnvC) {                                            \
244      ConstantStruct *ConfigC =                                                  \
245          getConfigurationFromKernelEnvironment(KernelEnvC);                     \
246      return dyn_cast<ConstantInt>(ConfigC->getAggregateElement(MEMBER##Idx));   \
247    }
248  
249  KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(UseGenericStateMachine)
250  KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MayUseNestedParallelism)
251  KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(ExecMode)
252  KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MinThreads)
253  KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MaxThreads)
254  KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MinTeams)
255  KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MaxTeams)
256  
257  #undef KERNEL_ENVIRONMENT_CONFIGURATION_GETTER
258  
259  GlobalVariable *
260  getKernelEnvironementGVFromKernelInitCB(CallBase *KernelInitCB) {
261    constexpr const int InitKernelEnvironmentArgNo = 0;
262    return cast<GlobalVariable>(
263        KernelInitCB->getArgOperand(InitKernelEnvironmentArgNo)
264            ->stripPointerCasts());
265  }
266  
getKernelEnvironementFromKernelInitCB(CallBase * KernelInitCB)267  ConstantStruct *getKernelEnvironementFromKernelInitCB(CallBase *KernelInitCB) {
268    GlobalVariable *KernelEnvGV =
269        getKernelEnvironementGVFromKernelInitCB(KernelInitCB);
270    return cast<ConstantStruct>(KernelEnvGV->getInitializer());
271  }
272  } // namespace KernelInfo
273  
274  namespace {
275  
276  struct AAHeapToShared;
277  
278  struct AAICVTracker;
279  
280  /// OpenMP specific information. For now, stores RFIs and ICVs also needed for
281  /// Attributor runs.
282  struct OMPInformationCache : public InformationCache {
OMPInformationCache__anondf4de5470111::OMPInformationCache283    OMPInformationCache(Module &M, AnalysisGetter &AG,
284                        BumpPtrAllocator &Allocator, SetVector<Function *> *CGSCC,
285                        bool OpenMPPostLink)
286        : InformationCache(M, AG, Allocator, CGSCC), OMPBuilder(M),
287          OpenMPPostLink(OpenMPPostLink) {
288  
289      OMPBuilder.Config.IsTargetDevice = isOpenMPDevice(OMPBuilder.M);
290      OMPBuilder.initialize();
291      initializeRuntimeFunctions(M);
292      initializeInternalControlVars();
293    }
294  
295    /// Generic information that describes an internal control variable.
296    struct InternalControlVarInfo {
297      /// The kind, as described by InternalControlVar enum.
298      InternalControlVar Kind;
299  
300      /// The name of the ICV.
301      StringRef Name;
302  
303      /// Environment variable associated with this ICV.
304      StringRef EnvVarName;
305  
306      /// Initial value kind.
307      ICVInitValue InitKind;
308  
309      /// Initial value.
310      ConstantInt *InitValue;
311  
312      /// Setter RTL function associated with this ICV.
313      RuntimeFunction Setter;
314  
315      /// Getter RTL function associated with this ICV.
316      RuntimeFunction Getter;
317  
318      /// RTL Function corresponding to the override clause of this ICV
319      RuntimeFunction Clause;
320    };
321  
322    /// Generic information that describes a runtime function
323    struct RuntimeFunctionInfo {
324  
325      /// The kind, as described by the RuntimeFunction enum.
326      RuntimeFunction Kind;
327  
328      /// The name of the function.
329      StringRef Name;
330  
331      /// Flag to indicate a variadic function.
332      bool IsVarArg;
333  
334      /// The return type of the function.
335      Type *ReturnType;
336  
337      /// The argument types of the function.
338      SmallVector<Type *, 8> ArgumentTypes;
339  
340      /// The declaration if available.
341      Function *Declaration = nullptr;
342  
343      /// Uses of this runtime function per function containing the use.
344      using UseVector = SmallVector<Use *, 16>;
345  
346      /// Clear UsesMap for runtime function.
clearUsesMap__anondf4de5470111::OMPInformationCache::RuntimeFunctionInfo347      void clearUsesMap() { UsesMap.clear(); }
348  
349      /// Boolean conversion that is true if the runtime function was found.
operator bool__anondf4de5470111::OMPInformationCache::RuntimeFunctionInfo350      operator bool() const { return Declaration; }
351  
352      /// Return the vector of uses in function \p F.
getOrCreateUseVector__anondf4de5470111::OMPInformationCache::RuntimeFunctionInfo353      UseVector &getOrCreateUseVector(Function *F) {
354        std::shared_ptr<UseVector> &UV = UsesMap[F];
355        if (!UV)
356          UV = std::make_shared<UseVector>();
357        return *UV;
358      }
359  
360      /// Return the vector of uses in function \p F or `nullptr` if there are
361      /// none.
getUseVector__anondf4de5470111::OMPInformationCache::RuntimeFunctionInfo362      const UseVector *getUseVector(Function &F) const {
363        auto I = UsesMap.find(&F);
364        if (I != UsesMap.end())
365          return I->second.get();
366        return nullptr;
367      }
368  
369      /// Return how many functions contain uses of this runtime function.
getNumFunctionsWithUses__anondf4de5470111::OMPInformationCache::RuntimeFunctionInfo370      size_t getNumFunctionsWithUses() const { return UsesMap.size(); }
371  
372      /// Return the number of arguments (or the minimal number for variadic
373      /// functions).
getNumArgs__anondf4de5470111::OMPInformationCache::RuntimeFunctionInfo374      size_t getNumArgs() const { return ArgumentTypes.size(); }
375  
376      /// Run the callback \p CB on each use and forget the use if the result is
377      /// true. The callback will be fed the function in which the use was
378      /// encountered as second argument.
foreachUse__anondf4de5470111::OMPInformationCache::RuntimeFunctionInfo379      void foreachUse(SmallVectorImpl<Function *> &SCC,
380                      function_ref<bool(Use &, Function &)> CB) {
381        for (Function *F : SCC)
382          foreachUse(CB, F);
383      }
384  
385      /// Run the callback \p CB on each use within the function \p F and forget
386      /// the use if the result is true.
foreachUse__anondf4de5470111::OMPInformationCache::RuntimeFunctionInfo387      void foreachUse(function_ref<bool(Use &, Function &)> CB, Function *F) {
388        SmallVector<unsigned, 8> ToBeDeleted;
389        ToBeDeleted.clear();
390  
391        unsigned Idx = 0;
392        UseVector &UV = getOrCreateUseVector(F);
393  
394        for (Use *U : UV) {
395          if (CB(*U, *F))
396            ToBeDeleted.push_back(Idx);
397          ++Idx;
398        }
399  
400        // Remove the to-be-deleted indices in reverse order as prior
401        // modifications will not modify the smaller indices.
402        while (!ToBeDeleted.empty()) {
403          unsigned Idx = ToBeDeleted.pop_back_val();
404          UV[Idx] = UV.back();
405          UV.pop_back();
406        }
407      }
408  
409    private:
410      /// Map from functions to all uses of this runtime function contained in
411      /// them.
412      DenseMap<Function *, std::shared_ptr<UseVector>> UsesMap;
413  
414    public:
415      /// Iterators for the uses of this runtime function.
begin__anondf4de5470111::OMPInformationCache::RuntimeFunctionInfo416      decltype(UsesMap)::iterator begin() { return UsesMap.begin(); }
end__anondf4de5470111::OMPInformationCache::RuntimeFunctionInfo417      decltype(UsesMap)::iterator end() { return UsesMap.end(); }
418    };
419  
420    /// An OpenMP-IR-Builder instance
421    OpenMPIRBuilder OMPBuilder;
422  
423    /// Map from runtime function kind to the runtime function description.
424    EnumeratedArray<RuntimeFunctionInfo, RuntimeFunction,
425                    RuntimeFunction::OMPRTL___last>
426        RFIs;
427  
428    /// Map from function declarations/definitions to their runtime enum type.
429    DenseMap<Function *, RuntimeFunction> RuntimeFunctionIDMap;
430  
431    /// Map from ICV kind to the ICV description.
432    EnumeratedArray<InternalControlVarInfo, InternalControlVar,
433                    InternalControlVar::ICV___last>
434        ICVs;
435  
436    /// Helper to initialize all internal control variable information for those
437    /// defined in OMPKinds.def.
initializeInternalControlVars__anondf4de5470111::OMPInformationCache438    void initializeInternalControlVars() {
439  #define ICV_RT_SET(_Name, RTL)                                                 \
440    {                                                                            \
441      auto &ICV = ICVs[_Name];                                                   \
442      ICV.Setter = RTL;                                                          \
443    }
444  #define ICV_RT_GET(Name, RTL)                                                  \
445    {                                                                            \
446      auto &ICV = ICVs[Name];                                                    \
447      ICV.Getter = RTL;                                                          \
448    }
449  #define ICV_DATA_ENV(Enum, _Name, _EnvVarName, Init)                           \
450    {                                                                            \
451      auto &ICV = ICVs[Enum];                                                    \
452      ICV.Name = _Name;                                                          \
453      ICV.Kind = Enum;                                                           \
454      ICV.InitKind = Init;                                                       \
455      ICV.EnvVarName = _EnvVarName;                                              \
456      switch (ICV.InitKind) {                                                    \
457      case ICV_IMPLEMENTATION_DEFINED:                                           \
458        ICV.InitValue = nullptr;                                                 \
459        break;                                                                   \
460      case ICV_ZERO:                                                             \
461        ICV.InitValue = ConstantInt::get(                                        \
462            Type::getInt32Ty(OMPBuilder.Int32->getContext()), 0);                \
463        break;                                                                   \
464      case ICV_FALSE:                                                            \
465        ICV.InitValue = ConstantInt::getFalse(OMPBuilder.Int1->getContext());    \
466        break;                                                                   \
467      case ICV_LAST:                                                             \
468        break;                                                                   \
469      }                                                                          \
470    }
471  #include "llvm/Frontend/OpenMP/OMPKinds.def"
472    }
473  
474    /// Returns true if the function declaration \p F matches the runtime
475    /// function types, that is, return type \p RTFRetType, and argument types
476    /// \p RTFArgTypes.
declMatchesRTFTypes__anondf4de5470111::OMPInformationCache477    static bool declMatchesRTFTypes(Function *F, Type *RTFRetType,
478                                    SmallVector<Type *, 8> &RTFArgTypes) {
479      // TODO: We should output information to the user (under debug output
480      //       and via remarks).
481  
482      if (!F)
483        return false;
484      if (F->getReturnType() != RTFRetType)
485        return false;
486      if (F->arg_size() != RTFArgTypes.size())
487        return false;
488  
489      auto *RTFTyIt = RTFArgTypes.begin();
490      for (Argument &Arg : F->args()) {
491        if (Arg.getType() != *RTFTyIt)
492          return false;
493  
494        ++RTFTyIt;
495      }
496  
497      return true;
498    }
499  
500    // Helper to collect all uses of the declaration in the UsesMap.
collectUses__anondf4de5470111::OMPInformationCache501    unsigned collectUses(RuntimeFunctionInfo &RFI, bool CollectStats = true) {
502      unsigned NumUses = 0;
503      if (!RFI.Declaration)
504        return NumUses;
505      OMPBuilder.addAttributes(RFI.Kind, *RFI.Declaration);
506  
507      if (CollectStats) {
508        NumOpenMPRuntimeFunctionsIdentified += 1;
509        NumOpenMPRuntimeFunctionUsesIdentified += RFI.Declaration->getNumUses();
510      }
511  
512      // TODO: We directly convert uses into proper calls and unknown uses.
513      for (Use &U : RFI.Declaration->uses()) {
514        if (Instruction *UserI = dyn_cast<Instruction>(U.getUser())) {
515          if (!CGSCC || CGSCC->empty() || CGSCC->contains(UserI->getFunction())) {
516            RFI.getOrCreateUseVector(UserI->getFunction()).push_back(&U);
517            ++NumUses;
518          }
519        } else {
520          RFI.getOrCreateUseVector(nullptr).push_back(&U);
521          ++NumUses;
522        }
523      }
524      return NumUses;
525    }
526  
527    // Helper function to recollect uses of a runtime function.
recollectUsesForFunction__anondf4de5470111::OMPInformationCache528    void recollectUsesForFunction(RuntimeFunction RTF) {
529      auto &RFI = RFIs[RTF];
530      RFI.clearUsesMap();
531      collectUses(RFI, /*CollectStats*/ false);
532    }
533  
534    // Helper function to recollect uses of all runtime functions.
recollectUses__anondf4de5470111::OMPInformationCache535    void recollectUses() {
536      for (int Idx = 0; Idx < RFIs.size(); ++Idx)
537        recollectUsesForFunction(static_cast<RuntimeFunction>(Idx));
538    }
539  
540    // Helper function to inherit the calling convention of the function callee.
setCallingConvention__anondf4de5470111::OMPInformationCache541    void setCallingConvention(FunctionCallee Callee, CallInst *CI) {
542      if (Function *Fn = dyn_cast<Function>(Callee.getCallee()))
543        CI->setCallingConv(Fn->getCallingConv());
544    }
545  
546    // Helper function to determine if it's legal to create a call to the runtime
547    // functions.
runtimeFnsAvailable__anondf4de5470111::OMPInformationCache548    bool runtimeFnsAvailable(ArrayRef<RuntimeFunction> Fns) {
549      // We can always emit calls if we haven't yet linked in the runtime.
550      if (!OpenMPPostLink)
551        return true;
552  
553      // Once the runtime has been already been linked in we cannot emit calls to
554      // any undefined functions.
555      for (RuntimeFunction Fn : Fns) {
556        RuntimeFunctionInfo &RFI = RFIs[Fn];
557  
558        if (RFI.Declaration && RFI.Declaration->isDeclaration())
559          return false;
560      }
561      return true;
562    }
563  
564    /// Helper to initialize all runtime function information for those defined
565    /// in OpenMPKinds.def.
initializeRuntimeFunctions__anondf4de5470111::OMPInformationCache566    void initializeRuntimeFunctions(Module &M) {
567  
568      // Helper macros for handling __VA_ARGS__ in OMP_RTL
569  #define OMP_TYPE(VarName, ...)                                                 \
570    Type *VarName = OMPBuilder.VarName;                                          \
571    (void)VarName;
572  
573  #define OMP_ARRAY_TYPE(VarName, ...)                                           \
574    ArrayType *VarName##Ty = OMPBuilder.VarName##Ty;                             \
575    (void)VarName##Ty;                                                           \
576    PointerType *VarName##PtrTy = OMPBuilder.VarName##PtrTy;                     \
577    (void)VarName##PtrTy;
578  
579  #define OMP_FUNCTION_TYPE(VarName, ...)                                        \
580    FunctionType *VarName = OMPBuilder.VarName;                                  \
581    (void)VarName;                                                               \
582    PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr;                         \
583    (void)VarName##Ptr;
584  
585  #define OMP_STRUCT_TYPE(VarName, ...)                                          \
586    StructType *VarName = OMPBuilder.VarName;                                    \
587    (void)VarName;                                                               \
588    PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr;                         \
589    (void)VarName##Ptr;
590  
591  #define OMP_RTL(_Enum, _Name, _IsVarArg, _ReturnType, ...)                     \
592    {                                                                            \
593      SmallVector<Type *, 8> ArgsTypes({__VA_ARGS__});                           \
594      Function *F = M.getFunction(_Name);                                        \
595      RTLFunctions.insert(F);                                                    \
596      if (declMatchesRTFTypes(F, OMPBuilder._ReturnType, ArgsTypes)) {           \
597        RuntimeFunctionIDMap[F] = _Enum;                                         \
598        auto &RFI = RFIs[_Enum];                                                 \
599        RFI.Kind = _Enum;                                                        \
600        RFI.Name = _Name;                                                        \
601        RFI.IsVarArg = _IsVarArg;                                                \
602        RFI.ReturnType = OMPBuilder._ReturnType;                                 \
603        RFI.ArgumentTypes = std::move(ArgsTypes);                                \
604        RFI.Declaration = F;                                                     \
605        unsigned NumUses = collectUses(RFI);                                     \
606        (void)NumUses;                                                           \
607        LLVM_DEBUG({                                                             \
608          dbgs() << TAG << RFI.Name << (RFI.Declaration ? "" : " not")           \
609                 << " found\n";                                                  \
610          if (RFI.Declaration)                                                   \
611            dbgs() << TAG << "-> got " << NumUses << " uses in "                 \
612                   << RFI.getNumFunctionsWithUses()                              \
613                   << " different functions.\n";                                 \
614        });                                                                      \
615      }                                                                          \
616    }
617  #include "llvm/Frontend/OpenMP/OMPKinds.def"
618  
619      // Remove the `noinline` attribute from `__kmpc`, `ompx::` and `omp_`
620      // functions, except if `optnone` is present.
621      if (isOpenMPDevice(M)) {
622        for (Function &F : M) {
623          for (StringRef Prefix : {"__kmpc", "_ZN4ompx", "omp_"})
624            if (F.hasFnAttribute(Attribute::NoInline) &&
625                F.getName().starts_with(Prefix) &&
626                !F.hasFnAttribute(Attribute::OptimizeNone))
627              F.removeFnAttr(Attribute::NoInline);
628        }
629      }
630  
631      // TODO: We should attach the attributes defined in OMPKinds.def.
632    }
633  
634    /// Collection of known OpenMP runtime functions..
635    DenseSet<const Function *> RTLFunctions;
636  
637    /// Indicates if we have already linked in the OpenMP device library.
638    bool OpenMPPostLink = false;
639  };
640  
641  template <typename Ty, bool InsertInvalidates = true>
642  struct BooleanStateWithSetVector : public BooleanState {
contains__anondf4de5470111::BooleanStateWithSetVector643    bool contains(const Ty &Elem) const { return Set.contains(Elem); }
insert__anondf4de5470111::BooleanStateWithSetVector644    bool insert(const Ty &Elem) {
645      if (InsertInvalidates)
646        BooleanState::indicatePessimisticFixpoint();
647      return Set.insert(Elem);
648    }
649  
operator []__anondf4de5470111::BooleanStateWithSetVector650    const Ty &operator[](int Idx) const { return Set[Idx]; }
operator ==__anondf4de5470111::BooleanStateWithSetVector651    bool operator==(const BooleanStateWithSetVector &RHS) const {
652      return BooleanState::operator==(RHS) && Set == RHS.Set;
653    }
operator !=__anondf4de5470111::BooleanStateWithSetVector654    bool operator!=(const BooleanStateWithSetVector &RHS) const {
655      return !(*this == RHS);
656    }
657  
empty__anondf4de5470111::BooleanStateWithSetVector658    bool empty() const { return Set.empty(); }
size__anondf4de5470111::BooleanStateWithSetVector659    size_t size() const { return Set.size(); }
660  
661    /// "Clamp" this state with \p RHS.
operator ^=__anondf4de5470111::BooleanStateWithSetVector662    BooleanStateWithSetVector &operator^=(const BooleanStateWithSetVector &RHS) {
663      BooleanState::operator^=(RHS);
664      Set.insert(RHS.Set.begin(), RHS.Set.end());
665      return *this;
666    }
667  
668  private:
669    /// A set to keep track of elements.
670    SetVector<Ty> Set;
671  
672  public:
begin__anondf4de5470111::BooleanStateWithSetVector673    typename decltype(Set)::iterator begin() { return Set.begin(); }
end__anondf4de5470111::BooleanStateWithSetVector674    typename decltype(Set)::iterator end() { return Set.end(); }
begin__anondf4de5470111::BooleanStateWithSetVector675    typename decltype(Set)::const_iterator begin() const { return Set.begin(); }
end__anondf4de5470111::BooleanStateWithSetVector676    typename decltype(Set)::const_iterator end() const { return Set.end(); }
677  };
678  
679  template <typename Ty, bool InsertInvalidates = true>
680  using BooleanStateWithPtrSetVector =
681      BooleanStateWithSetVector<Ty *, InsertInvalidates>;
682  
683  struct KernelInfoState : AbstractState {
684    /// Flag to track if we reached a fixpoint.
685    bool IsAtFixpoint = false;
686  
687    /// The parallel regions (identified by the outlined parallel functions) that
688    /// can be reached from the associated function.
689    BooleanStateWithPtrSetVector<CallBase, /* InsertInvalidates */ false>
690        ReachedKnownParallelRegions;
691  
692    /// State to track what parallel region we might reach.
693    BooleanStateWithPtrSetVector<CallBase> ReachedUnknownParallelRegions;
694  
695    /// State to track if we are in SPMD-mode, assumed or know, and why we decided
696    /// we cannot be. If it is assumed, then RequiresFullRuntime should also be
697    /// false.
698    BooleanStateWithPtrSetVector<Instruction, false> SPMDCompatibilityTracker;
699  
700    /// The __kmpc_target_init call in this kernel, if any. If we find more than
701    /// one we abort as the kernel is malformed.
702    CallBase *KernelInitCB = nullptr;
703  
704    /// The constant kernel environement as taken from and passed to
705    /// __kmpc_target_init.
706    ConstantStruct *KernelEnvC = nullptr;
707  
708    /// The __kmpc_target_deinit call in this kernel, if any. If we find more than
709    /// one we abort as the kernel is malformed.
710    CallBase *KernelDeinitCB = nullptr;
711  
712    /// Flag to indicate if the associated function is a kernel entry.
713    bool IsKernelEntry = false;
714  
715    /// State to track what kernel entries can reach the associated function.
716    BooleanStateWithPtrSetVector<Function, false> ReachingKernelEntries;
717  
718    /// State to indicate if we can track parallel level of the associated
719    /// function. We will give up tracking if we encounter unknown caller or the
720    /// caller is __kmpc_parallel_51.
721    BooleanStateWithSetVector<uint8_t> ParallelLevels;
722  
723    /// Flag that indicates if the kernel has nested Parallelism
724    bool NestedParallelism = false;
725  
726    /// Abstract State interface
727    ///{
728  
729    KernelInfoState() = default;
KernelInfoState__anondf4de5470111::KernelInfoState730    KernelInfoState(bool BestState) {
731      if (!BestState)
732        indicatePessimisticFixpoint();
733    }
734  
735    /// See AbstractState::isValidState(...)
isValidState__anondf4de5470111::KernelInfoState736    bool isValidState() const override { return true; }
737  
738    /// See AbstractState::isAtFixpoint(...)
isAtFixpoint__anondf4de5470111::KernelInfoState739    bool isAtFixpoint() const override { return IsAtFixpoint; }
740  
741    /// See AbstractState::indicatePessimisticFixpoint(...)
indicatePessimisticFixpoint__anondf4de5470111::KernelInfoState742    ChangeStatus indicatePessimisticFixpoint() override {
743      IsAtFixpoint = true;
744      ParallelLevels.indicatePessimisticFixpoint();
745      ReachingKernelEntries.indicatePessimisticFixpoint();
746      SPMDCompatibilityTracker.indicatePessimisticFixpoint();
747      ReachedKnownParallelRegions.indicatePessimisticFixpoint();
748      ReachedUnknownParallelRegions.indicatePessimisticFixpoint();
749      NestedParallelism = true;
750      return ChangeStatus::CHANGED;
751    }
752  
753    /// See AbstractState::indicateOptimisticFixpoint(...)
indicateOptimisticFixpoint__anondf4de5470111::KernelInfoState754    ChangeStatus indicateOptimisticFixpoint() override {
755      IsAtFixpoint = true;
756      ParallelLevels.indicateOptimisticFixpoint();
757      ReachingKernelEntries.indicateOptimisticFixpoint();
758      SPMDCompatibilityTracker.indicateOptimisticFixpoint();
759      ReachedKnownParallelRegions.indicateOptimisticFixpoint();
760      ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
761      return ChangeStatus::UNCHANGED;
762    }
763  
764    /// Return the assumed state
getAssumed__anondf4de5470111::KernelInfoState765    KernelInfoState &getAssumed() { return *this; }
getAssumed__anondf4de5470111::KernelInfoState766    const KernelInfoState &getAssumed() const { return *this; }
767  
operator ==__anondf4de5470111::KernelInfoState768    bool operator==(const KernelInfoState &RHS) const {
769      if (SPMDCompatibilityTracker != RHS.SPMDCompatibilityTracker)
770        return false;
771      if (ReachedKnownParallelRegions != RHS.ReachedKnownParallelRegions)
772        return false;
773      if (ReachedUnknownParallelRegions != RHS.ReachedUnknownParallelRegions)
774        return false;
775      if (ReachingKernelEntries != RHS.ReachingKernelEntries)
776        return false;
777      if (ParallelLevels != RHS.ParallelLevels)
778        return false;
779      if (NestedParallelism != RHS.NestedParallelism)
780        return false;
781      return true;
782    }
783  
784    /// Returns true if this kernel contains any OpenMP parallel regions.
mayContainParallelRegion__anondf4de5470111::KernelInfoState785    bool mayContainParallelRegion() {
786      return !ReachedKnownParallelRegions.empty() ||
787             !ReachedUnknownParallelRegions.empty();
788    }
789  
790    /// Return empty set as the best state of potential values.
getBestState__anondf4de5470111::KernelInfoState791    static KernelInfoState getBestState() { return KernelInfoState(true); }
792  
getBestState__anondf4de5470111::KernelInfoState793    static KernelInfoState getBestState(KernelInfoState &KIS) {
794      return getBestState();
795    }
796  
797    /// Return full set as the worst state of potential values.
getWorstState__anondf4de5470111::KernelInfoState798    static KernelInfoState getWorstState() { return KernelInfoState(false); }
799  
800    /// "Clamp" this state with \p KIS.
operator ^=__anondf4de5470111::KernelInfoState801    KernelInfoState operator^=(const KernelInfoState &KIS) {
802      // Do not merge two different _init and _deinit call sites.
803      if (KIS.KernelInitCB) {
804        if (KernelInitCB && KernelInitCB != KIS.KernelInitCB)
805          llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
806                           "assumptions.");
807        KernelInitCB = KIS.KernelInitCB;
808      }
809      if (KIS.KernelDeinitCB) {
810        if (KernelDeinitCB && KernelDeinitCB != KIS.KernelDeinitCB)
811          llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
812                           "assumptions.");
813        KernelDeinitCB = KIS.KernelDeinitCB;
814      }
815      if (KIS.KernelEnvC) {
816        if (KernelEnvC && KernelEnvC != KIS.KernelEnvC)
817          llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
818                           "assumptions.");
819        KernelEnvC = KIS.KernelEnvC;
820      }
821      SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker;
822      ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions;
823      ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions;
824      NestedParallelism |= KIS.NestedParallelism;
825      return *this;
826    }
827  
operator &=__anondf4de5470111::KernelInfoState828    KernelInfoState operator&=(const KernelInfoState &KIS) {
829      return (*this ^= KIS);
830    }
831  
832    ///}
833  };
834  
835  /// Used to map the values physically (in the IR) stored in an offload
836  /// array, to a vector in memory.
837  struct OffloadArray {
838    /// Physical array (in the IR).
839    AllocaInst *Array = nullptr;
840    /// Mapped values.
841    SmallVector<Value *, 8> StoredValues;
842    /// Last stores made in the offload array.
843    SmallVector<StoreInst *, 8> LastAccesses;
844  
845    OffloadArray() = default;
846  
847    /// Initializes the OffloadArray with the values stored in \p Array before
848    /// instruction \p Before is reached. Returns false if the initialization
849    /// fails.
850    /// This MUST be used immediately after the construction of the object.
initialize__anondf4de5470111::OffloadArray851    bool initialize(AllocaInst &Array, Instruction &Before) {
852      if (!Array.getAllocatedType()->isArrayTy())
853        return false;
854  
855      if (!getValues(Array, Before))
856        return false;
857  
858      this->Array = &Array;
859      return true;
860    }
861  
862    static const unsigned DeviceIDArgNum = 1;
863    static const unsigned BasePtrsArgNum = 3;
864    static const unsigned PtrsArgNum = 4;
865    static const unsigned SizesArgNum = 5;
866  
867  private:
868    /// Traverses the BasicBlock where \p Array is, collecting the stores made to
869    /// \p Array, leaving StoredValues with the values stored before the
870    /// instruction \p Before is reached.
getValues__anondf4de5470111::OffloadArray871    bool getValues(AllocaInst &Array, Instruction &Before) {
872      // Initialize container.
873      const uint64_t NumValues = Array.getAllocatedType()->getArrayNumElements();
874      StoredValues.assign(NumValues, nullptr);
875      LastAccesses.assign(NumValues, nullptr);
876  
877      // TODO: This assumes the instruction \p Before is in the same
878      //  BasicBlock as Array. Make it general, for any control flow graph.
879      BasicBlock *BB = Array.getParent();
880      if (BB != Before.getParent())
881        return false;
882  
883      const DataLayout &DL = Array.getDataLayout();
884      const unsigned int PointerSize = DL.getPointerSize();
885  
886      for (Instruction &I : *BB) {
887        if (&I == &Before)
888          break;
889  
890        if (!isa<StoreInst>(&I))
891          continue;
892  
893        auto *S = cast<StoreInst>(&I);
894        int64_t Offset = -1;
895        auto *Dst =
896            GetPointerBaseWithConstantOffset(S->getPointerOperand(), Offset, DL);
897        if (Dst == &Array) {
898          int64_t Idx = Offset / PointerSize;
899          StoredValues[Idx] = getUnderlyingObject(S->getValueOperand());
900          LastAccesses[Idx] = S;
901        }
902      }
903  
904      return isFilled();
905    }
906  
907    /// Returns true if all values in StoredValues and
908    /// LastAccesses are not nullptrs.
isFilled__anondf4de5470111::OffloadArray909    bool isFilled() {
910      const unsigned NumValues = StoredValues.size();
911      for (unsigned I = 0; I < NumValues; ++I) {
912        if (!StoredValues[I] || !LastAccesses[I])
913          return false;
914      }
915  
916      return true;
917    }
918  };
919  
920  struct OpenMPOpt {
921  
922    using OptimizationRemarkGetter =
923        function_ref<OptimizationRemarkEmitter &(Function *)>;
924  
OpenMPOpt__anondf4de5470111::OpenMPOpt925    OpenMPOpt(SmallVectorImpl<Function *> &SCC, CallGraphUpdater &CGUpdater,
926              OptimizationRemarkGetter OREGetter,
927              OMPInformationCache &OMPInfoCache, Attributor &A)
928        : M(*(*SCC.begin())->getParent()), SCC(SCC), CGUpdater(CGUpdater),
929          OREGetter(OREGetter), OMPInfoCache(OMPInfoCache), A(A) {}
930  
931    /// Check if any remarks are enabled for openmp-opt
remarksEnabled__anondf4de5470111::OpenMPOpt932    bool remarksEnabled() {
933      auto &Ctx = M.getContext();
934      return Ctx.getDiagHandlerPtr()->isAnyRemarkEnabled(DEBUG_TYPE);
935    }
936  
937    /// Run all OpenMP optimizations on the underlying SCC.
run__anondf4de5470111::OpenMPOpt938    bool run(bool IsModulePass) {
939      if (SCC.empty())
940        return false;
941  
942      bool Changed = false;
943  
944      LLVM_DEBUG(dbgs() << TAG << "Run on SCC with " << SCC.size()
945                        << " functions\n");
946  
947      if (IsModulePass) {
948        Changed |= runAttributor(IsModulePass);
949  
950        // Recollect uses, in case Attributor deleted any.
951        OMPInfoCache.recollectUses();
952  
953        // TODO: This should be folded into buildCustomStateMachine.
954        Changed |= rewriteDeviceCodeStateMachine();
955  
956        if (remarksEnabled())
957          analysisGlobalization();
958      } else {
959        if (PrintICVValues)
960          printICVs();
961        if (PrintOpenMPKernels)
962          printKernels();
963  
964        Changed |= runAttributor(IsModulePass);
965  
966        // Recollect uses, in case Attributor deleted any.
967        OMPInfoCache.recollectUses();
968  
969        Changed |= deleteParallelRegions();
970  
971        if (HideMemoryTransferLatency)
972          Changed |= hideMemTransfersLatency();
973        Changed |= deduplicateRuntimeCalls();
974        if (EnableParallelRegionMerging) {
975          if (mergeParallelRegions()) {
976            deduplicateRuntimeCalls();
977            Changed = true;
978          }
979        }
980      }
981  
982      if (OMPInfoCache.OpenMPPostLink)
983        Changed |= removeRuntimeSymbols();
984  
985      return Changed;
986    }
987  
988    /// Print initial ICV values for testing.
989    /// FIXME: This should be done from the Attributor once it is added.
printICVs__anondf4de5470111::OpenMPOpt990    void printICVs() const {
991      InternalControlVar ICVs[] = {ICV_nthreads, ICV_active_levels, ICV_cancel,
992                                   ICV_proc_bind};
993  
994      for (Function *F : SCC) {
995        for (auto ICV : ICVs) {
996          auto ICVInfo = OMPInfoCache.ICVs[ICV];
997          auto Remark = [&](OptimizationRemarkAnalysis ORA) {
998            return ORA << "OpenMP ICV " << ore::NV("OpenMPICV", ICVInfo.Name)
999                       << " Value: "
1000                       << (ICVInfo.InitValue
1001                               ? toString(ICVInfo.InitValue->getValue(), 10, true)
1002                               : "IMPLEMENTATION_DEFINED");
1003          };
1004  
1005          emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPICVTracker", Remark);
1006        }
1007      }
1008    }
1009  
1010    /// Print OpenMP GPU kernels for testing.
printKernels__anondf4de5470111::OpenMPOpt1011    void printKernels() const {
1012      for (Function *F : SCC) {
1013        if (!omp::isOpenMPKernel(*F))
1014          continue;
1015  
1016        auto Remark = [&](OptimizationRemarkAnalysis ORA) {
1017          return ORA << "OpenMP GPU kernel "
1018                     << ore::NV("OpenMPGPUKernel", F->getName()) << "\n";
1019        };
1020  
1021        emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPGPU", Remark);
1022      }
1023    }
1024  
1025    /// Return the call if \p U is a callee use in a regular call. If \p RFI is
1026    /// given it has to be the callee or a nullptr is returned.
getCallIfRegularCall__anondf4de5470111::OpenMPOpt1027    static CallInst *getCallIfRegularCall(
1028        Use &U, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
1029      CallInst *CI = dyn_cast<CallInst>(U.getUser());
1030      if (CI && CI->isCallee(&U) && !CI->hasOperandBundles() &&
1031          (!RFI ||
1032           (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
1033        return CI;
1034      return nullptr;
1035    }
1036  
1037    /// Return the call if \p V is a regular call. If \p RFI is given it has to be
1038    /// the callee or a nullptr is returned.
getCallIfRegularCall__anondf4de5470111::OpenMPOpt1039    static CallInst *getCallIfRegularCall(
1040        Value &V, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
1041      CallInst *CI = dyn_cast<CallInst>(&V);
1042      if (CI && !CI->hasOperandBundles() &&
1043          (!RFI ||
1044           (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
1045        return CI;
1046      return nullptr;
1047    }
1048  
1049  private:
1050    /// Merge parallel regions when it is safe.
mergeParallelRegions__anondf4de5470111::OpenMPOpt1051    bool mergeParallelRegions() {
1052      const unsigned CallbackCalleeOperand = 2;
1053      const unsigned CallbackFirstArgOperand = 3;
1054      using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
1055  
1056      // Check if there are any __kmpc_fork_call calls to merge.
1057      OMPInformationCache::RuntimeFunctionInfo &RFI =
1058          OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
1059  
1060      if (!RFI.Declaration)
1061        return false;
1062  
1063      // Unmergable calls that prevent merging a parallel region.
1064      OMPInformationCache::RuntimeFunctionInfo UnmergableCallsInfo[] = {
1065          OMPInfoCache.RFIs[OMPRTL___kmpc_push_proc_bind],
1066          OMPInfoCache.RFIs[OMPRTL___kmpc_push_num_threads],
1067      };
1068  
1069      bool Changed = false;
1070      LoopInfo *LI = nullptr;
1071      DominatorTree *DT = nullptr;
1072  
1073      SmallDenseMap<BasicBlock *, SmallPtrSet<Instruction *, 4>> BB2PRMap;
1074  
1075      BasicBlock *StartBB = nullptr, *EndBB = nullptr;
1076      auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
1077        BasicBlock *CGStartBB = CodeGenIP.getBlock();
1078        BasicBlock *CGEndBB =
1079            SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
1080        assert(StartBB != nullptr && "StartBB should not be null");
1081        CGStartBB->getTerminator()->setSuccessor(0, StartBB);
1082        assert(EndBB != nullptr && "EndBB should not be null");
1083        EndBB->getTerminator()->setSuccessor(0, CGEndBB);
1084      };
1085  
1086      auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value &,
1087                        Value &Inner, Value *&ReplacementValue) -> InsertPointTy {
1088        ReplacementValue = &Inner;
1089        return CodeGenIP;
1090      };
1091  
1092      auto FiniCB = [&](InsertPointTy CodeGenIP) {};
1093  
1094      /// Create a sequential execution region within a merged parallel region,
1095      /// encapsulated in a master construct with a barrier for synchronization.
1096      auto CreateSequentialRegion = [&](Function *OuterFn,
1097                                        BasicBlock *OuterPredBB,
1098                                        Instruction *SeqStartI,
1099                                        Instruction *SeqEndI) {
1100        // Isolate the instructions of the sequential region to a separate
1101        // block.
1102        BasicBlock *ParentBB = SeqStartI->getParent();
1103        BasicBlock *SeqEndBB =
1104            SplitBlock(ParentBB, SeqEndI->getNextNode(), DT, LI);
1105        BasicBlock *SeqAfterBB =
1106            SplitBlock(SeqEndBB, &*SeqEndBB->getFirstInsertionPt(), DT, LI);
1107        BasicBlock *SeqStartBB =
1108            SplitBlock(ParentBB, SeqStartI, DT, LI, nullptr, "seq.par.merged");
1109  
1110        assert(ParentBB->getUniqueSuccessor() == SeqStartBB &&
1111               "Expected a different CFG");
1112        const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
1113        ParentBB->getTerminator()->eraseFromParent();
1114  
1115        auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
1116          BasicBlock *CGStartBB = CodeGenIP.getBlock();
1117          BasicBlock *CGEndBB =
1118              SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
1119          assert(SeqStartBB != nullptr && "SeqStartBB should not be null");
1120          CGStartBB->getTerminator()->setSuccessor(0, SeqStartBB);
1121          assert(SeqEndBB != nullptr && "SeqEndBB should not be null");
1122          SeqEndBB->getTerminator()->setSuccessor(0, CGEndBB);
1123        };
1124        auto FiniCB = [&](InsertPointTy CodeGenIP) {};
1125  
1126        // Find outputs from the sequential region to outside users and
1127        // broadcast their values to them.
1128        for (Instruction &I : *SeqStartBB) {
1129          SmallPtrSet<Instruction *, 4> OutsideUsers;
1130          for (User *Usr : I.users()) {
1131            Instruction &UsrI = *cast<Instruction>(Usr);
1132            // Ignore outputs to LT intrinsics, code extraction for the merged
1133            // parallel region will fix them.
1134            if (UsrI.isLifetimeStartOrEnd())
1135              continue;
1136  
1137            if (UsrI.getParent() != SeqStartBB)
1138              OutsideUsers.insert(&UsrI);
1139          }
1140  
1141          if (OutsideUsers.empty())
1142            continue;
1143  
1144          // Emit an alloca in the outer region to store the broadcasted
1145          // value.
1146          const DataLayout &DL = M.getDataLayout();
1147          AllocaInst *AllocaI = new AllocaInst(
1148              I.getType(), DL.getAllocaAddrSpace(), nullptr,
1149              I.getName() + ".seq.output.alloc", OuterFn->front().begin());
1150  
1151          // Emit a store instruction in the sequential BB to update the
1152          // value.
1153          new StoreInst(&I, AllocaI, SeqStartBB->getTerminator()->getIterator());
1154  
1155          // Emit a load instruction and replace the use of the output value
1156          // with it.
1157          for (Instruction *UsrI : OutsideUsers) {
1158            LoadInst *LoadI = new LoadInst(I.getType(), AllocaI,
1159                                           I.getName() + ".seq.output.load",
1160                                           UsrI->getIterator());
1161            UsrI->replaceUsesOfWith(&I, LoadI);
1162          }
1163        }
1164  
1165        OpenMPIRBuilder::LocationDescription Loc(
1166            InsertPointTy(ParentBB, ParentBB->end()), DL);
1167        InsertPointTy SeqAfterIP =
1168            OMPInfoCache.OMPBuilder.createMaster(Loc, BodyGenCB, FiniCB);
1169  
1170        OMPInfoCache.OMPBuilder.createBarrier(SeqAfterIP, OMPD_parallel);
1171  
1172        BranchInst::Create(SeqAfterBB, SeqAfterIP.getBlock());
1173  
1174        LLVM_DEBUG(dbgs() << TAG << "After sequential inlining " << *OuterFn
1175                          << "\n");
1176      };
1177  
1178      // Helper to merge the __kmpc_fork_call calls in MergableCIs. They are all
1179      // contained in BB and only separated by instructions that can be
1180      // redundantly executed in parallel. The block BB is split before the first
1181      // call (in MergableCIs) and after the last so the entire region we merge
1182      // into a single parallel region is contained in a single basic block
1183      // without any other instructions. We use the OpenMPIRBuilder to outline
1184      // that block and call the resulting function via __kmpc_fork_call.
1185      auto Merge = [&](const SmallVectorImpl<CallInst *> &MergableCIs,
1186                       BasicBlock *BB) {
1187        // TODO: Change the interface to allow single CIs expanded, e.g, to
1188        // include an outer loop.
1189        assert(MergableCIs.size() > 1 && "Assumed multiple mergable CIs");
1190  
1191        auto Remark = [&](OptimizationRemark OR) {
1192          OR << "Parallel region merged with parallel region"
1193             << (MergableCIs.size() > 2 ? "s" : "") << " at ";
1194          for (auto *CI : llvm::drop_begin(MergableCIs)) {
1195            OR << ore::NV("OpenMPParallelMerge", CI->getDebugLoc());
1196            if (CI != MergableCIs.back())
1197              OR << ", ";
1198          }
1199          return OR << ".";
1200        };
1201  
1202        emitRemark<OptimizationRemark>(MergableCIs.front(), "OMP150", Remark);
1203  
1204        Function *OriginalFn = BB->getParent();
1205        LLVM_DEBUG(dbgs() << TAG << "Merge " << MergableCIs.size()
1206                          << " parallel regions in " << OriginalFn->getName()
1207                          << "\n");
1208  
1209        // Isolate the calls to merge in a separate block.
1210        EndBB = SplitBlock(BB, MergableCIs.back()->getNextNode(), DT, LI);
1211        BasicBlock *AfterBB =
1212            SplitBlock(EndBB, &*EndBB->getFirstInsertionPt(), DT, LI);
1213        StartBB = SplitBlock(BB, MergableCIs.front(), DT, LI, nullptr,
1214                             "omp.par.merged");
1215  
1216        assert(BB->getUniqueSuccessor() == StartBB && "Expected a different CFG");
1217        const DebugLoc DL = BB->getTerminator()->getDebugLoc();
1218        BB->getTerminator()->eraseFromParent();
1219  
1220        // Create sequential regions for sequential instructions that are
1221        // in-between mergable parallel regions.
1222        for (auto *It = MergableCIs.begin(), *End = MergableCIs.end() - 1;
1223             It != End; ++It) {
1224          Instruction *ForkCI = *It;
1225          Instruction *NextForkCI = *(It + 1);
1226  
1227          // Continue if there are not in-between instructions.
1228          if (ForkCI->getNextNode() == NextForkCI)
1229            continue;
1230  
1231          CreateSequentialRegion(OriginalFn, BB, ForkCI->getNextNode(),
1232                                 NextForkCI->getPrevNode());
1233        }
1234  
1235        OpenMPIRBuilder::LocationDescription Loc(InsertPointTy(BB, BB->end()),
1236                                                 DL);
1237        IRBuilder<>::InsertPoint AllocaIP(
1238            &OriginalFn->getEntryBlock(),
1239            OriginalFn->getEntryBlock().getFirstInsertionPt());
1240        // Create the merged parallel region with default proc binding, to
1241        // avoid overriding binding settings, and without explicit cancellation.
1242        InsertPointTy AfterIP = OMPInfoCache.OMPBuilder.createParallel(
1243            Loc, AllocaIP, BodyGenCB, PrivCB, FiniCB, nullptr, nullptr,
1244            OMP_PROC_BIND_default, /* IsCancellable */ false);
1245        BranchInst::Create(AfterBB, AfterIP.getBlock());
1246  
1247        // Perform the actual outlining.
1248        OMPInfoCache.OMPBuilder.finalize(OriginalFn);
1249  
1250        Function *OutlinedFn = MergableCIs.front()->getCaller();
1251  
1252        // Replace the __kmpc_fork_call calls with direct calls to the outlined
1253        // callbacks.
1254        SmallVector<Value *, 8> Args;
1255        for (auto *CI : MergableCIs) {
1256          Value *Callee = CI->getArgOperand(CallbackCalleeOperand);
1257          FunctionType *FT = OMPInfoCache.OMPBuilder.ParallelTask;
1258          Args.clear();
1259          Args.push_back(OutlinedFn->getArg(0));
1260          Args.push_back(OutlinedFn->getArg(1));
1261          for (unsigned U = CallbackFirstArgOperand, E = CI->arg_size(); U < E;
1262               ++U)
1263            Args.push_back(CI->getArgOperand(U));
1264  
1265          CallInst *NewCI =
1266              CallInst::Create(FT, Callee, Args, "", CI->getIterator());
1267          if (CI->getDebugLoc())
1268            NewCI->setDebugLoc(CI->getDebugLoc());
1269  
1270          // Forward parameter attributes from the callback to the callee.
1271          for (unsigned U = CallbackFirstArgOperand, E = CI->arg_size(); U < E;
1272               ++U)
1273            for (const Attribute &A : CI->getAttributes().getParamAttrs(U))
1274              NewCI->addParamAttr(
1275                  U - (CallbackFirstArgOperand - CallbackCalleeOperand), A);
1276  
1277          // Emit an explicit barrier to replace the implicit fork-join barrier.
1278          if (CI != MergableCIs.back()) {
1279            // TODO: Remove barrier if the merged parallel region includes the
1280            // 'nowait' clause.
1281            OMPInfoCache.OMPBuilder.createBarrier(
1282                InsertPointTy(NewCI->getParent(),
1283                              NewCI->getNextNode()->getIterator()),
1284                OMPD_parallel);
1285          }
1286  
1287          CI->eraseFromParent();
1288        }
1289  
1290        assert(OutlinedFn != OriginalFn && "Outlining failed");
1291        CGUpdater.registerOutlinedFunction(*OriginalFn, *OutlinedFn);
1292        CGUpdater.reanalyzeFunction(*OriginalFn);
1293  
1294        NumOpenMPParallelRegionsMerged += MergableCIs.size();
1295  
1296        return true;
1297      };
1298  
1299      // Helper function that identifes sequences of
1300      // __kmpc_fork_call uses in a basic block.
1301      auto DetectPRsCB = [&](Use &U, Function &F) {
1302        CallInst *CI = getCallIfRegularCall(U, &RFI);
1303        BB2PRMap[CI->getParent()].insert(CI);
1304  
1305        return false;
1306      };
1307  
1308      BB2PRMap.clear();
1309      RFI.foreachUse(SCC, DetectPRsCB);
1310      SmallVector<SmallVector<CallInst *, 4>, 4> MergableCIsVector;
1311      // Find mergable parallel regions within a basic block that are
1312      // safe to merge, that is any in-between instructions can safely
1313      // execute in parallel after merging.
1314      // TODO: support merging across basic-blocks.
1315      for (auto &It : BB2PRMap) {
1316        auto &CIs = It.getSecond();
1317        if (CIs.size() < 2)
1318          continue;
1319  
1320        BasicBlock *BB = It.getFirst();
1321        SmallVector<CallInst *, 4> MergableCIs;
1322  
1323        /// Returns true if the instruction is mergable, false otherwise.
1324        /// A terminator instruction is unmergable by definition since merging
1325        /// works within a BB. Instructions before the mergable region are
1326        /// mergable if they are not calls to OpenMP runtime functions that may
1327        /// set different execution parameters for subsequent parallel regions.
1328        /// Instructions in-between parallel regions are mergable if they are not
1329        /// calls to any non-intrinsic function since that may call a non-mergable
1330        /// OpenMP runtime function.
1331        auto IsMergable = [&](Instruction &I, bool IsBeforeMergableRegion) {
1332          // We do not merge across BBs, hence return false (unmergable) if the
1333          // instruction is a terminator.
1334          if (I.isTerminator())
1335            return false;
1336  
1337          if (!isa<CallInst>(&I))
1338            return true;
1339  
1340          CallInst *CI = cast<CallInst>(&I);
1341          if (IsBeforeMergableRegion) {
1342            Function *CalledFunction = CI->getCalledFunction();
1343            if (!CalledFunction)
1344              return false;
1345            // Return false (unmergable) if the call before the parallel
1346            // region calls an explicit affinity (proc_bind) or number of
1347            // threads (num_threads) compiler-generated function. Those settings
1348            // may be incompatible with following parallel regions.
1349            // TODO: ICV tracking to detect compatibility.
1350            for (const auto &RFI : UnmergableCallsInfo) {
1351              if (CalledFunction == RFI.Declaration)
1352                return false;
1353            }
1354          } else {
1355            // Return false (unmergable) if there is a call instruction
1356            // in-between parallel regions when it is not an intrinsic. It
1357            // may call an unmergable OpenMP runtime function in its callpath.
1358            // TODO: Keep track of possible OpenMP calls in the callpath.
1359            if (!isa<IntrinsicInst>(CI))
1360              return false;
1361          }
1362  
1363          return true;
1364        };
1365        // Find maximal number of parallel region CIs that are safe to merge.
1366        for (auto It = BB->begin(), End = BB->end(); It != End;) {
1367          Instruction &I = *It;
1368          ++It;
1369  
1370          if (CIs.count(&I)) {
1371            MergableCIs.push_back(cast<CallInst>(&I));
1372            continue;
1373          }
1374  
1375          // Continue expanding if the instruction is mergable.
1376          if (IsMergable(I, MergableCIs.empty()))
1377            continue;
1378  
1379          // Forward the instruction iterator to skip the next parallel region
1380          // since there is an unmergable instruction which can affect it.
1381          for (; It != End; ++It) {
1382            Instruction &SkipI = *It;
1383            if (CIs.count(&SkipI)) {
1384              LLVM_DEBUG(dbgs() << TAG << "Skip parallel region " << SkipI
1385                                << " due to " << I << "\n");
1386              ++It;
1387              break;
1388            }
1389          }
1390  
1391          // Store mergable regions found.
1392          if (MergableCIs.size() > 1) {
1393            MergableCIsVector.push_back(MergableCIs);
1394            LLVM_DEBUG(dbgs() << TAG << "Found " << MergableCIs.size()
1395                              << " parallel regions in block " << BB->getName()
1396                              << " of function " << BB->getParent()->getName()
1397                              << "\n";);
1398          }
1399  
1400          MergableCIs.clear();
1401        }
1402  
1403        if (!MergableCIsVector.empty()) {
1404          Changed = true;
1405  
1406          for (auto &MergableCIs : MergableCIsVector)
1407            Merge(MergableCIs, BB);
1408          MergableCIsVector.clear();
1409        }
1410      }
1411  
1412      if (Changed) {
1413        /// Re-collect use for fork calls, emitted barrier calls, and
1414        /// any emitted master/end_master calls.
1415        OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_fork_call);
1416        OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_barrier);
1417        OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_master);
1418        OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_end_master);
1419      }
1420  
1421      return Changed;
1422    }
1423  
1424    /// Try to delete parallel regions if possible.
deleteParallelRegions__anondf4de5470111::OpenMPOpt1425    bool deleteParallelRegions() {
1426      const unsigned CallbackCalleeOperand = 2;
1427  
1428      OMPInformationCache::RuntimeFunctionInfo &RFI =
1429          OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
1430  
1431      if (!RFI.Declaration)
1432        return false;
1433  
1434      bool Changed = false;
1435      auto DeleteCallCB = [&](Use &U, Function &) {
1436        CallInst *CI = getCallIfRegularCall(U);
1437        if (!CI)
1438          return false;
1439        auto *Fn = dyn_cast<Function>(
1440            CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts());
1441        if (!Fn)
1442          return false;
1443        if (!Fn->onlyReadsMemory())
1444          return false;
1445        if (!Fn->hasFnAttribute(Attribute::WillReturn))
1446          return false;
1447  
1448        LLVM_DEBUG(dbgs() << TAG << "Delete read-only parallel region in "
1449                          << CI->getCaller()->getName() << "\n");
1450  
1451        auto Remark = [&](OptimizationRemark OR) {
1452          return OR << "Removing parallel region with no side-effects.";
1453        };
1454        emitRemark<OptimizationRemark>(CI, "OMP160", Remark);
1455  
1456        CI->eraseFromParent();
1457        Changed = true;
1458        ++NumOpenMPParallelRegionsDeleted;
1459        return true;
1460      };
1461  
1462      RFI.foreachUse(SCC, DeleteCallCB);
1463  
1464      return Changed;
1465    }
1466  
1467    /// Try to eliminate runtime calls by reusing existing ones.
deduplicateRuntimeCalls__anondf4de5470111::OpenMPOpt1468    bool deduplicateRuntimeCalls() {
1469      bool Changed = false;
1470  
1471      RuntimeFunction DeduplicableRuntimeCallIDs[] = {
1472          OMPRTL_omp_get_num_threads,
1473          OMPRTL_omp_in_parallel,
1474          OMPRTL_omp_get_cancellation,
1475          OMPRTL_omp_get_supported_active_levels,
1476          OMPRTL_omp_get_level,
1477          OMPRTL_omp_get_ancestor_thread_num,
1478          OMPRTL_omp_get_team_size,
1479          OMPRTL_omp_get_active_level,
1480          OMPRTL_omp_in_final,
1481          OMPRTL_omp_get_proc_bind,
1482          OMPRTL_omp_get_num_places,
1483          OMPRTL_omp_get_num_procs,
1484          OMPRTL_omp_get_place_num,
1485          OMPRTL_omp_get_partition_num_places,
1486          OMPRTL_omp_get_partition_place_nums};
1487  
1488      // Global-tid is handled separately.
1489      SmallSetVector<Value *, 16> GTIdArgs;
1490      collectGlobalThreadIdArguments(GTIdArgs);
1491      LLVM_DEBUG(dbgs() << TAG << "Found " << GTIdArgs.size()
1492                        << " global thread ID arguments\n");
1493  
1494      for (Function *F : SCC) {
1495        for (auto DeduplicableRuntimeCallID : DeduplicableRuntimeCallIDs)
1496          Changed |= deduplicateRuntimeCalls(
1497              *F, OMPInfoCache.RFIs[DeduplicableRuntimeCallID]);
1498  
1499        // __kmpc_global_thread_num is special as we can replace it with an
1500        // argument in enough cases to make it worth trying.
1501        Value *GTIdArg = nullptr;
1502        for (Argument &Arg : F->args())
1503          if (GTIdArgs.count(&Arg)) {
1504            GTIdArg = &Arg;
1505            break;
1506          }
1507        Changed |= deduplicateRuntimeCalls(
1508            *F, OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num], GTIdArg);
1509      }
1510  
1511      return Changed;
1512    }
1513  
1514    /// Tries to remove known runtime symbols that are optional from the module.
removeRuntimeSymbols__anondf4de5470111::OpenMPOpt1515    bool removeRuntimeSymbols() {
1516      // The RPC client symbol is defined in `libc` and indicates that something
1517      // required an RPC server. If its users were all optimized out then we can
1518      // safely remove it.
1519      // TODO: This should be somewhere more common in the future.
1520      if (GlobalVariable *GV = M.getNamedGlobal("__llvm_libc_rpc_client")) {
1521        if (!GV->getType()->isPointerTy())
1522          return false;
1523  
1524        Constant *C = GV->getInitializer();
1525        if (!C)
1526          return false;
1527  
1528        // Check to see if the only user of the RPC client is the external handle.
1529        GlobalVariable *Client = dyn_cast<GlobalVariable>(C->stripPointerCasts());
1530        if (!Client || Client->getNumUses() > 1 ||
1531            Client->user_back() != GV->getInitializer())
1532          return false;
1533  
1534        Client->replaceAllUsesWith(PoisonValue::get(Client->getType()));
1535        Client->eraseFromParent();
1536  
1537        GV->replaceAllUsesWith(PoisonValue::get(GV->getType()));
1538        GV->eraseFromParent();
1539  
1540        return true;
1541      }
1542      return false;
1543    }
1544  
1545    /// Tries to hide the latency of runtime calls that involve host to
1546    /// device memory transfers by splitting them into their "issue" and "wait"
1547    /// versions. The "issue" is moved upwards as much as possible. The "wait" is
1548    /// moved downards as much as possible. The "issue" issues the memory transfer
1549    /// asynchronously, returning a handle. The "wait" waits in the returned
1550    /// handle for the memory transfer to finish.
hideMemTransfersLatency__anondf4de5470111::OpenMPOpt1551    bool hideMemTransfersLatency() {
1552      auto &RFI = OMPInfoCache.RFIs[OMPRTL___tgt_target_data_begin_mapper];
1553      bool Changed = false;
1554      auto SplitMemTransfers = [&](Use &U, Function &Decl) {
1555        auto *RTCall = getCallIfRegularCall(U, &RFI);
1556        if (!RTCall)
1557          return false;
1558  
1559        OffloadArray OffloadArrays[3];
1560        if (!getValuesInOffloadArrays(*RTCall, OffloadArrays))
1561          return false;
1562  
1563        LLVM_DEBUG(dumpValuesInOffloadArrays(OffloadArrays));
1564  
1565        // TODO: Check if can be moved upwards.
1566        bool WasSplit = false;
1567        Instruction *WaitMovementPoint = canBeMovedDownwards(*RTCall);
1568        if (WaitMovementPoint)
1569          WasSplit = splitTargetDataBeginRTC(*RTCall, *WaitMovementPoint);
1570  
1571        Changed |= WasSplit;
1572        return WasSplit;
1573      };
1574      if (OMPInfoCache.runtimeFnsAvailable(
1575              {OMPRTL___tgt_target_data_begin_mapper_issue,
1576               OMPRTL___tgt_target_data_begin_mapper_wait}))
1577        RFI.foreachUse(SCC, SplitMemTransfers);
1578  
1579      return Changed;
1580    }
1581  
analysisGlobalization__anondf4de5470111::OpenMPOpt1582    void analysisGlobalization() {
1583      auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
1584  
1585      auto CheckGlobalization = [&](Use &U, Function &Decl) {
1586        if (CallInst *CI = getCallIfRegularCall(U, &RFI)) {
1587          auto Remark = [&](OptimizationRemarkMissed ORM) {
1588            return ORM
1589                   << "Found thread data sharing on the GPU. "
1590                   << "Expect degraded performance due to data globalization.";
1591          };
1592          emitRemark<OptimizationRemarkMissed>(CI, "OMP112", Remark);
1593        }
1594  
1595        return false;
1596      };
1597  
1598      RFI.foreachUse(SCC, CheckGlobalization);
1599    }
1600  
1601    /// Maps the values stored in the offload arrays passed as arguments to
1602    /// \p RuntimeCall into the offload arrays in \p OAs.
getValuesInOffloadArrays__anondf4de5470111::OpenMPOpt1603    bool getValuesInOffloadArrays(CallInst &RuntimeCall,
1604                                  MutableArrayRef<OffloadArray> OAs) {
1605      assert(OAs.size() == 3 && "Need space for three offload arrays!");
1606  
1607      // A runtime call that involves memory offloading looks something like:
1608      // call void @__tgt_target_data_begin_mapper(arg0, arg1,
1609      //   i8** %offload_baseptrs, i8** %offload_ptrs, i64* %offload_sizes,
1610      // ...)
1611      // So, the idea is to access the allocas that allocate space for these
1612      // offload arrays, offload_baseptrs, offload_ptrs, offload_sizes.
1613      // Therefore:
1614      // i8** %offload_baseptrs.
1615      Value *BasePtrsArg =
1616          RuntimeCall.getArgOperand(OffloadArray::BasePtrsArgNum);
1617      // i8** %offload_ptrs.
1618      Value *PtrsArg = RuntimeCall.getArgOperand(OffloadArray::PtrsArgNum);
1619      // i8** %offload_sizes.
1620      Value *SizesArg = RuntimeCall.getArgOperand(OffloadArray::SizesArgNum);
1621  
1622      // Get values stored in **offload_baseptrs.
1623      auto *V = getUnderlyingObject(BasePtrsArg);
1624      if (!isa<AllocaInst>(V))
1625        return false;
1626      auto *BasePtrsArray = cast<AllocaInst>(V);
1627      if (!OAs[0].initialize(*BasePtrsArray, RuntimeCall))
1628        return false;
1629  
1630      // Get values stored in **offload_baseptrs.
1631      V = getUnderlyingObject(PtrsArg);
1632      if (!isa<AllocaInst>(V))
1633        return false;
1634      auto *PtrsArray = cast<AllocaInst>(V);
1635      if (!OAs[1].initialize(*PtrsArray, RuntimeCall))
1636        return false;
1637  
1638      // Get values stored in **offload_sizes.
1639      V = getUnderlyingObject(SizesArg);
1640      // If it's a [constant] global array don't analyze it.
1641      if (isa<GlobalValue>(V))
1642        return isa<Constant>(V);
1643      if (!isa<AllocaInst>(V))
1644        return false;
1645  
1646      auto *SizesArray = cast<AllocaInst>(V);
1647      if (!OAs[2].initialize(*SizesArray, RuntimeCall))
1648        return false;
1649  
1650      return true;
1651    }
1652  
1653    /// Prints the values in the OffloadArrays \p OAs using LLVM_DEBUG.
1654    /// For now this is a way to test that the function getValuesInOffloadArrays
1655    /// is working properly.
1656    /// TODO: Move this to a unittest when unittests are available for OpenMPOpt.
dumpValuesInOffloadArrays__anondf4de5470111::OpenMPOpt1657    void dumpValuesInOffloadArrays(ArrayRef<OffloadArray> OAs) {
1658      assert(OAs.size() == 3 && "There are three offload arrays to debug!");
1659  
1660      LLVM_DEBUG(dbgs() << TAG << " Successfully got offload values:\n");
1661      std::string ValuesStr;
1662      raw_string_ostream Printer(ValuesStr);
1663      std::string Separator = " --- ";
1664  
1665      for (auto *BP : OAs[0].StoredValues) {
1666        BP->print(Printer);
1667        Printer << Separator;
1668      }
1669      LLVM_DEBUG(dbgs() << "\t\toffload_baseptrs: " << ValuesStr << "\n");
1670      ValuesStr.clear();
1671  
1672      for (auto *P : OAs[1].StoredValues) {
1673        P->print(Printer);
1674        Printer << Separator;
1675      }
1676      LLVM_DEBUG(dbgs() << "\t\toffload_ptrs: " << ValuesStr << "\n");
1677      ValuesStr.clear();
1678  
1679      for (auto *S : OAs[2].StoredValues) {
1680        S->print(Printer);
1681        Printer << Separator;
1682      }
1683      LLVM_DEBUG(dbgs() << "\t\toffload_sizes: " << ValuesStr << "\n");
1684    }
1685  
1686    /// Returns the instruction where the "wait" counterpart \p RuntimeCall can be
1687    /// moved. Returns nullptr if the movement is not possible, or not worth it.
canBeMovedDownwards__anondf4de5470111::OpenMPOpt1688    Instruction *canBeMovedDownwards(CallInst &RuntimeCall) {
1689      // FIXME: This traverses only the BasicBlock where RuntimeCall is.
1690      //  Make it traverse the CFG.
1691  
1692      Instruction *CurrentI = &RuntimeCall;
1693      bool IsWorthIt = false;
1694      while ((CurrentI = CurrentI->getNextNode())) {
1695  
1696        // TODO: Once we detect the regions to be offloaded we should use the
1697        //  alias analysis manager to check if CurrentI may modify one of
1698        //  the offloaded regions.
1699        if (CurrentI->mayHaveSideEffects() || CurrentI->mayReadFromMemory()) {
1700          if (IsWorthIt)
1701            return CurrentI;
1702  
1703          return nullptr;
1704        }
1705  
1706        // FIXME: For now if we move it over anything without side effect
1707        //  is worth it.
1708        IsWorthIt = true;
1709      }
1710  
1711      // Return end of BasicBlock.
1712      return RuntimeCall.getParent()->getTerminator();
1713    }
1714  
1715    /// Splits \p RuntimeCall into its "issue" and "wait" counterparts.
splitTargetDataBeginRTC__anondf4de5470111::OpenMPOpt1716    bool splitTargetDataBeginRTC(CallInst &RuntimeCall,
1717                                 Instruction &WaitMovementPoint) {
1718      // Create stack allocated handle (__tgt_async_info) at the beginning of the
1719      // function. Used for storing information of the async transfer, allowing to
1720      // wait on it later.
1721      auto &IRBuilder = OMPInfoCache.OMPBuilder;
1722      Function *F = RuntimeCall.getCaller();
1723      BasicBlock &Entry = F->getEntryBlock();
1724      IRBuilder.Builder.SetInsertPoint(&Entry,
1725                                       Entry.getFirstNonPHIOrDbgOrAlloca());
1726      Value *Handle = IRBuilder.Builder.CreateAlloca(
1727          IRBuilder.AsyncInfo, /*ArraySize=*/nullptr, "handle");
1728      Handle =
1729          IRBuilder.Builder.CreateAddrSpaceCast(Handle, IRBuilder.AsyncInfoPtr);
1730  
1731      // Add "issue" runtime call declaration:
1732      // declare %struct.tgt_async_info @__tgt_target_data_begin_issue(i64, i32,
1733      //   i8**, i8**, i64*, i64*)
1734      FunctionCallee IssueDecl = IRBuilder.getOrCreateRuntimeFunction(
1735          M, OMPRTL___tgt_target_data_begin_mapper_issue);
1736  
1737      // Change RuntimeCall call site for its asynchronous version.
1738      SmallVector<Value *, 16> Args;
1739      for (auto &Arg : RuntimeCall.args())
1740        Args.push_back(Arg.get());
1741      Args.push_back(Handle);
1742  
1743      CallInst *IssueCallsite = CallInst::Create(IssueDecl, Args, /*NameStr=*/"",
1744                                                 RuntimeCall.getIterator());
1745      OMPInfoCache.setCallingConvention(IssueDecl, IssueCallsite);
1746      RuntimeCall.eraseFromParent();
1747  
1748      // Add "wait" runtime call declaration:
1749      // declare void @__tgt_target_data_begin_wait(i64, %struct.__tgt_async_info)
1750      FunctionCallee WaitDecl = IRBuilder.getOrCreateRuntimeFunction(
1751          M, OMPRTL___tgt_target_data_begin_mapper_wait);
1752  
1753      Value *WaitParams[2] = {
1754          IssueCallsite->getArgOperand(
1755              OffloadArray::DeviceIDArgNum), // device_id.
1756          Handle                             // handle to wait on.
1757      };
1758      CallInst *WaitCallsite = CallInst::Create(
1759          WaitDecl, WaitParams, /*NameStr=*/"", WaitMovementPoint.getIterator());
1760      OMPInfoCache.setCallingConvention(WaitDecl, WaitCallsite);
1761  
1762      return true;
1763    }
1764  
combinedIdentStruct__anondf4de5470111::OpenMPOpt1765    static Value *combinedIdentStruct(Value *CurrentIdent, Value *NextIdent,
1766                                      bool GlobalOnly, bool &SingleChoice) {
1767      if (CurrentIdent == NextIdent)
1768        return CurrentIdent;
1769  
1770      // TODO: Figure out how to actually combine multiple debug locations. For
1771      //       now we just keep an existing one if there is a single choice.
1772      if (!GlobalOnly || isa<GlobalValue>(NextIdent)) {
1773        SingleChoice = !CurrentIdent;
1774        return NextIdent;
1775      }
1776      return nullptr;
1777    }
1778  
1779    /// Return an `struct ident_t*` value that represents the ones used in the
1780    /// calls of \p RFI inside of \p F. If \p GlobalOnly is true, we will not
1781    /// return a local `struct ident_t*`. For now, if we cannot find a suitable
1782    /// return value we create one from scratch. We also do not yet combine
1783    /// information, e.g., the source locations, see combinedIdentStruct.
1784    Value *
getCombinedIdentFromCallUsesIn__anondf4de5470111::OpenMPOpt1785    getCombinedIdentFromCallUsesIn(OMPInformationCache::RuntimeFunctionInfo &RFI,
1786                                   Function &F, bool GlobalOnly) {
1787      bool SingleChoice = true;
1788      Value *Ident = nullptr;
1789      auto CombineIdentStruct = [&](Use &U, Function &Caller) {
1790        CallInst *CI = getCallIfRegularCall(U, &RFI);
1791        if (!CI || &F != &Caller)
1792          return false;
1793        Ident = combinedIdentStruct(Ident, CI->getArgOperand(0),
1794                                    /* GlobalOnly */ true, SingleChoice);
1795        return false;
1796      };
1797      RFI.foreachUse(SCC, CombineIdentStruct);
1798  
1799      if (!Ident || !SingleChoice) {
1800        // The IRBuilder uses the insertion block to get to the module, this is
1801        // unfortunate but we work around it for now.
1802        if (!OMPInfoCache.OMPBuilder.getInsertionPoint().getBlock())
1803          OMPInfoCache.OMPBuilder.updateToLocation(OpenMPIRBuilder::InsertPointTy(
1804              &F.getEntryBlock(), F.getEntryBlock().begin()));
1805        // Create a fallback location if non was found.
1806        // TODO: Use the debug locations of the calls instead.
1807        uint32_t SrcLocStrSize;
1808        Constant *Loc =
1809            OMPInfoCache.OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
1810        Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(Loc, SrcLocStrSize);
1811      }
1812      return Ident;
1813    }
1814  
1815    /// Try to eliminate calls of \p RFI in \p F by reusing an existing one or
1816    /// \p ReplVal if given.
deduplicateRuntimeCalls__anondf4de5470111::OpenMPOpt1817    bool deduplicateRuntimeCalls(Function &F,
1818                                 OMPInformationCache::RuntimeFunctionInfo &RFI,
1819                                 Value *ReplVal = nullptr) {
1820      auto *UV = RFI.getUseVector(F);
1821      if (!UV || UV->size() + (ReplVal != nullptr) < 2)
1822        return false;
1823  
1824      LLVM_DEBUG(
1825          dbgs() << TAG << "Deduplicate " << UV->size() << " uses of " << RFI.Name
1826                 << (ReplVal ? " with an existing value\n" : "\n") << "\n");
1827  
1828      assert((!ReplVal || (isa<Argument>(ReplVal) &&
1829                           cast<Argument>(ReplVal)->getParent() == &F)) &&
1830             "Unexpected replacement value!");
1831  
1832      // TODO: Use dominance to find a good position instead.
1833      auto CanBeMoved = [this](CallBase &CB) {
1834        unsigned NumArgs = CB.arg_size();
1835        if (NumArgs == 0)
1836          return true;
1837        if (CB.getArgOperand(0)->getType() != OMPInfoCache.OMPBuilder.IdentPtr)
1838          return false;
1839        for (unsigned U = 1; U < NumArgs; ++U)
1840          if (isa<Instruction>(CB.getArgOperand(U)))
1841            return false;
1842        return true;
1843      };
1844  
1845      if (!ReplVal) {
1846        auto *DT =
1847            OMPInfoCache.getAnalysisResultForFunction<DominatorTreeAnalysis>(F);
1848        if (!DT)
1849          return false;
1850        Instruction *IP = nullptr;
1851        for (Use *U : *UV) {
1852          if (CallInst *CI = getCallIfRegularCall(*U, &RFI)) {
1853            if (IP)
1854              IP = DT->findNearestCommonDominator(IP, CI);
1855            else
1856              IP = CI;
1857            if (!CanBeMoved(*CI))
1858              continue;
1859            if (!ReplVal)
1860              ReplVal = CI;
1861          }
1862        }
1863        if (!ReplVal)
1864          return false;
1865        assert(IP && "Expected insertion point!");
1866        cast<Instruction>(ReplVal)->moveBefore(IP);
1867      }
1868  
1869      // If we use a call as a replacement value we need to make sure the ident is
1870      // valid at the new location. For now we just pick a global one, either
1871      // existing and used by one of the calls, or created from scratch.
1872      if (CallBase *CI = dyn_cast<CallBase>(ReplVal)) {
1873        if (!CI->arg_empty() &&
1874            CI->getArgOperand(0)->getType() == OMPInfoCache.OMPBuilder.IdentPtr) {
1875          Value *Ident = getCombinedIdentFromCallUsesIn(RFI, F,
1876                                                        /* GlobalOnly */ true);
1877          CI->setArgOperand(0, Ident);
1878        }
1879      }
1880  
1881      bool Changed = false;
1882      auto ReplaceAndDeleteCB = [&](Use &U, Function &Caller) {
1883        CallInst *CI = getCallIfRegularCall(U, &RFI);
1884        if (!CI || CI == ReplVal || &F != &Caller)
1885          return false;
1886        assert(CI->getCaller() == &F && "Unexpected call!");
1887  
1888        auto Remark = [&](OptimizationRemark OR) {
1889          return OR << "OpenMP runtime call "
1890                    << ore::NV("OpenMPOptRuntime", RFI.Name) << " deduplicated.";
1891        };
1892        if (CI->getDebugLoc())
1893          emitRemark<OptimizationRemark>(CI, "OMP170", Remark);
1894        else
1895          emitRemark<OptimizationRemark>(&F, "OMP170", Remark);
1896  
1897        CI->replaceAllUsesWith(ReplVal);
1898        CI->eraseFromParent();
1899        ++NumOpenMPRuntimeCallsDeduplicated;
1900        Changed = true;
1901        return true;
1902      };
1903      RFI.foreachUse(SCC, ReplaceAndDeleteCB);
1904  
1905      return Changed;
1906    }
1907  
1908    /// Collect arguments that represent the global thread id in \p GTIdArgs.
collectGlobalThreadIdArguments__anondf4de5470111::OpenMPOpt1909    void collectGlobalThreadIdArguments(SmallSetVector<Value *, 16> &GTIdArgs) {
1910      // TODO: Below we basically perform a fixpoint iteration with a pessimistic
1911      //       initialization. We could define an AbstractAttribute instead and
1912      //       run the Attributor here once it can be run as an SCC pass.
1913  
1914      // Helper to check the argument \p ArgNo at all call sites of \p F for
1915      // a GTId.
1916      auto CallArgOpIsGTId = [&](Function &F, unsigned ArgNo, CallInst &RefCI) {
1917        if (!F.hasLocalLinkage())
1918          return false;
1919        for (Use &U : F.uses()) {
1920          if (CallInst *CI = getCallIfRegularCall(U)) {
1921            Value *ArgOp = CI->getArgOperand(ArgNo);
1922            if (CI == &RefCI || GTIdArgs.count(ArgOp) ||
1923                getCallIfRegularCall(
1924                    *ArgOp, &OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num]))
1925              continue;
1926          }
1927          return false;
1928        }
1929        return true;
1930      };
1931  
1932      // Helper to identify uses of a GTId as GTId arguments.
1933      auto AddUserArgs = [&](Value &GTId) {
1934        for (Use &U : GTId.uses())
1935          if (CallInst *CI = dyn_cast<CallInst>(U.getUser()))
1936            if (CI->isArgOperand(&U))
1937              if (Function *Callee = CI->getCalledFunction())
1938                if (CallArgOpIsGTId(*Callee, U.getOperandNo(), *CI))
1939                  GTIdArgs.insert(Callee->getArg(U.getOperandNo()));
1940      };
1941  
1942      // The argument users of __kmpc_global_thread_num calls are GTIds.
1943      OMPInformationCache::RuntimeFunctionInfo &GlobThreadNumRFI =
1944          OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num];
1945  
1946      GlobThreadNumRFI.foreachUse(SCC, [&](Use &U, Function &F) {
1947        if (CallInst *CI = getCallIfRegularCall(U, &GlobThreadNumRFI))
1948          AddUserArgs(*CI);
1949        return false;
1950      });
1951  
1952      // Transitively search for more arguments by looking at the users of the
1953      // ones we know already. During the search the GTIdArgs vector is extended
1954      // so we cannot cache the size nor can we use a range based for.
1955      for (unsigned U = 0; U < GTIdArgs.size(); ++U)
1956        AddUserArgs(*GTIdArgs[U]);
1957    }
1958  
1959    /// Kernel (=GPU) optimizations and utility functions
1960    ///
1961    ///{{
1962  
1963    /// Cache to remember the unique kernel for a function.
1964    DenseMap<Function *, std::optional<Kernel>> UniqueKernelMap;
1965  
1966    /// Find the unique kernel that will execute \p F, if any.
1967    Kernel getUniqueKernelFor(Function &F);
1968  
1969    /// Find the unique kernel that will execute \p I, if any.
getUniqueKernelFor__anondf4de5470111::OpenMPOpt1970    Kernel getUniqueKernelFor(Instruction &I) {
1971      return getUniqueKernelFor(*I.getFunction());
1972    }
1973  
1974    /// Rewrite the device (=GPU) code state machine create in non-SPMD mode in
1975    /// the cases we can avoid taking the address of a function.
1976    bool rewriteDeviceCodeStateMachine();
1977  
1978    ///
1979    ///}}
1980  
1981    /// Emit a remark generically
1982    ///
1983    /// This template function can be used to generically emit a remark. The
1984    /// RemarkKind should be one of the following:
1985    ///   - OptimizationRemark to indicate a successful optimization attempt
1986    ///   - OptimizationRemarkMissed to report a failed optimization attempt
1987    ///   - OptimizationRemarkAnalysis to provide additional information about an
1988    ///     optimization attempt
1989    ///
1990    /// The remark is built using a callback function provided by the caller that
1991    /// takes a RemarkKind as input and returns a RemarkKind.
1992    template <typename RemarkKind, typename RemarkCallBack>
emitRemark__anondf4de5470111::OpenMPOpt1993    void emitRemark(Instruction *I, StringRef RemarkName,
1994                    RemarkCallBack &&RemarkCB) const {
1995      Function *F = I->getParent()->getParent();
1996      auto &ORE = OREGetter(F);
1997  
1998      if (RemarkName.starts_with("OMP"))
1999        ORE.emit([&]() {
2000          return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I))
2001                 << " [" << RemarkName << "]";
2002        });
2003      else
2004        ORE.emit(
2005            [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I)); });
2006    }
2007  
2008    /// Emit a remark on a function.
2009    template <typename RemarkKind, typename RemarkCallBack>
emitRemark__anondf4de5470111::OpenMPOpt2010    void emitRemark(Function *F, StringRef RemarkName,
2011                    RemarkCallBack &&RemarkCB) const {
2012      auto &ORE = OREGetter(F);
2013  
2014      if (RemarkName.starts_with("OMP"))
2015        ORE.emit([&]() {
2016          return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F))
2017                 << " [" << RemarkName << "]";
2018        });
2019      else
2020        ORE.emit(
2021            [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F)); });
2022    }
2023  
2024    /// The underlying module.
2025    Module &M;
2026  
2027    /// The SCC we are operating on.
2028    SmallVectorImpl<Function *> &SCC;
2029  
2030    /// Callback to update the call graph, the first argument is a removed call,
2031    /// the second an optional replacement call.
2032    CallGraphUpdater &CGUpdater;
2033  
2034    /// Callback to get an OptimizationRemarkEmitter from a Function *
2035    OptimizationRemarkGetter OREGetter;
2036  
2037    /// OpenMP-specific information cache. Also Used for Attributor runs.
2038    OMPInformationCache &OMPInfoCache;
2039  
2040    /// Attributor instance.
2041    Attributor &A;
2042  
2043    /// Helper function to run Attributor on SCC.
runAttributor__anondf4de5470111::OpenMPOpt2044    bool runAttributor(bool IsModulePass) {
2045      if (SCC.empty())
2046        return false;
2047  
2048      registerAAs(IsModulePass);
2049  
2050      ChangeStatus Changed = A.run();
2051  
2052      LLVM_DEBUG(dbgs() << "[Attributor] Done with " << SCC.size()
2053                        << " functions, result: " << Changed << ".\n");
2054  
2055      if (Changed == ChangeStatus::CHANGED)
2056        OMPInfoCache.invalidateAnalyses();
2057  
2058      return Changed == ChangeStatus::CHANGED;
2059    }
2060  
2061    void registerFoldRuntimeCall(RuntimeFunction RF);
2062  
2063    /// Populate the Attributor with abstract attribute opportunities in the
2064    /// functions.
2065    void registerAAs(bool IsModulePass);
2066  
2067  public:
2068    /// Callback to register AAs for live functions, including internal functions
2069    /// marked live during the traversal.
2070    static void registerAAsForFunction(Attributor &A, const Function &F);
2071  };
2072  
getUniqueKernelFor(Function & F)2073  Kernel OpenMPOpt::getUniqueKernelFor(Function &F) {
2074    if (OMPInfoCache.CGSCC && !OMPInfoCache.CGSCC->empty() &&
2075        !OMPInfoCache.CGSCC->contains(&F))
2076      return nullptr;
2077  
2078    // Use a scope to keep the lifetime of the CachedKernel short.
2079    {
2080      std::optional<Kernel> &CachedKernel = UniqueKernelMap[&F];
2081      if (CachedKernel)
2082        return *CachedKernel;
2083  
2084      // TODO: We should use an AA to create an (optimistic and callback
2085      //       call-aware) call graph. For now we stick to simple patterns that
2086      //       are less powerful, basically the worst fixpoint.
2087      if (isOpenMPKernel(F)) {
2088        CachedKernel = Kernel(&F);
2089        return *CachedKernel;
2090      }
2091  
2092      CachedKernel = nullptr;
2093      if (!F.hasLocalLinkage()) {
2094  
2095        // See https://openmp.llvm.org/remarks/OptimizationRemarks.html
2096        auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2097          return ORA << "Potentially unknown OpenMP target region caller.";
2098        };
2099        emitRemark<OptimizationRemarkAnalysis>(&F, "OMP100", Remark);
2100  
2101        return nullptr;
2102      }
2103    }
2104  
2105    auto GetUniqueKernelForUse = [&](const Use &U) -> Kernel {
2106      if (auto *Cmp = dyn_cast<ICmpInst>(U.getUser())) {
2107        // Allow use in equality comparisons.
2108        if (Cmp->isEquality())
2109          return getUniqueKernelFor(*Cmp);
2110        return nullptr;
2111      }
2112      if (auto *CB = dyn_cast<CallBase>(U.getUser())) {
2113        // Allow direct calls.
2114        if (CB->isCallee(&U))
2115          return getUniqueKernelFor(*CB);
2116  
2117        OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
2118            OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
2119        // Allow the use in __kmpc_parallel_51 calls.
2120        if (OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI))
2121          return getUniqueKernelFor(*CB);
2122        return nullptr;
2123      }
2124      // Disallow every other use.
2125      return nullptr;
2126    };
2127  
2128    // TODO: In the future we want to track more than just a unique kernel.
2129    SmallPtrSet<Kernel, 2> PotentialKernels;
2130    OMPInformationCache::foreachUse(F, [&](const Use &U) {
2131      PotentialKernels.insert(GetUniqueKernelForUse(U));
2132    });
2133  
2134    Kernel K = nullptr;
2135    if (PotentialKernels.size() == 1)
2136      K = *PotentialKernels.begin();
2137  
2138    // Cache the result.
2139    UniqueKernelMap[&F] = K;
2140  
2141    return K;
2142  }
2143  
rewriteDeviceCodeStateMachine()2144  bool OpenMPOpt::rewriteDeviceCodeStateMachine() {
2145    OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
2146        OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
2147  
2148    bool Changed = false;
2149    if (!KernelParallelRFI)
2150      return Changed;
2151  
2152    // If we have disabled state machine changes, exit
2153    if (DisableOpenMPOptStateMachineRewrite)
2154      return Changed;
2155  
2156    for (Function *F : SCC) {
2157  
2158      // Check if the function is a use in a __kmpc_parallel_51 call at
2159      // all.
2160      bool UnknownUse = false;
2161      bool KernelParallelUse = false;
2162      unsigned NumDirectCalls = 0;
2163  
2164      SmallVector<Use *, 2> ToBeReplacedStateMachineUses;
2165      OMPInformationCache::foreachUse(*F, [&](Use &U) {
2166        if (auto *CB = dyn_cast<CallBase>(U.getUser()))
2167          if (CB->isCallee(&U)) {
2168            ++NumDirectCalls;
2169            return;
2170          }
2171  
2172        if (isa<ICmpInst>(U.getUser())) {
2173          ToBeReplacedStateMachineUses.push_back(&U);
2174          return;
2175        }
2176  
2177        // Find wrapper functions that represent parallel kernels.
2178        CallInst *CI =
2179            OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI);
2180        const unsigned int WrapperFunctionArgNo = 6;
2181        if (!KernelParallelUse && CI &&
2182            CI->getArgOperandNo(&U) == WrapperFunctionArgNo) {
2183          KernelParallelUse = true;
2184          ToBeReplacedStateMachineUses.push_back(&U);
2185          return;
2186        }
2187        UnknownUse = true;
2188      });
2189  
2190      // Do not emit a remark if we haven't seen a __kmpc_parallel_51
2191      // use.
2192      if (!KernelParallelUse)
2193        continue;
2194  
2195      // If this ever hits, we should investigate.
2196      // TODO: Checking the number of uses is not a necessary restriction and
2197      // should be lifted.
2198      if (UnknownUse || NumDirectCalls != 1 ||
2199          ToBeReplacedStateMachineUses.size() > 2) {
2200        auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2201          return ORA << "Parallel region is used in "
2202                     << (UnknownUse ? "unknown" : "unexpected")
2203                     << " ways. Will not attempt to rewrite the state machine.";
2204        };
2205        emitRemark<OptimizationRemarkAnalysis>(F, "OMP101", Remark);
2206        continue;
2207      }
2208  
2209      // Even if we have __kmpc_parallel_51 calls, we (for now) give
2210      // up if the function is not called from a unique kernel.
2211      Kernel K = getUniqueKernelFor(*F);
2212      if (!K) {
2213        auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2214          return ORA << "Parallel region is not called from a unique kernel. "
2215                        "Will not attempt to rewrite the state machine.";
2216        };
2217        emitRemark<OptimizationRemarkAnalysis>(F, "OMP102", Remark);
2218        continue;
2219      }
2220  
2221      // We now know F is a parallel body function called only from the kernel K.
2222      // We also identified the state machine uses in which we replace the
2223      // function pointer by a new global symbol for identification purposes. This
2224      // ensures only direct calls to the function are left.
2225  
2226      Module &M = *F->getParent();
2227      Type *Int8Ty = Type::getInt8Ty(M.getContext());
2228  
2229      auto *ID = new GlobalVariable(
2230          M, Int8Ty, /* isConstant */ true, GlobalValue::PrivateLinkage,
2231          UndefValue::get(Int8Ty), F->getName() + ".ID");
2232  
2233      for (Use *U : ToBeReplacedStateMachineUses)
2234        U->set(ConstantExpr::getPointerBitCastOrAddrSpaceCast(
2235            ID, U->get()->getType()));
2236  
2237      ++NumOpenMPParallelRegionsReplacedInGPUStateMachine;
2238  
2239      Changed = true;
2240    }
2241  
2242    return Changed;
2243  }
2244  
2245  /// Abstract Attribute for tracking ICV values.
2246  struct AAICVTracker : public StateWrapper<BooleanState, AbstractAttribute> {
2247    using Base = StateWrapper<BooleanState, AbstractAttribute>;
AAICVTracker__anondf4de5470111::AAICVTracker2248    AAICVTracker(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
2249  
2250    /// Returns true if value is assumed to be tracked.
isAssumedTracked__anondf4de5470111::AAICVTracker2251    bool isAssumedTracked() const { return getAssumed(); }
2252  
2253    /// Returns true if value is known to be tracked.
isKnownTracked__anondf4de5470111::AAICVTracker2254    bool isKnownTracked() const { return getAssumed(); }
2255  
2256    /// Create an abstract attribute biew for the position \p IRP.
2257    static AAICVTracker &createForPosition(const IRPosition &IRP, Attributor &A);
2258  
2259    /// Return the value with which \p I can be replaced for specific \p ICV.
getReplacementValue__anondf4de5470111::AAICVTracker2260    virtual std::optional<Value *> getReplacementValue(InternalControlVar ICV,
2261                                                       const Instruction *I,
2262                                                       Attributor &A) const {
2263      return std::nullopt;
2264    }
2265  
2266    /// Return an assumed unique ICV value if a single candidate is found. If
2267    /// there cannot be one, return a nullptr. If it is not clear yet, return
2268    /// std::nullopt.
2269    virtual std::optional<Value *>
2270    getUniqueReplacementValue(InternalControlVar ICV) const = 0;
2271  
2272    // Currently only nthreads is being tracked.
2273    // this array will only grow with time.
2274    InternalControlVar TrackableICVs[1] = {ICV_nthreads};
2275  
2276    /// See AbstractAttribute::getName()
getName__anondf4de5470111::AAICVTracker2277    const std::string getName() const override { return "AAICVTracker"; }
2278  
2279    /// See AbstractAttribute::getIdAddr()
getIdAddr__anondf4de5470111::AAICVTracker2280    const char *getIdAddr() const override { return &ID; }
2281  
2282    /// This function should return true if the type of the \p AA is AAICVTracker
classof__anondf4de5470111::AAICVTracker2283    static bool classof(const AbstractAttribute *AA) {
2284      return (AA->getIdAddr() == &ID);
2285    }
2286  
2287    static const char ID;
2288  };
2289  
2290  struct AAICVTrackerFunction : public AAICVTracker {
AAICVTrackerFunction__anondf4de5470111::AAICVTrackerFunction2291    AAICVTrackerFunction(const IRPosition &IRP, Attributor &A)
2292        : AAICVTracker(IRP, A) {}
2293  
2294    // FIXME: come up with better string.
getAsStr__anondf4de5470111::AAICVTrackerFunction2295    const std::string getAsStr(Attributor *) const override {
2296      return "ICVTrackerFunction";
2297    }
2298  
2299    // FIXME: come up with some stats.
trackStatistics__anondf4de5470111::AAICVTrackerFunction2300    void trackStatistics() const override {}
2301  
2302    /// We don't manifest anything for this AA.
manifest__anondf4de5470111::AAICVTrackerFunction2303    ChangeStatus manifest(Attributor &A) override {
2304      return ChangeStatus::UNCHANGED;
2305    }
2306  
2307    // Map of ICV to their values at specific program point.
2308    EnumeratedArray<DenseMap<Instruction *, Value *>, InternalControlVar,
2309                    InternalControlVar::ICV___last>
2310        ICVReplacementValuesMap;
2311  
updateImpl__anondf4de5470111::AAICVTrackerFunction2312    ChangeStatus updateImpl(Attributor &A) override {
2313      ChangeStatus HasChanged = ChangeStatus::UNCHANGED;
2314  
2315      Function *F = getAnchorScope();
2316  
2317      auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2318  
2319      for (InternalControlVar ICV : TrackableICVs) {
2320        auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2321  
2322        auto &ValuesMap = ICVReplacementValuesMap[ICV];
2323        auto TrackValues = [&](Use &U, Function &) {
2324          CallInst *CI = OpenMPOpt::getCallIfRegularCall(U);
2325          if (!CI)
2326            return false;
2327  
2328          // FIXME: handle setters with more that 1 arguments.
2329          /// Track new value.
2330          if (ValuesMap.insert(std::make_pair(CI, CI->getArgOperand(0))).second)
2331            HasChanged = ChangeStatus::CHANGED;
2332  
2333          return false;
2334        };
2335  
2336        auto CallCheck = [&](Instruction &I) {
2337          std::optional<Value *> ReplVal = getValueForCall(A, I, ICV);
2338          if (ReplVal && ValuesMap.insert(std::make_pair(&I, *ReplVal)).second)
2339            HasChanged = ChangeStatus::CHANGED;
2340  
2341          return true;
2342        };
2343  
2344        // Track all changes of an ICV.
2345        SetterRFI.foreachUse(TrackValues, F);
2346  
2347        bool UsedAssumedInformation = false;
2348        A.checkForAllInstructions(CallCheck, *this, {Instruction::Call},
2349                                  UsedAssumedInformation,
2350                                  /* CheckBBLivenessOnly */ true);
2351  
2352        /// TODO: Figure out a way to avoid adding entry in
2353        /// ICVReplacementValuesMap
2354        Instruction *Entry = &F->getEntryBlock().front();
2355        if (HasChanged == ChangeStatus::CHANGED && !ValuesMap.count(Entry))
2356          ValuesMap.insert(std::make_pair(Entry, nullptr));
2357      }
2358  
2359      return HasChanged;
2360    }
2361  
2362    /// Helper to check if \p I is a call and get the value for it if it is
2363    /// unique.
getValueForCall__anondf4de5470111::AAICVTrackerFunction2364    std::optional<Value *> getValueForCall(Attributor &A, const Instruction &I,
2365                                           InternalControlVar &ICV) const {
2366  
2367      const auto *CB = dyn_cast<CallBase>(&I);
2368      if (!CB || CB->hasFnAttr("no_openmp") ||
2369          CB->hasFnAttr("no_openmp_routines"))
2370        return std::nullopt;
2371  
2372      auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2373      auto &GetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Getter];
2374      auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2375      Function *CalledFunction = CB->getCalledFunction();
2376  
2377      // Indirect call, assume ICV changes.
2378      if (CalledFunction == nullptr)
2379        return nullptr;
2380      if (CalledFunction == GetterRFI.Declaration)
2381        return std::nullopt;
2382      if (CalledFunction == SetterRFI.Declaration) {
2383        if (ICVReplacementValuesMap[ICV].count(&I))
2384          return ICVReplacementValuesMap[ICV].lookup(&I);
2385  
2386        return nullptr;
2387      }
2388  
2389      // Since we don't know, assume it changes the ICV.
2390      if (CalledFunction->isDeclaration())
2391        return nullptr;
2392  
2393      const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2394          *this, IRPosition::callsite_returned(*CB), DepClassTy::REQUIRED);
2395  
2396      if (ICVTrackingAA->isAssumedTracked()) {
2397        std::optional<Value *> URV =
2398            ICVTrackingAA->getUniqueReplacementValue(ICV);
2399        if (!URV || (*URV && AA::isValidAtPosition(AA::ValueAndContext(**URV, I),
2400                                                   OMPInfoCache)))
2401          return URV;
2402      }
2403  
2404      // If we don't know, assume it changes.
2405      return nullptr;
2406    }
2407  
2408    // We don't check unique value for a function, so return std::nullopt.
2409    std::optional<Value *>
getUniqueReplacementValue__anondf4de5470111::AAICVTrackerFunction2410    getUniqueReplacementValue(InternalControlVar ICV) const override {
2411      return std::nullopt;
2412    }
2413  
2414    /// Return the value with which \p I can be replaced for specific \p ICV.
getReplacementValue__anondf4de5470111::AAICVTrackerFunction2415    std::optional<Value *> getReplacementValue(InternalControlVar ICV,
2416                                               const Instruction *I,
2417                                               Attributor &A) const override {
2418      const auto &ValuesMap = ICVReplacementValuesMap[ICV];
2419      if (ValuesMap.count(I))
2420        return ValuesMap.lookup(I);
2421  
2422      SmallVector<const Instruction *, 16> Worklist;
2423      SmallPtrSet<const Instruction *, 16> Visited;
2424      Worklist.push_back(I);
2425  
2426      std::optional<Value *> ReplVal;
2427  
2428      while (!Worklist.empty()) {
2429        const Instruction *CurrInst = Worklist.pop_back_val();
2430        if (!Visited.insert(CurrInst).second)
2431          continue;
2432  
2433        const BasicBlock *CurrBB = CurrInst->getParent();
2434  
2435        // Go up and look for all potential setters/calls that might change the
2436        // ICV.
2437        while ((CurrInst = CurrInst->getPrevNode())) {
2438          if (ValuesMap.count(CurrInst)) {
2439            std::optional<Value *> NewReplVal = ValuesMap.lookup(CurrInst);
2440            // Unknown value, track new.
2441            if (!ReplVal) {
2442              ReplVal = NewReplVal;
2443              break;
2444            }
2445  
2446            // If we found a new value, we can't know the icv value anymore.
2447            if (NewReplVal)
2448              if (ReplVal != NewReplVal)
2449                return nullptr;
2450  
2451            break;
2452          }
2453  
2454          std::optional<Value *> NewReplVal = getValueForCall(A, *CurrInst, ICV);
2455          if (!NewReplVal)
2456            continue;
2457  
2458          // Unknown value, track new.
2459          if (!ReplVal) {
2460            ReplVal = NewReplVal;
2461            break;
2462          }
2463  
2464          // if (NewReplVal.hasValue())
2465          // We found a new value, we can't know the icv value anymore.
2466          if (ReplVal != NewReplVal)
2467            return nullptr;
2468        }
2469  
2470        // If we are in the same BB and we have a value, we are done.
2471        if (CurrBB == I->getParent() && ReplVal)
2472          return ReplVal;
2473  
2474        // Go through all predecessors and add terminators for analysis.
2475        for (const BasicBlock *Pred : predecessors(CurrBB))
2476          if (const Instruction *Terminator = Pred->getTerminator())
2477            Worklist.push_back(Terminator);
2478      }
2479  
2480      return ReplVal;
2481    }
2482  };
2483  
2484  struct AAICVTrackerFunctionReturned : AAICVTracker {
AAICVTrackerFunctionReturned__anondf4de5470111::AAICVTrackerFunctionReturned2485    AAICVTrackerFunctionReturned(const IRPosition &IRP, Attributor &A)
2486        : AAICVTracker(IRP, A) {}
2487  
2488    // FIXME: come up with better string.
getAsStr__anondf4de5470111::AAICVTrackerFunctionReturned2489    const std::string getAsStr(Attributor *) const override {
2490      return "ICVTrackerFunctionReturned";
2491    }
2492  
2493    // FIXME: come up with some stats.
trackStatistics__anondf4de5470111::AAICVTrackerFunctionReturned2494    void trackStatistics() const override {}
2495  
2496    /// We don't manifest anything for this AA.
manifest__anondf4de5470111::AAICVTrackerFunctionReturned2497    ChangeStatus manifest(Attributor &A) override {
2498      return ChangeStatus::UNCHANGED;
2499    }
2500  
2501    // Map of ICV to their values at specific program point.
2502    EnumeratedArray<std::optional<Value *>, InternalControlVar,
2503                    InternalControlVar::ICV___last>
2504        ICVReplacementValuesMap;
2505  
2506    /// Return the value with which \p I can be replaced for specific \p ICV.
2507    std::optional<Value *>
getUniqueReplacementValue__anondf4de5470111::AAICVTrackerFunctionReturned2508    getUniqueReplacementValue(InternalControlVar ICV) const override {
2509      return ICVReplacementValuesMap[ICV];
2510    }
2511  
updateImpl__anondf4de5470111::AAICVTrackerFunctionReturned2512    ChangeStatus updateImpl(Attributor &A) override {
2513      ChangeStatus Changed = ChangeStatus::UNCHANGED;
2514      const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2515          *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2516  
2517      if (!ICVTrackingAA->isAssumedTracked())
2518        return indicatePessimisticFixpoint();
2519  
2520      for (InternalControlVar ICV : TrackableICVs) {
2521        std::optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2522        std::optional<Value *> UniqueICVValue;
2523  
2524        auto CheckReturnInst = [&](Instruction &I) {
2525          std::optional<Value *> NewReplVal =
2526              ICVTrackingAA->getReplacementValue(ICV, &I, A);
2527  
2528          // If we found a second ICV value there is no unique returned value.
2529          if (UniqueICVValue && UniqueICVValue != NewReplVal)
2530            return false;
2531  
2532          UniqueICVValue = NewReplVal;
2533  
2534          return true;
2535        };
2536  
2537        bool UsedAssumedInformation = false;
2538        if (!A.checkForAllInstructions(CheckReturnInst, *this, {Instruction::Ret},
2539                                       UsedAssumedInformation,
2540                                       /* CheckBBLivenessOnly */ true))
2541          UniqueICVValue = nullptr;
2542  
2543        if (UniqueICVValue == ReplVal)
2544          continue;
2545  
2546        ReplVal = UniqueICVValue;
2547        Changed = ChangeStatus::CHANGED;
2548      }
2549  
2550      return Changed;
2551    }
2552  };
2553  
2554  struct AAICVTrackerCallSite : AAICVTracker {
AAICVTrackerCallSite__anondf4de5470111::AAICVTrackerCallSite2555    AAICVTrackerCallSite(const IRPosition &IRP, Attributor &A)
2556        : AAICVTracker(IRP, A) {}
2557  
initialize__anondf4de5470111::AAICVTrackerCallSite2558    void initialize(Attributor &A) override {
2559      assert(getAnchorScope() && "Expected anchor function");
2560  
2561      // We only initialize this AA for getters, so we need to know which ICV it
2562      // gets.
2563      auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2564      for (InternalControlVar ICV : TrackableICVs) {
2565        auto ICVInfo = OMPInfoCache.ICVs[ICV];
2566        auto &Getter = OMPInfoCache.RFIs[ICVInfo.Getter];
2567        if (Getter.Declaration == getAssociatedFunction()) {
2568          AssociatedICV = ICVInfo.Kind;
2569          return;
2570        }
2571      }
2572  
2573      /// Unknown ICV.
2574      indicatePessimisticFixpoint();
2575    }
2576  
manifest__anondf4de5470111::AAICVTrackerCallSite2577    ChangeStatus manifest(Attributor &A) override {
2578      if (!ReplVal || !*ReplVal)
2579        return ChangeStatus::UNCHANGED;
2580  
2581      A.changeAfterManifest(IRPosition::inst(*getCtxI()), **ReplVal);
2582      A.deleteAfterManifest(*getCtxI());
2583  
2584      return ChangeStatus::CHANGED;
2585    }
2586  
2587    // FIXME: come up with better string.
getAsStr__anondf4de5470111::AAICVTrackerCallSite2588    const std::string getAsStr(Attributor *) const override {
2589      return "ICVTrackerCallSite";
2590    }
2591  
2592    // FIXME: come up with some stats.
trackStatistics__anondf4de5470111::AAICVTrackerCallSite2593    void trackStatistics() const override {}
2594  
2595    InternalControlVar AssociatedICV;
2596    std::optional<Value *> ReplVal;
2597  
updateImpl__anondf4de5470111::AAICVTrackerCallSite2598    ChangeStatus updateImpl(Attributor &A) override {
2599      const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2600          *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2601  
2602      // We don't have any information, so we assume it changes the ICV.
2603      if (!ICVTrackingAA->isAssumedTracked())
2604        return indicatePessimisticFixpoint();
2605  
2606      std::optional<Value *> NewReplVal =
2607          ICVTrackingAA->getReplacementValue(AssociatedICV, getCtxI(), A);
2608  
2609      if (ReplVal == NewReplVal)
2610        return ChangeStatus::UNCHANGED;
2611  
2612      ReplVal = NewReplVal;
2613      return ChangeStatus::CHANGED;
2614    }
2615  
2616    // Return the value with which associated value can be replaced for specific
2617    // \p ICV.
2618    std::optional<Value *>
getUniqueReplacementValue__anondf4de5470111::AAICVTrackerCallSite2619    getUniqueReplacementValue(InternalControlVar ICV) const override {
2620      return ReplVal;
2621    }
2622  };
2623  
2624  struct AAICVTrackerCallSiteReturned : AAICVTracker {
AAICVTrackerCallSiteReturned__anondf4de5470111::AAICVTrackerCallSiteReturned2625    AAICVTrackerCallSiteReturned(const IRPosition &IRP, Attributor &A)
2626        : AAICVTracker(IRP, A) {}
2627  
2628    // FIXME: come up with better string.
getAsStr__anondf4de5470111::AAICVTrackerCallSiteReturned2629    const std::string getAsStr(Attributor *) const override {
2630      return "ICVTrackerCallSiteReturned";
2631    }
2632  
2633    // FIXME: come up with some stats.
trackStatistics__anondf4de5470111::AAICVTrackerCallSiteReturned2634    void trackStatistics() const override {}
2635  
2636    /// We don't manifest anything for this AA.
manifest__anondf4de5470111::AAICVTrackerCallSiteReturned2637    ChangeStatus manifest(Attributor &A) override {
2638      return ChangeStatus::UNCHANGED;
2639    }
2640  
2641    // Map of ICV to their values at specific program point.
2642    EnumeratedArray<std::optional<Value *>, InternalControlVar,
2643                    InternalControlVar::ICV___last>
2644        ICVReplacementValuesMap;
2645  
2646    /// Return the value with which associated value can be replaced for specific
2647    /// \p ICV.
2648    std::optional<Value *>
getUniqueReplacementValue__anondf4de5470111::AAICVTrackerCallSiteReturned2649    getUniqueReplacementValue(InternalControlVar ICV) const override {
2650      return ICVReplacementValuesMap[ICV];
2651    }
2652  
updateImpl__anondf4de5470111::AAICVTrackerCallSiteReturned2653    ChangeStatus updateImpl(Attributor &A) override {
2654      ChangeStatus Changed = ChangeStatus::UNCHANGED;
2655      const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2656          *this, IRPosition::returned(*getAssociatedFunction()),
2657          DepClassTy::REQUIRED);
2658  
2659      // We don't have any information, so we assume it changes the ICV.
2660      if (!ICVTrackingAA->isAssumedTracked())
2661        return indicatePessimisticFixpoint();
2662  
2663      for (InternalControlVar ICV : TrackableICVs) {
2664        std::optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2665        std::optional<Value *> NewReplVal =
2666            ICVTrackingAA->getUniqueReplacementValue(ICV);
2667  
2668        if (ReplVal == NewReplVal)
2669          continue;
2670  
2671        ReplVal = NewReplVal;
2672        Changed = ChangeStatus::CHANGED;
2673      }
2674      return Changed;
2675    }
2676  };
2677  
2678  /// Determines if \p BB exits the function unconditionally itself or reaches a
2679  /// block that does through only unique successors.
hasFunctionEndAsUniqueSuccessor(const BasicBlock * BB)2680  static bool hasFunctionEndAsUniqueSuccessor(const BasicBlock *BB) {
2681    if (succ_empty(BB))
2682      return true;
2683    const BasicBlock *const Successor = BB->getUniqueSuccessor();
2684    if (!Successor)
2685      return false;
2686    return hasFunctionEndAsUniqueSuccessor(Successor);
2687  }
2688  
2689  struct AAExecutionDomainFunction : public AAExecutionDomain {
AAExecutionDomainFunction__anondf4de5470111::AAExecutionDomainFunction2690    AAExecutionDomainFunction(const IRPosition &IRP, Attributor &A)
2691        : AAExecutionDomain(IRP, A) {}
2692  
~AAExecutionDomainFunction__anondf4de5470111::AAExecutionDomainFunction2693    ~AAExecutionDomainFunction() { delete RPOT; }
2694  
initialize__anondf4de5470111::AAExecutionDomainFunction2695    void initialize(Attributor &A) override {
2696      Function *F = getAnchorScope();
2697      assert(F && "Expected anchor function");
2698      RPOT = new ReversePostOrderTraversal<Function *>(F);
2699    }
2700  
getAsStr__anondf4de5470111::AAExecutionDomainFunction2701    const std::string getAsStr(Attributor *) const override {
2702      unsigned TotalBlocks = 0, InitialThreadBlocks = 0, AlignedBlocks = 0;
2703      for (auto &It : BEDMap) {
2704        if (!It.getFirst())
2705          continue;
2706        TotalBlocks++;
2707        InitialThreadBlocks += It.getSecond().IsExecutedByInitialThreadOnly;
2708        AlignedBlocks += It.getSecond().IsReachedFromAlignedBarrierOnly &&
2709                         It.getSecond().IsReachingAlignedBarrierOnly;
2710      }
2711      return "[AAExecutionDomain] " + std::to_string(InitialThreadBlocks) + "/" +
2712             std::to_string(AlignedBlocks) + " of " +
2713             std::to_string(TotalBlocks) +
2714             " executed by initial thread / aligned";
2715    }
2716  
2717    /// See AbstractAttribute::trackStatistics().
trackStatistics__anondf4de5470111::AAExecutionDomainFunction2718    void trackStatistics() const override {}
2719  
manifest__anondf4de5470111::AAExecutionDomainFunction2720    ChangeStatus manifest(Attributor &A) override {
2721      LLVM_DEBUG({
2722        for (const BasicBlock &BB : *getAnchorScope()) {
2723          if (!isExecutedByInitialThreadOnly(BB))
2724            continue;
2725          dbgs() << TAG << " Basic block @" << getAnchorScope()->getName() << " "
2726                 << BB.getName() << " is executed by a single thread.\n";
2727        }
2728      });
2729  
2730      ChangeStatus Changed = ChangeStatus::UNCHANGED;
2731  
2732      if (DisableOpenMPOptBarrierElimination)
2733        return Changed;
2734  
2735      SmallPtrSet<CallBase *, 16> DeletedBarriers;
2736      auto HandleAlignedBarrier = [&](CallBase *CB) {
2737        const ExecutionDomainTy &ED = CB ? CEDMap[{CB, PRE}] : BEDMap[nullptr];
2738        if (!ED.IsReachedFromAlignedBarrierOnly ||
2739            ED.EncounteredNonLocalSideEffect)
2740          return;
2741        if (!ED.EncounteredAssumes.empty() && !A.isModulePass())
2742          return;
2743  
2744        // We can remove this barrier, if it is one, or aligned barriers reaching
2745        // the kernel end (if CB is nullptr). Aligned barriers reaching the kernel
2746        // end should only be removed if the kernel end is their unique successor;
2747        // otherwise, they may have side-effects that aren't accounted for in the
2748        // kernel end in their other successors. If those barriers have other
2749        // barriers reaching them, those can be transitively removed as well as
2750        // long as the kernel end is also their unique successor.
2751        if (CB) {
2752          DeletedBarriers.insert(CB);
2753          A.deleteAfterManifest(*CB);
2754          ++NumBarriersEliminated;
2755          Changed = ChangeStatus::CHANGED;
2756        } else if (!ED.AlignedBarriers.empty()) {
2757          Changed = ChangeStatus::CHANGED;
2758          SmallVector<CallBase *> Worklist(ED.AlignedBarriers.begin(),
2759                                           ED.AlignedBarriers.end());
2760          SmallSetVector<CallBase *, 16> Visited;
2761          while (!Worklist.empty()) {
2762            CallBase *LastCB = Worklist.pop_back_val();
2763            if (!Visited.insert(LastCB))
2764              continue;
2765            if (LastCB->getFunction() != getAnchorScope())
2766              continue;
2767            if (!hasFunctionEndAsUniqueSuccessor(LastCB->getParent()))
2768              continue;
2769            if (!DeletedBarriers.count(LastCB)) {
2770              ++NumBarriersEliminated;
2771              A.deleteAfterManifest(*LastCB);
2772              continue;
2773            }
2774            // The final aligned barrier (LastCB) reaching the kernel end was
2775            // removed already. This means we can go one step further and remove
2776            // the barriers encoutered last before (LastCB).
2777            const ExecutionDomainTy &LastED = CEDMap[{LastCB, PRE}];
2778            Worklist.append(LastED.AlignedBarriers.begin(),
2779                            LastED.AlignedBarriers.end());
2780          }
2781        }
2782  
2783        // If we actually eliminated a barrier we need to eliminate the associated
2784        // llvm.assumes as well to avoid creating UB.
2785        if (!ED.EncounteredAssumes.empty() && (CB || !ED.AlignedBarriers.empty()))
2786          for (auto *AssumeCB : ED.EncounteredAssumes)
2787            A.deleteAfterManifest(*AssumeCB);
2788      };
2789  
2790      for (auto *CB : AlignedBarriers)
2791        HandleAlignedBarrier(CB);
2792  
2793      // Handle the "kernel end barrier" for kernels too.
2794      if (omp::isOpenMPKernel(*getAnchorScope()))
2795        HandleAlignedBarrier(nullptr);
2796  
2797      return Changed;
2798    }
2799  
isNoOpFence__anondf4de5470111::AAExecutionDomainFunction2800    bool isNoOpFence(const FenceInst &FI) const override {
2801      return getState().isValidState() && !NonNoOpFences.count(&FI);
2802    }
2803  
2804    /// Merge barrier and assumption information from \p PredED into the successor
2805    /// \p ED.
2806    void
2807    mergeInPredecessorBarriersAndAssumptions(Attributor &A, ExecutionDomainTy &ED,
2808                                             const ExecutionDomainTy &PredED);
2809  
2810    /// Merge all information from \p PredED into the successor \p ED. If
2811    /// \p InitialEdgeOnly is set, only the initial edge will enter the block
2812    /// represented by \p ED from this predecessor.
2813    bool mergeInPredecessor(Attributor &A, ExecutionDomainTy &ED,
2814                            const ExecutionDomainTy &PredED,
2815                            bool InitialEdgeOnly = false);
2816  
2817    /// Accumulate information for the entry block in \p EntryBBED.
2818    bool handleCallees(Attributor &A, ExecutionDomainTy &EntryBBED);
2819  
2820    /// See AbstractAttribute::updateImpl.
2821    ChangeStatus updateImpl(Attributor &A) override;
2822  
2823    /// Query interface, see AAExecutionDomain
2824    ///{
isExecutedByInitialThreadOnly__anondf4de5470111::AAExecutionDomainFunction2825    bool isExecutedByInitialThreadOnly(const BasicBlock &BB) const override {
2826      if (!isValidState())
2827        return false;
2828      assert(BB.getParent() == getAnchorScope() && "Block is out of scope!");
2829      return BEDMap.lookup(&BB).IsExecutedByInitialThreadOnly;
2830    }
2831  
isExecutedInAlignedRegion__anondf4de5470111::AAExecutionDomainFunction2832    bool isExecutedInAlignedRegion(Attributor &A,
2833                                   const Instruction &I) const override {
2834      assert(I.getFunction() == getAnchorScope() &&
2835             "Instruction is out of scope!");
2836      if (!isValidState())
2837        return false;
2838  
2839      bool ForwardIsOk = true;
2840      const Instruction *CurI;
2841  
2842      // Check forward until a call or the block end is reached.
2843      CurI = &I;
2844      do {
2845        auto *CB = dyn_cast<CallBase>(CurI);
2846        if (!CB)
2847          continue;
2848        if (CB != &I && AlignedBarriers.contains(const_cast<CallBase *>(CB)))
2849          return true;
2850        const auto &It = CEDMap.find({CB, PRE});
2851        if (It == CEDMap.end())
2852          continue;
2853        if (!It->getSecond().IsReachingAlignedBarrierOnly)
2854          ForwardIsOk = false;
2855        break;
2856      } while ((CurI = CurI->getNextNonDebugInstruction()));
2857  
2858      if (!CurI && !BEDMap.lookup(I.getParent()).IsReachingAlignedBarrierOnly)
2859        ForwardIsOk = false;
2860  
2861      // Check backward until a call or the block beginning is reached.
2862      CurI = &I;
2863      do {
2864        auto *CB = dyn_cast<CallBase>(CurI);
2865        if (!CB)
2866          continue;
2867        if (CB != &I && AlignedBarriers.contains(const_cast<CallBase *>(CB)))
2868          return true;
2869        const auto &It = CEDMap.find({CB, POST});
2870        if (It == CEDMap.end())
2871          continue;
2872        if (It->getSecond().IsReachedFromAlignedBarrierOnly)
2873          break;
2874        return false;
2875      } while ((CurI = CurI->getPrevNonDebugInstruction()));
2876  
2877      // Delayed decision on the forward pass to allow aligned barrier detection
2878      // in the backwards traversal.
2879      if (!ForwardIsOk)
2880        return false;
2881  
2882      if (!CurI) {
2883        const BasicBlock *BB = I.getParent();
2884        if (BB == &BB->getParent()->getEntryBlock())
2885          return BEDMap.lookup(nullptr).IsReachedFromAlignedBarrierOnly;
2886        if (!llvm::all_of(predecessors(BB), [&](const BasicBlock *PredBB) {
2887              return BEDMap.lookup(PredBB).IsReachedFromAlignedBarrierOnly;
2888            })) {
2889          return false;
2890        }
2891      }
2892  
2893      // On neither traversal we found a anything but aligned barriers.
2894      return true;
2895    }
2896  
getExecutionDomain__anondf4de5470111::AAExecutionDomainFunction2897    ExecutionDomainTy getExecutionDomain(const BasicBlock &BB) const override {
2898      assert(isValidState() &&
2899             "No request should be made against an invalid state!");
2900      return BEDMap.lookup(&BB);
2901    }
2902    std::pair<ExecutionDomainTy, ExecutionDomainTy>
getExecutionDomain__anondf4de5470111::AAExecutionDomainFunction2903    getExecutionDomain(const CallBase &CB) const override {
2904      assert(isValidState() &&
2905             "No request should be made against an invalid state!");
2906      return {CEDMap.lookup({&CB, PRE}), CEDMap.lookup({&CB, POST})};
2907    }
getFunctionExecutionDomain__anondf4de5470111::AAExecutionDomainFunction2908    ExecutionDomainTy getFunctionExecutionDomain() const override {
2909      assert(isValidState() &&
2910             "No request should be made against an invalid state!");
2911      return InterProceduralED;
2912    }
2913    ///}
2914  
2915    // Check if the edge into the successor block contains a condition that only
2916    // lets the main thread execute it.
isInitialThreadOnlyEdge__anondf4de5470111::AAExecutionDomainFunction2917    static bool isInitialThreadOnlyEdge(Attributor &A, BranchInst *Edge,
2918                                        BasicBlock &SuccessorBB) {
2919      if (!Edge || !Edge->isConditional())
2920        return false;
2921      if (Edge->getSuccessor(0) != &SuccessorBB)
2922        return false;
2923  
2924      auto *Cmp = dyn_cast<CmpInst>(Edge->getCondition());
2925      if (!Cmp || !Cmp->isTrueWhenEqual() || !Cmp->isEquality())
2926        return false;
2927  
2928      ConstantInt *C = dyn_cast<ConstantInt>(Cmp->getOperand(1));
2929      if (!C)
2930        return false;
2931  
2932      // Match: -1 == __kmpc_target_init (for non-SPMD kernels only!)
2933      if (C->isAllOnesValue()) {
2934        auto *CB = dyn_cast<CallBase>(Cmp->getOperand(0));
2935        auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2936        auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
2937        CB = CB ? OpenMPOpt::getCallIfRegularCall(*CB, &RFI) : nullptr;
2938        if (!CB)
2939          return false;
2940        ConstantStruct *KernelEnvC =
2941            KernelInfo::getKernelEnvironementFromKernelInitCB(CB);
2942        ConstantInt *ExecModeC =
2943            KernelInfo::getExecModeFromKernelEnvironment(KernelEnvC);
2944        return ExecModeC->getSExtValue() & OMP_TGT_EXEC_MODE_GENERIC;
2945      }
2946  
2947      if (C->isZero()) {
2948        // Match: 0 == llvm.nvvm.read.ptx.sreg.tid.x()
2949        if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
2950          if (II->getIntrinsicID() == Intrinsic::nvvm_read_ptx_sreg_tid_x)
2951            return true;
2952  
2953        // Match: 0 == llvm.amdgcn.workitem.id.x()
2954        if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
2955          if (II->getIntrinsicID() == Intrinsic::amdgcn_workitem_id_x)
2956            return true;
2957      }
2958  
2959      return false;
2960    };
2961  
2962    /// Mapping containing information about the function for other AAs.
2963    ExecutionDomainTy InterProceduralED;
2964  
2965    enum Direction { PRE = 0, POST = 1 };
2966    /// Mapping containing information per block.
2967    DenseMap<const BasicBlock *, ExecutionDomainTy> BEDMap;
2968    DenseMap<PointerIntPair<const CallBase *, 1, Direction>, ExecutionDomainTy>
2969        CEDMap;
2970    SmallSetVector<CallBase *, 16> AlignedBarriers;
2971  
2972    ReversePostOrderTraversal<Function *> *RPOT = nullptr;
2973  
2974    /// Set \p R to \V and report true if that changed \p R.
setAndRecord__anondf4de5470111::AAExecutionDomainFunction2975    static bool setAndRecord(bool &R, bool V) {
2976      bool Eq = (R == V);
2977      R = V;
2978      return !Eq;
2979    }
2980  
2981    /// Collection of fences known to be non-no-opt. All fences not in this set
2982    /// can be assumed no-opt.
2983    SmallPtrSet<const FenceInst *, 8> NonNoOpFences;
2984  };
2985  
mergeInPredecessorBarriersAndAssumptions(Attributor & A,ExecutionDomainTy & ED,const ExecutionDomainTy & PredED)2986  void AAExecutionDomainFunction::mergeInPredecessorBarriersAndAssumptions(
2987      Attributor &A, ExecutionDomainTy &ED, const ExecutionDomainTy &PredED) {
2988    for (auto *EA : PredED.EncounteredAssumes)
2989      ED.addAssumeInst(A, *EA);
2990  
2991    for (auto *AB : PredED.AlignedBarriers)
2992      ED.addAlignedBarrier(A, *AB);
2993  }
2994  
mergeInPredecessor(Attributor & A,ExecutionDomainTy & ED,const ExecutionDomainTy & PredED,bool InitialEdgeOnly)2995  bool AAExecutionDomainFunction::mergeInPredecessor(
2996      Attributor &A, ExecutionDomainTy &ED, const ExecutionDomainTy &PredED,
2997      bool InitialEdgeOnly) {
2998  
2999    bool Changed = false;
3000    Changed |=
3001        setAndRecord(ED.IsExecutedByInitialThreadOnly,
3002                     InitialEdgeOnly || (PredED.IsExecutedByInitialThreadOnly &&
3003                                         ED.IsExecutedByInitialThreadOnly));
3004  
3005    Changed |= setAndRecord(ED.IsReachedFromAlignedBarrierOnly,
3006                            ED.IsReachedFromAlignedBarrierOnly &&
3007                                PredED.IsReachedFromAlignedBarrierOnly);
3008    Changed |= setAndRecord(ED.EncounteredNonLocalSideEffect,
3009                            ED.EncounteredNonLocalSideEffect |
3010                                PredED.EncounteredNonLocalSideEffect);
3011    // Do not track assumptions and barriers as part of Changed.
3012    if (ED.IsReachedFromAlignedBarrierOnly)
3013      mergeInPredecessorBarriersAndAssumptions(A, ED, PredED);
3014    else
3015      ED.clearAssumeInstAndAlignedBarriers();
3016    return Changed;
3017  }
3018  
handleCallees(Attributor & A,ExecutionDomainTy & EntryBBED)3019  bool AAExecutionDomainFunction::handleCallees(Attributor &A,
3020                                                ExecutionDomainTy &EntryBBED) {
3021    SmallVector<std::pair<ExecutionDomainTy, ExecutionDomainTy>, 4> CallSiteEDs;
3022    auto PredForCallSite = [&](AbstractCallSite ACS) {
3023      const auto *EDAA = A.getAAFor<AAExecutionDomain>(
3024          *this, IRPosition::function(*ACS.getInstruction()->getFunction()),
3025          DepClassTy::OPTIONAL);
3026      if (!EDAA || !EDAA->getState().isValidState())
3027        return false;
3028      CallSiteEDs.emplace_back(
3029          EDAA->getExecutionDomain(*cast<CallBase>(ACS.getInstruction())));
3030      return true;
3031    };
3032  
3033    ExecutionDomainTy ExitED;
3034    bool AllCallSitesKnown;
3035    if (A.checkForAllCallSites(PredForCallSite, *this,
3036                               /* RequiresAllCallSites */ true,
3037                               AllCallSitesKnown)) {
3038      for (const auto &[CSInED, CSOutED] : CallSiteEDs) {
3039        mergeInPredecessor(A, EntryBBED, CSInED);
3040        ExitED.IsReachingAlignedBarrierOnly &=
3041            CSOutED.IsReachingAlignedBarrierOnly;
3042      }
3043  
3044    } else {
3045      // We could not find all predecessors, so this is either a kernel or a
3046      // function with external linkage (or with some other weird uses).
3047      if (omp::isOpenMPKernel(*getAnchorScope())) {
3048        EntryBBED.IsExecutedByInitialThreadOnly = false;
3049        EntryBBED.IsReachedFromAlignedBarrierOnly = true;
3050        EntryBBED.EncounteredNonLocalSideEffect = false;
3051        ExitED.IsReachingAlignedBarrierOnly = false;
3052      } else {
3053        EntryBBED.IsExecutedByInitialThreadOnly = false;
3054        EntryBBED.IsReachedFromAlignedBarrierOnly = false;
3055        EntryBBED.EncounteredNonLocalSideEffect = true;
3056        ExitED.IsReachingAlignedBarrierOnly = false;
3057      }
3058    }
3059  
3060    bool Changed = false;
3061    auto &FnED = BEDMap[nullptr];
3062    Changed |= setAndRecord(FnED.IsReachedFromAlignedBarrierOnly,
3063                            FnED.IsReachedFromAlignedBarrierOnly &
3064                                EntryBBED.IsReachedFromAlignedBarrierOnly);
3065    Changed |= setAndRecord(FnED.IsReachingAlignedBarrierOnly,
3066                            FnED.IsReachingAlignedBarrierOnly &
3067                                ExitED.IsReachingAlignedBarrierOnly);
3068    Changed |= setAndRecord(FnED.IsExecutedByInitialThreadOnly,
3069                            EntryBBED.IsExecutedByInitialThreadOnly);
3070    return Changed;
3071  }
3072  
updateImpl(Attributor & A)3073  ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
3074  
3075    bool Changed = false;
3076  
3077    // Helper to deal with an aligned barrier encountered during the forward
3078    // traversal. \p CB is the aligned barrier, \p ED is the execution domain when
3079    // it was encountered.
3080    auto HandleAlignedBarrier = [&](CallBase &CB, ExecutionDomainTy &ED) {
3081      Changed |= AlignedBarriers.insert(&CB);
3082      // First, update the barrier ED kept in the separate CEDMap.
3083      auto &CallInED = CEDMap[{&CB, PRE}];
3084      Changed |= mergeInPredecessor(A, CallInED, ED);
3085      CallInED.IsReachingAlignedBarrierOnly = true;
3086      // Next adjust the ED we use for the traversal.
3087      ED.EncounteredNonLocalSideEffect = false;
3088      ED.IsReachedFromAlignedBarrierOnly = true;
3089      // Aligned barrier collection has to come last.
3090      ED.clearAssumeInstAndAlignedBarriers();
3091      ED.addAlignedBarrier(A, CB);
3092      auto &CallOutED = CEDMap[{&CB, POST}];
3093      Changed |= mergeInPredecessor(A, CallOutED, ED);
3094    };
3095  
3096    auto *LivenessAA =
3097        A.getAAFor<AAIsDead>(*this, getIRPosition(), DepClassTy::OPTIONAL);
3098  
3099    Function *F = getAnchorScope();
3100    BasicBlock &EntryBB = F->getEntryBlock();
3101    bool IsKernel = omp::isOpenMPKernel(*F);
3102  
3103    SmallVector<Instruction *> SyncInstWorklist;
3104    for (auto &RIt : *RPOT) {
3105      BasicBlock &BB = *RIt;
3106  
3107      bool IsEntryBB = &BB == &EntryBB;
3108      // TODO: We use local reasoning since we don't have a divergence analysis
3109      // 	     running as well. We could basically allow uniform branches here.
3110      bool AlignedBarrierLastInBlock = IsEntryBB && IsKernel;
3111      bool IsExplicitlyAligned = IsEntryBB && IsKernel;
3112      ExecutionDomainTy ED;
3113      // Propagate "incoming edges" into information about this block.
3114      if (IsEntryBB) {
3115        Changed |= handleCallees(A, ED);
3116      } else {
3117        // For live non-entry blocks we only propagate
3118        // information via live edges.
3119        if (LivenessAA && LivenessAA->isAssumedDead(&BB))
3120          continue;
3121  
3122        for (auto *PredBB : predecessors(&BB)) {
3123          if (LivenessAA && LivenessAA->isEdgeDead(PredBB, &BB))
3124            continue;
3125          bool InitialEdgeOnly = isInitialThreadOnlyEdge(
3126              A, dyn_cast<BranchInst>(PredBB->getTerminator()), BB);
3127          mergeInPredecessor(A, ED, BEDMap[PredBB], InitialEdgeOnly);
3128        }
3129      }
3130  
3131      // Now we traverse the block, accumulate effects in ED and attach
3132      // information to calls.
3133      for (Instruction &I : BB) {
3134        bool UsedAssumedInformation;
3135        if (A.isAssumedDead(I, *this, LivenessAA, UsedAssumedInformation,
3136                            /* CheckBBLivenessOnly */ false, DepClassTy::OPTIONAL,
3137                            /* CheckForDeadStore */ true))
3138          continue;
3139  
3140        // Asummes and "assume-like" (dbg, lifetime, ...) are handled first, the
3141        // former is collected the latter is ignored.
3142        if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
3143          if (auto *AI = dyn_cast_or_null<AssumeInst>(II)) {
3144            ED.addAssumeInst(A, *AI);
3145            continue;
3146          }
3147          // TODO: Should we also collect and delete lifetime markers?
3148          if (II->isAssumeLikeIntrinsic())
3149            continue;
3150        }
3151  
3152        if (auto *FI = dyn_cast<FenceInst>(&I)) {
3153          if (!ED.EncounteredNonLocalSideEffect) {
3154            // An aligned fence without non-local side-effects is a no-op.
3155            if (ED.IsReachedFromAlignedBarrierOnly)
3156              continue;
3157            // A non-aligned fence without non-local side-effects is a no-op
3158            // if the ordering only publishes non-local side-effects (or less).
3159            switch (FI->getOrdering()) {
3160            case AtomicOrdering::NotAtomic:
3161              continue;
3162            case AtomicOrdering::Unordered:
3163              continue;
3164            case AtomicOrdering::Monotonic:
3165              continue;
3166            case AtomicOrdering::Acquire:
3167              break;
3168            case AtomicOrdering::Release:
3169              continue;
3170            case AtomicOrdering::AcquireRelease:
3171              break;
3172            case AtomicOrdering::SequentiallyConsistent:
3173              break;
3174            };
3175          }
3176          NonNoOpFences.insert(FI);
3177        }
3178  
3179        auto *CB = dyn_cast<CallBase>(&I);
3180        bool IsNoSync = AA::isNoSyncInst(A, I, *this);
3181        bool IsAlignedBarrier =
3182            !IsNoSync && CB &&
3183            AANoSync::isAlignedBarrier(*CB, AlignedBarrierLastInBlock);
3184  
3185        AlignedBarrierLastInBlock &= IsNoSync;
3186        IsExplicitlyAligned &= IsNoSync;
3187  
3188        // Next we check for calls. Aligned barriers are handled
3189        // explicitly, everything else is kept for the backward traversal and will
3190        // also affect our state.
3191        if (CB) {
3192          if (IsAlignedBarrier) {
3193            HandleAlignedBarrier(*CB, ED);
3194            AlignedBarrierLastInBlock = true;
3195            IsExplicitlyAligned = true;
3196            continue;
3197          }
3198  
3199          // Check the pointer(s) of a memory intrinsic explicitly.
3200          if (isa<MemIntrinsic>(&I)) {
3201            if (!ED.EncounteredNonLocalSideEffect &&
3202                AA::isPotentiallyAffectedByBarrier(A, I, *this))
3203              ED.EncounteredNonLocalSideEffect = true;
3204            if (!IsNoSync) {
3205              ED.IsReachedFromAlignedBarrierOnly = false;
3206              SyncInstWorklist.push_back(&I);
3207            }
3208            continue;
3209          }
3210  
3211          // Record how we entered the call, then accumulate the effect of the
3212          // call in ED for potential use by the callee.
3213          auto &CallInED = CEDMap[{CB, PRE}];
3214          Changed |= mergeInPredecessor(A, CallInED, ED);
3215  
3216          // If we have a sync-definition we can check if it starts/ends in an
3217          // aligned barrier. If we are unsure we assume any sync breaks
3218          // alignment.
3219          Function *Callee = CB->getCalledFunction();
3220          if (!IsNoSync && Callee && !Callee->isDeclaration()) {
3221            const auto *EDAA = A.getAAFor<AAExecutionDomain>(
3222                *this, IRPosition::function(*Callee), DepClassTy::OPTIONAL);
3223            if (EDAA && EDAA->getState().isValidState()) {
3224              const auto &CalleeED = EDAA->getFunctionExecutionDomain();
3225              ED.IsReachedFromAlignedBarrierOnly =
3226                  CalleeED.IsReachedFromAlignedBarrierOnly;
3227              AlignedBarrierLastInBlock = ED.IsReachedFromAlignedBarrierOnly;
3228              if (IsNoSync || !CalleeED.IsReachedFromAlignedBarrierOnly)
3229                ED.EncounteredNonLocalSideEffect |=
3230                    CalleeED.EncounteredNonLocalSideEffect;
3231              else
3232                ED.EncounteredNonLocalSideEffect =
3233                    CalleeED.EncounteredNonLocalSideEffect;
3234              if (!CalleeED.IsReachingAlignedBarrierOnly) {
3235                Changed |=
3236                    setAndRecord(CallInED.IsReachingAlignedBarrierOnly, false);
3237                SyncInstWorklist.push_back(&I);
3238              }
3239              if (CalleeED.IsReachedFromAlignedBarrierOnly)
3240                mergeInPredecessorBarriersAndAssumptions(A, ED, CalleeED);
3241              auto &CallOutED = CEDMap[{CB, POST}];
3242              Changed |= mergeInPredecessor(A, CallOutED, ED);
3243              continue;
3244            }
3245          }
3246          if (!IsNoSync) {
3247            ED.IsReachedFromAlignedBarrierOnly = false;
3248            Changed |= setAndRecord(CallInED.IsReachingAlignedBarrierOnly, false);
3249            SyncInstWorklist.push_back(&I);
3250          }
3251          AlignedBarrierLastInBlock &= ED.IsReachedFromAlignedBarrierOnly;
3252          ED.EncounteredNonLocalSideEffect |= !CB->doesNotAccessMemory();
3253          auto &CallOutED = CEDMap[{CB, POST}];
3254          Changed |= mergeInPredecessor(A, CallOutED, ED);
3255        }
3256  
3257        if (!I.mayHaveSideEffects() && !I.mayReadFromMemory())
3258          continue;
3259  
3260        // If we have a callee we try to use fine-grained information to
3261        // determine local side-effects.
3262        if (CB) {
3263          const auto *MemAA = A.getAAFor<AAMemoryLocation>(
3264              *this, IRPosition::callsite_function(*CB), DepClassTy::OPTIONAL);
3265  
3266          auto AccessPred = [&](const Instruction *I, const Value *Ptr,
3267                                AAMemoryLocation::AccessKind,
3268                                AAMemoryLocation::MemoryLocationsKind) {
3269            return !AA::isPotentiallyAffectedByBarrier(A, {Ptr}, *this, I);
3270          };
3271          if (MemAA && MemAA->getState().isValidState() &&
3272              MemAA->checkForAllAccessesToMemoryKind(
3273                  AccessPred, AAMemoryLocation::ALL_LOCATIONS))
3274            continue;
3275        }
3276  
3277        auto &InfoCache = A.getInfoCache();
3278        if (!I.mayHaveSideEffects() && InfoCache.isOnlyUsedByAssume(I))
3279          continue;
3280  
3281        if (auto *LI = dyn_cast<LoadInst>(&I))
3282          if (LI->hasMetadata(LLVMContext::MD_invariant_load))
3283            continue;
3284  
3285        if (!ED.EncounteredNonLocalSideEffect &&
3286            AA::isPotentiallyAffectedByBarrier(A, I, *this))
3287          ED.EncounteredNonLocalSideEffect = true;
3288      }
3289  
3290      bool IsEndAndNotReachingAlignedBarriersOnly = false;
3291      if (!isa<UnreachableInst>(BB.getTerminator()) &&
3292          !BB.getTerminator()->getNumSuccessors()) {
3293  
3294        Changed |= mergeInPredecessor(A, InterProceduralED, ED);
3295  
3296        auto &FnED = BEDMap[nullptr];
3297        if (IsKernel && !IsExplicitlyAligned)
3298          FnED.IsReachingAlignedBarrierOnly = false;
3299        Changed |= mergeInPredecessor(A, FnED, ED);
3300  
3301        if (!FnED.IsReachingAlignedBarrierOnly) {
3302          IsEndAndNotReachingAlignedBarriersOnly = true;
3303          SyncInstWorklist.push_back(BB.getTerminator());
3304          auto &BBED = BEDMap[&BB];
3305          Changed |= setAndRecord(BBED.IsReachingAlignedBarrierOnly, false);
3306        }
3307      }
3308  
3309      ExecutionDomainTy &StoredED = BEDMap[&BB];
3310      ED.IsReachingAlignedBarrierOnly = StoredED.IsReachingAlignedBarrierOnly &
3311                                        !IsEndAndNotReachingAlignedBarriersOnly;
3312  
3313      // Check if we computed anything different as part of the forward
3314      // traversal. We do not take assumptions and aligned barriers into account
3315      // as they do not influence the state we iterate. Backward traversal values
3316      // are handled later on.
3317      if (ED.IsExecutedByInitialThreadOnly !=
3318              StoredED.IsExecutedByInitialThreadOnly ||
3319          ED.IsReachedFromAlignedBarrierOnly !=
3320              StoredED.IsReachedFromAlignedBarrierOnly ||
3321          ED.EncounteredNonLocalSideEffect !=
3322              StoredED.EncounteredNonLocalSideEffect)
3323        Changed = true;
3324  
3325      // Update the state with the new value.
3326      StoredED = std::move(ED);
3327    }
3328  
3329    // Propagate (non-aligned) sync instruction effects backwards until the
3330    // entry is hit or an aligned barrier.
3331    SmallSetVector<BasicBlock *, 16> Visited;
3332    while (!SyncInstWorklist.empty()) {
3333      Instruction *SyncInst = SyncInstWorklist.pop_back_val();
3334      Instruction *CurInst = SyncInst;
3335      bool HitAlignedBarrierOrKnownEnd = false;
3336      while ((CurInst = CurInst->getPrevNode())) {
3337        auto *CB = dyn_cast<CallBase>(CurInst);
3338        if (!CB)
3339          continue;
3340        auto &CallOutED = CEDMap[{CB, POST}];
3341        Changed |= setAndRecord(CallOutED.IsReachingAlignedBarrierOnly, false);
3342        auto &CallInED = CEDMap[{CB, PRE}];
3343        HitAlignedBarrierOrKnownEnd =
3344            AlignedBarriers.count(CB) || !CallInED.IsReachingAlignedBarrierOnly;
3345        if (HitAlignedBarrierOrKnownEnd)
3346          break;
3347        Changed |= setAndRecord(CallInED.IsReachingAlignedBarrierOnly, false);
3348      }
3349      if (HitAlignedBarrierOrKnownEnd)
3350        continue;
3351      BasicBlock *SyncBB = SyncInst->getParent();
3352      for (auto *PredBB : predecessors(SyncBB)) {
3353        if (LivenessAA && LivenessAA->isEdgeDead(PredBB, SyncBB))
3354          continue;
3355        if (!Visited.insert(PredBB))
3356          continue;
3357        auto &PredED = BEDMap[PredBB];
3358        if (setAndRecord(PredED.IsReachingAlignedBarrierOnly, false)) {
3359          Changed = true;
3360          SyncInstWorklist.push_back(PredBB->getTerminator());
3361        }
3362      }
3363      if (SyncBB != &EntryBB)
3364        continue;
3365      Changed |=
3366          setAndRecord(InterProceduralED.IsReachingAlignedBarrierOnly, false);
3367    }
3368  
3369    return Changed ? ChangeStatus::CHANGED : ChangeStatus::UNCHANGED;
3370  }
3371  
3372  /// Try to replace memory allocation calls called by a single thread with a
3373  /// static buffer of shared memory.
3374  struct AAHeapToShared : public StateWrapper<BooleanState, AbstractAttribute> {
3375    using Base = StateWrapper<BooleanState, AbstractAttribute>;
AAHeapToShared__anondf4de5470111::AAHeapToShared3376    AAHeapToShared(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
3377  
3378    /// Create an abstract attribute view for the position \p IRP.
3379    static AAHeapToShared &createForPosition(const IRPosition &IRP,
3380                                             Attributor &A);
3381  
3382    /// Returns true if HeapToShared conversion is assumed to be possible.
3383    virtual bool isAssumedHeapToShared(CallBase &CB) const = 0;
3384  
3385    /// Returns true if HeapToShared conversion is assumed and the CB is a
3386    /// callsite to a free operation to be removed.
3387    virtual bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const = 0;
3388  
3389    /// See AbstractAttribute::getName().
getName__anondf4de5470111::AAHeapToShared3390    const std::string getName() const override { return "AAHeapToShared"; }
3391  
3392    /// See AbstractAttribute::getIdAddr().
getIdAddr__anondf4de5470111::AAHeapToShared3393    const char *getIdAddr() const override { return &ID; }
3394  
3395    /// This function should return true if the type of the \p AA is
3396    /// AAHeapToShared.
classof__anondf4de5470111::AAHeapToShared3397    static bool classof(const AbstractAttribute *AA) {
3398      return (AA->getIdAddr() == &ID);
3399    }
3400  
3401    /// Unique ID (due to the unique address)
3402    static const char ID;
3403  };
3404  
3405  struct AAHeapToSharedFunction : public AAHeapToShared {
AAHeapToSharedFunction__anondf4de5470111::AAHeapToSharedFunction3406    AAHeapToSharedFunction(const IRPosition &IRP, Attributor &A)
3407        : AAHeapToShared(IRP, A) {}
3408  
getAsStr__anondf4de5470111::AAHeapToSharedFunction3409    const std::string getAsStr(Attributor *) const override {
3410      return "[AAHeapToShared] " + std::to_string(MallocCalls.size()) +
3411             " malloc calls eligible.";
3412    }
3413  
3414    /// See AbstractAttribute::trackStatistics().
trackStatistics__anondf4de5470111::AAHeapToSharedFunction3415    void trackStatistics() const override {}
3416  
3417    /// This functions finds free calls that will be removed by the
3418    /// HeapToShared transformation.
findPotentialRemovedFreeCalls__anondf4de5470111::AAHeapToSharedFunction3419    void findPotentialRemovedFreeCalls(Attributor &A) {
3420      auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3421      auto &FreeRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
3422  
3423      PotentialRemovedFreeCalls.clear();
3424      // Update free call users of found malloc calls.
3425      for (CallBase *CB : MallocCalls) {
3426        SmallVector<CallBase *, 4> FreeCalls;
3427        for (auto *U : CB->users()) {
3428          CallBase *C = dyn_cast<CallBase>(U);
3429          if (C && C->getCalledFunction() == FreeRFI.Declaration)
3430            FreeCalls.push_back(C);
3431        }
3432  
3433        if (FreeCalls.size() != 1)
3434          continue;
3435  
3436        PotentialRemovedFreeCalls.insert(FreeCalls.front());
3437      }
3438    }
3439  
initialize__anondf4de5470111::AAHeapToSharedFunction3440    void initialize(Attributor &A) override {
3441      if (DisableOpenMPOptDeglobalization) {
3442        indicatePessimisticFixpoint();
3443        return;
3444      }
3445  
3446      auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3447      auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3448      if (!RFI.Declaration)
3449        return;
3450  
3451      Attributor::SimplifictionCallbackTy SCB =
3452          [](const IRPosition &, const AbstractAttribute *,
3453             bool &) -> std::optional<Value *> { return nullptr; };
3454  
3455      Function *F = getAnchorScope();
3456      for (User *U : RFI.Declaration->users())
3457        if (CallBase *CB = dyn_cast<CallBase>(U)) {
3458          if (CB->getFunction() != F)
3459            continue;
3460          MallocCalls.insert(CB);
3461          A.registerSimplificationCallback(IRPosition::callsite_returned(*CB),
3462                                           SCB);
3463        }
3464  
3465      findPotentialRemovedFreeCalls(A);
3466    }
3467  
isAssumedHeapToShared__anondf4de5470111::AAHeapToSharedFunction3468    bool isAssumedHeapToShared(CallBase &CB) const override {
3469      return isValidState() && MallocCalls.count(&CB);
3470    }
3471  
isAssumedHeapToSharedRemovedFree__anondf4de5470111::AAHeapToSharedFunction3472    bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const override {
3473      return isValidState() && PotentialRemovedFreeCalls.count(&CB);
3474    }
3475  
manifest__anondf4de5470111::AAHeapToSharedFunction3476    ChangeStatus manifest(Attributor &A) override {
3477      if (MallocCalls.empty())
3478        return ChangeStatus::UNCHANGED;
3479  
3480      auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3481      auto &FreeCall = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
3482  
3483      Function *F = getAnchorScope();
3484      auto *HS = A.lookupAAFor<AAHeapToStack>(IRPosition::function(*F), this,
3485                                              DepClassTy::OPTIONAL);
3486  
3487      ChangeStatus Changed = ChangeStatus::UNCHANGED;
3488      for (CallBase *CB : MallocCalls) {
3489        // Skip replacing this if HeapToStack has already claimed it.
3490        if (HS && HS->isAssumedHeapToStack(*CB))
3491          continue;
3492  
3493        // Find the unique free call to remove it.
3494        SmallVector<CallBase *, 4> FreeCalls;
3495        for (auto *U : CB->users()) {
3496          CallBase *C = dyn_cast<CallBase>(U);
3497          if (C && C->getCalledFunction() == FreeCall.Declaration)
3498            FreeCalls.push_back(C);
3499        }
3500        if (FreeCalls.size() != 1)
3501          continue;
3502  
3503        auto *AllocSize = cast<ConstantInt>(CB->getArgOperand(0));
3504  
3505        if (AllocSize->getZExtValue() + SharedMemoryUsed > SharedMemoryLimit) {
3506          LLVM_DEBUG(dbgs() << TAG << "Cannot replace call " << *CB
3507                            << " with shared memory."
3508                            << " Shared memory usage is limited to "
3509                            << SharedMemoryLimit << " bytes\n");
3510          continue;
3511        }
3512  
3513        LLVM_DEBUG(dbgs() << TAG << "Replace globalization call " << *CB
3514                          << " with " << AllocSize->getZExtValue()
3515                          << " bytes of shared memory\n");
3516  
3517        // Create a new shared memory buffer of the same size as the allocation
3518        // and replace all the uses of the original allocation with it.
3519        Module *M = CB->getModule();
3520        Type *Int8Ty = Type::getInt8Ty(M->getContext());
3521        Type *Int8ArrTy = ArrayType::get(Int8Ty, AllocSize->getZExtValue());
3522        auto *SharedMem = new GlobalVariable(
3523            *M, Int8ArrTy, /* IsConstant */ false, GlobalValue::InternalLinkage,
3524            PoisonValue::get(Int8ArrTy), CB->getName() + "_shared", nullptr,
3525            GlobalValue::NotThreadLocal,
3526            static_cast<unsigned>(AddressSpace::Shared));
3527        auto *NewBuffer =
3528            ConstantExpr::getPointerCast(SharedMem, Int8Ty->getPointerTo());
3529  
3530        auto Remark = [&](OptimizationRemark OR) {
3531          return OR << "Replaced globalized variable with "
3532                    << ore::NV("SharedMemory", AllocSize->getZExtValue())
3533                    << (AllocSize->isOne() ? " byte " : " bytes ")
3534                    << "of shared memory.";
3535        };
3536        A.emitRemark<OptimizationRemark>(CB, "OMP111", Remark);
3537  
3538        MaybeAlign Alignment = CB->getRetAlign();
3539        assert(Alignment &&
3540               "HeapToShared on allocation without alignment attribute");
3541        SharedMem->setAlignment(*Alignment);
3542  
3543        A.changeAfterManifest(IRPosition::callsite_returned(*CB), *NewBuffer);
3544        A.deleteAfterManifest(*CB);
3545        A.deleteAfterManifest(*FreeCalls.front());
3546  
3547        SharedMemoryUsed += AllocSize->getZExtValue();
3548        NumBytesMovedToSharedMemory = SharedMemoryUsed;
3549        Changed = ChangeStatus::CHANGED;
3550      }
3551  
3552      return Changed;
3553    }
3554  
updateImpl__anondf4de5470111::AAHeapToSharedFunction3555    ChangeStatus updateImpl(Attributor &A) override {
3556      if (MallocCalls.empty())
3557        return indicatePessimisticFixpoint();
3558      auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3559      auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3560      if (!RFI.Declaration)
3561        return ChangeStatus::UNCHANGED;
3562  
3563      Function *F = getAnchorScope();
3564  
3565      auto NumMallocCalls = MallocCalls.size();
3566  
3567      // Only consider malloc calls executed by a single thread with a constant.
3568      for (User *U : RFI.Declaration->users()) {
3569        if (CallBase *CB = dyn_cast<CallBase>(U)) {
3570          if (CB->getCaller() != F)
3571            continue;
3572          if (!MallocCalls.count(CB))
3573            continue;
3574          if (!isa<ConstantInt>(CB->getArgOperand(0))) {
3575            MallocCalls.remove(CB);
3576            continue;
3577          }
3578          const auto *ED = A.getAAFor<AAExecutionDomain>(
3579              *this, IRPosition::function(*F), DepClassTy::REQUIRED);
3580          if (!ED || !ED->isExecutedByInitialThreadOnly(*CB))
3581            MallocCalls.remove(CB);
3582        }
3583      }
3584  
3585      findPotentialRemovedFreeCalls(A);
3586  
3587      if (NumMallocCalls != MallocCalls.size())
3588        return ChangeStatus::CHANGED;
3589  
3590      return ChangeStatus::UNCHANGED;
3591    }
3592  
3593    /// Collection of all malloc calls in a function.
3594    SmallSetVector<CallBase *, 4> MallocCalls;
3595    /// Collection of potentially removed free calls in a function.
3596    SmallPtrSet<CallBase *, 4> PotentialRemovedFreeCalls;
3597    /// The total amount of shared memory that has been used for HeapToShared.
3598    unsigned SharedMemoryUsed = 0;
3599  };
3600  
3601  struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
3602    using Base = StateWrapper<KernelInfoState, AbstractAttribute>;
AAKernelInfo__anondf4de5470111::AAKernelInfo3603    AAKernelInfo(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
3604  
3605    /// The callee value is tracked beyond a simple stripPointerCasts, so we allow
3606    /// unknown callees.
requiresCalleeForCallBase__anondf4de5470111::AAKernelInfo3607    static bool requiresCalleeForCallBase() { return false; }
3608  
3609    /// Statistics are tracked as part of manifest for now.
trackStatistics__anondf4de5470111::AAKernelInfo3610    void trackStatistics() const override {}
3611  
3612    /// See AbstractAttribute::getAsStr()
getAsStr__anondf4de5470111::AAKernelInfo3613    const std::string getAsStr(Attributor *) const override {
3614      if (!isValidState())
3615        return "<invalid>";
3616      return std::string(SPMDCompatibilityTracker.isAssumed() ? "SPMD"
3617                                                              : "generic") +
3618             std::string(SPMDCompatibilityTracker.isAtFixpoint() ? " [FIX]"
3619                                                                 : "") +
3620             std::string(" #PRs: ") +
3621             (ReachedKnownParallelRegions.isValidState()
3622                  ? std::to_string(ReachedKnownParallelRegions.size())
3623                  : "<invalid>") +
3624             ", #Unknown PRs: " +
3625             (ReachedUnknownParallelRegions.isValidState()
3626                  ? std::to_string(ReachedUnknownParallelRegions.size())
3627                  : "<invalid>") +
3628             ", #Reaching Kernels: " +
3629             (ReachingKernelEntries.isValidState()
3630                  ? std::to_string(ReachingKernelEntries.size())
3631                  : "<invalid>") +
3632             ", #ParLevels: " +
3633             (ParallelLevels.isValidState()
3634                  ? std::to_string(ParallelLevels.size())
3635                  : "<invalid>") +
3636             ", NestedPar: " + (NestedParallelism ? "yes" : "no");
3637    }
3638  
3639    /// Create an abstract attribute biew for the position \p IRP.
3640    static AAKernelInfo &createForPosition(const IRPosition &IRP, Attributor &A);
3641  
3642    /// See AbstractAttribute::getName()
getName__anondf4de5470111::AAKernelInfo3643    const std::string getName() const override { return "AAKernelInfo"; }
3644  
3645    /// See AbstractAttribute::getIdAddr()
getIdAddr__anondf4de5470111::AAKernelInfo3646    const char *getIdAddr() const override { return &ID; }
3647  
3648    /// This function should return true if the type of the \p AA is AAKernelInfo
classof__anondf4de5470111::AAKernelInfo3649    static bool classof(const AbstractAttribute *AA) {
3650      return (AA->getIdAddr() == &ID);
3651    }
3652  
3653    static const char ID;
3654  };
3655  
3656  /// The function kernel info abstract attribute, basically, what can we say
3657  /// about a function with regards to the KernelInfoState.
3658  struct AAKernelInfoFunction : AAKernelInfo {
AAKernelInfoFunction__anondf4de5470111::AAKernelInfoFunction3659    AAKernelInfoFunction(const IRPosition &IRP, Attributor &A)
3660        : AAKernelInfo(IRP, A) {}
3661  
3662    SmallPtrSet<Instruction *, 4> GuardedInstructions;
3663  
getGuardedInstructions__anondf4de5470111::AAKernelInfoFunction3664    SmallPtrSetImpl<Instruction *> &getGuardedInstructions() {
3665      return GuardedInstructions;
3666    }
3667  
setConfigurationOfKernelEnvironment__anondf4de5470111::AAKernelInfoFunction3668    void setConfigurationOfKernelEnvironment(ConstantStruct *ConfigC) {
3669      Constant *NewKernelEnvC = ConstantFoldInsertValueInstruction(
3670          KernelEnvC, ConfigC, {KernelInfo::ConfigurationIdx});
3671      assert(NewKernelEnvC && "Failed to create new kernel environment");
3672      KernelEnvC = cast<ConstantStruct>(NewKernelEnvC);
3673    }
3674  
3675  #define KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MEMBER)                        \
3676    void set##MEMBER##OfKernelEnvironment(ConstantInt *NewVal) {                 \
3677      ConstantStruct *ConfigC =                                                  \
3678          KernelInfo::getConfigurationFromKernelEnvironment(KernelEnvC);         \
3679      Constant *NewConfigC = ConstantFoldInsertValueInstruction(                 \
3680          ConfigC, NewVal, {KernelInfo::MEMBER##Idx});                           \
3681      assert(NewConfigC && "Failed to create new configuration environment");    \
3682      setConfigurationOfKernelEnvironment(cast<ConstantStruct>(NewConfigC));     \
3683    }
3684  
3685    KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(UseGenericStateMachine)
KERNEL_ENVIRONMENT_CONFIGURATION_SETTER__anondf4de5470111::AAKernelInfoFunction3686    KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MayUseNestedParallelism)
3687    KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(ExecMode)
3688    KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MinThreads)
3689    KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MaxThreads)
3690    KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MinTeams)
3691    KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MaxTeams)
3692  
3693  #undef KERNEL_ENVIRONMENT_CONFIGURATION_SETTER
3694  
3695    /// See AbstractAttribute::initialize(...).
3696    void initialize(Attributor &A) override {
3697      // This is a high-level transform that might change the constant arguments
3698      // of the init and dinit calls. We need to tell the Attributor about this
3699      // to avoid other parts using the current constant value for simpliication.
3700      auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3701  
3702      Function *Fn = getAnchorScope();
3703  
3704      OMPInformationCache::RuntimeFunctionInfo &InitRFI =
3705          OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
3706      OMPInformationCache::RuntimeFunctionInfo &DeinitRFI =
3707          OMPInfoCache.RFIs[OMPRTL___kmpc_target_deinit];
3708  
3709      // For kernels we perform more initialization work, first we find the init
3710      // and deinit calls.
3711      auto StoreCallBase = [](Use &U,
3712                              OMPInformationCache::RuntimeFunctionInfo &RFI,
3713                              CallBase *&Storage) {
3714        CallBase *CB = OpenMPOpt::getCallIfRegularCall(U, &RFI);
3715        assert(CB &&
3716               "Unexpected use of __kmpc_target_init or __kmpc_target_deinit!");
3717        assert(!Storage &&
3718               "Multiple uses of __kmpc_target_init or __kmpc_target_deinit!");
3719        Storage = CB;
3720        return false;
3721      };
3722      InitRFI.foreachUse(
3723          [&](Use &U, Function &) {
3724            StoreCallBase(U, InitRFI, KernelInitCB);
3725            return false;
3726          },
3727          Fn);
3728      DeinitRFI.foreachUse(
3729          [&](Use &U, Function &) {
3730            StoreCallBase(U, DeinitRFI, KernelDeinitCB);
3731            return false;
3732          },
3733          Fn);
3734  
3735      // Ignore kernels without initializers such as global constructors.
3736      if (!KernelInitCB || !KernelDeinitCB)
3737        return;
3738  
3739      // Add itself to the reaching kernel and set IsKernelEntry.
3740      ReachingKernelEntries.insert(Fn);
3741      IsKernelEntry = true;
3742  
3743      KernelEnvC =
3744          KernelInfo::getKernelEnvironementFromKernelInitCB(KernelInitCB);
3745      GlobalVariable *KernelEnvGV =
3746          KernelInfo::getKernelEnvironementGVFromKernelInitCB(KernelInitCB);
3747  
3748      Attributor::GlobalVariableSimplifictionCallbackTy
3749          KernelConfigurationSimplifyCB =
3750              [&](const GlobalVariable &GV, const AbstractAttribute *AA,
3751                  bool &UsedAssumedInformation) -> std::optional<Constant *> {
3752        if (!isAtFixpoint()) {
3753          if (!AA)
3754            return nullptr;
3755          UsedAssumedInformation = true;
3756          A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
3757        }
3758        return KernelEnvC;
3759      };
3760  
3761      A.registerGlobalVariableSimplificationCallback(
3762          *KernelEnvGV, KernelConfigurationSimplifyCB);
3763  
3764      // Check if we know we are in SPMD-mode already.
3765      ConstantInt *ExecModeC =
3766          KernelInfo::getExecModeFromKernelEnvironment(KernelEnvC);
3767      ConstantInt *AssumedExecModeC = ConstantInt::get(
3768          ExecModeC->getIntegerType(),
3769          ExecModeC->getSExtValue() | OMP_TGT_EXEC_MODE_GENERIC_SPMD);
3770      if (ExecModeC->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD)
3771        SPMDCompatibilityTracker.indicateOptimisticFixpoint();
3772      else if (DisableOpenMPOptSPMDization)
3773        // This is a generic region but SPMDization is disabled so stop
3774        // tracking.
3775        SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3776      else
3777        setExecModeOfKernelEnvironment(AssumedExecModeC);
3778  
3779      const Triple T(Fn->getParent()->getTargetTriple());
3780      auto *Int32Ty = Type::getInt32Ty(Fn->getContext());
3781      auto [MinThreads, MaxThreads] =
3782          OpenMPIRBuilder::readThreadBoundsForKernel(T, *Fn);
3783      if (MinThreads)
3784        setMinThreadsOfKernelEnvironment(ConstantInt::get(Int32Ty, MinThreads));
3785      if (MaxThreads)
3786        setMaxThreadsOfKernelEnvironment(ConstantInt::get(Int32Ty, MaxThreads));
3787      auto [MinTeams, MaxTeams] =
3788          OpenMPIRBuilder::readTeamBoundsForKernel(T, *Fn);
3789      if (MinTeams)
3790        setMinTeamsOfKernelEnvironment(ConstantInt::get(Int32Ty, MinTeams));
3791      if (MaxTeams)
3792        setMaxTeamsOfKernelEnvironment(ConstantInt::get(Int32Ty, MaxTeams));
3793  
3794      ConstantInt *MayUseNestedParallelismC =
3795          KernelInfo::getMayUseNestedParallelismFromKernelEnvironment(KernelEnvC);
3796      ConstantInt *AssumedMayUseNestedParallelismC = ConstantInt::get(
3797          MayUseNestedParallelismC->getIntegerType(), NestedParallelism);
3798      setMayUseNestedParallelismOfKernelEnvironment(
3799          AssumedMayUseNestedParallelismC);
3800  
3801      if (!DisableOpenMPOptStateMachineRewrite) {
3802        ConstantInt *UseGenericStateMachineC =
3803            KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
3804                KernelEnvC);
3805        ConstantInt *AssumedUseGenericStateMachineC =
3806            ConstantInt::get(UseGenericStateMachineC->getIntegerType(), false);
3807        setUseGenericStateMachineOfKernelEnvironment(
3808            AssumedUseGenericStateMachineC);
3809      }
3810  
3811      // Register virtual uses of functions we might need to preserve.
3812      auto RegisterVirtualUse = [&](RuntimeFunction RFKind,
3813                                    Attributor::VirtualUseCallbackTy &CB) {
3814        if (!OMPInfoCache.RFIs[RFKind].Declaration)
3815          return;
3816        A.registerVirtualUseCallback(*OMPInfoCache.RFIs[RFKind].Declaration, CB);
3817      };
3818  
3819      // Add a dependence to ensure updates if the state changes.
3820      auto AddDependence = [](Attributor &A, const AAKernelInfo *KI,
3821                              const AbstractAttribute *QueryingAA) {
3822        if (QueryingAA) {
3823          A.recordDependence(*KI, *QueryingAA, DepClassTy::OPTIONAL);
3824        }
3825        return true;
3826      };
3827  
3828      Attributor::VirtualUseCallbackTy CustomStateMachineUseCB =
3829          [&](Attributor &A, const AbstractAttribute *QueryingAA) {
3830            // Whenever we create a custom state machine we will insert calls to
3831            // __kmpc_get_hardware_num_threads_in_block,
3832            // __kmpc_get_warp_size,
3833            // __kmpc_barrier_simple_generic,
3834            // __kmpc_kernel_parallel, and
3835            // __kmpc_kernel_end_parallel.
3836            // Not needed if we are on track for SPMDzation.
3837            if (SPMDCompatibilityTracker.isValidState())
3838              return AddDependence(A, this, QueryingAA);
3839            // Not needed if we can't rewrite due to an invalid state.
3840            if (!ReachedKnownParallelRegions.isValidState())
3841              return AddDependence(A, this, QueryingAA);
3842            return false;
3843          };
3844  
3845      // Not needed if we are pre-runtime merge.
3846      if (!KernelInitCB->getCalledFunction()->isDeclaration()) {
3847        RegisterVirtualUse(OMPRTL___kmpc_get_hardware_num_threads_in_block,
3848                           CustomStateMachineUseCB);
3849        RegisterVirtualUse(OMPRTL___kmpc_get_warp_size, CustomStateMachineUseCB);
3850        RegisterVirtualUse(OMPRTL___kmpc_barrier_simple_generic,
3851                           CustomStateMachineUseCB);
3852        RegisterVirtualUse(OMPRTL___kmpc_kernel_parallel,
3853                           CustomStateMachineUseCB);
3854        RegisterVirtualUse(OMPRTL___kmpc_kernel_end_parallel,
3855                           CustomStateMachineUseCB);
3856      }
3857  
3858      // If we do not perform SPMDzation we do not need the virtual uses below.
3859      if (SPMDCompatibilityTracker.isAtFixpoint())
3860        return;
3861  
3862      Attributor::VirtualUseCallbackTy HWThreadIdUseCB =
3863          [&](Attributor &A, const AbstractAttribute *QueryingAA) {
3864            // Whenever we perform SPMDzation we will insert
3865            // __kmpc_get_hardware_thread_id_in_block calls.
3866            if (!SPMDCompatibilityTracker.isValidState())
3867              return AddDependence(A, this, QueryingAA);
3868            return false;
3869          };
3870      RegisterVirtualUse(OMPRTL___kmpc_get_hardware_thread_id_in_block,
3871                         HWThreadIdUseCB);
3872  
3873      Attributor::VirtualUseCallbackTy SPMDBarrierUseCB =
3874          [&](Attributor &A, const AbstractAttribute *QueryingAA) {
3875            // Whenever we perform SPMDzation with guarding we will insert
3876            // __kmpc_simple_barrier_spmd calls. If SPMDzation failed, there is
3877            // nothing to guard, or there are no parallel regions, we don't need
3878            // the calls.
3879            if (!SPMDCompatibilityTracker.isValidState())
3880              return AddDependence(A, this, QueryingAA);
3881            if (SPMDCompatibilityTracker.empty())
3882              return AddDependence(A, this, QueryingAA);
3883            if (!mayContainParallelRegion())
3884              return AddDependence(A, this, QueryingAA);
3885            return false;
3886          };
3887      RegisterVirtualUse(OMPRTL___kmpc_barrier_simple_spmd, SPMDBarrierUseCB);
3888    }
3889  
3890    /// Sanitize the string \p S such that it is a suitable global symbol name.
sanitizeForGlobalName__anondf4de5470111::AAKernelInfoFunction3891    static std::string sanitizeForGlobalName(std::string S) {
3892      std::replace_if(
3893          S.begin(), S.end(),
3894          [](const char C) {
3895            return !((C >= 'a' && C <= 'z') || (C >= 'A' && C <= 'Z') ||
3896                     (C >= '0' && C <= '9') || C == '_');
3897          },
3898          '.');
3899      return S;
3900    }
3901  
3902    /// Modify the IR based on the KernelInfoState as the fixpoint iteration is
3903    /// finished now.
manifest__anondf4de5470111::AAKernelInfoFunction3904    ChangeStatus manifest(Attributor &A) override {
3905      // If we are not looking at a kernel with __kmpc_target_init and
3906      // __kmpc_target_deinit call we cannot actually manifest the information.
3907      if (!KernelInitCB || !KernelDeinitCB)
3908        return ChangeStatus::UNCHANGED;
3909  
3910      ChangeStatus Changed = ChangeStatus::UNCHANGED;
3911  
3912      bool HasBuiltStateMachine = true;
3913      if (!changeToSPMDMode(A, Changed)) {
3914        if (!KernelInitCB->getCalledFunction()->isDeclaration())
3915          HasBuiltStateMachine = buildCustomStateMachine(A, Changed);
3916        else
3917          HasBuiltStateMachine = false;
3918      }
3919  
3920      // We need to reset KernelEnvC if specific rewriting is not done.
3921      ConstantStruct *ExistingKernelEnvC =
3922          KernelInfo::getKernelEnvironementFromKernelInitCB(KernelInitCB);
3923      ConstantInt *OldUseGenericStateMachineVal =
3924          KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
3925              ExistingKernelEnvC);
3926      if (!HasBuiltStateMachine)
3927        setUseGenericStateMachineOfKernelEnvironment(
3928            OldUseGenericStateMachineVal);
3929  
3930      // At last, update the KernelEnvc
3931      GlobalVariable *KernelEnvGV =
3932          KernelInfo::getKernelEnvironementGVFromKernelInitCB(KernelInitCB);
3933      if (KernelEnvGV->getInitializer() != KernelEnvC) {
3934        KernelEnvGV->setInitializer(KernelEnvC);
3935        Changed = ChangeStatus::CHANGED;
3936      }
3937  
3938      return Changed;
3939    }
3940  
insertInstructionGuardsHelper__anondf4de5470111::AAKernelInfoFunction3941    void insertInstructionGuardsHelper(Attributor &A) {
3942      auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3943  
3944      auto CreateGuardedRegion = [&](Instruction *RegionStartI,
3945                                     Instruction *RegionEndI) {
3946        LoopInfo *LI = nullptr;
3947        DominatorTree *DT = nullptr;
3948        MemorySSAUpdater *MSU = nullptr;
3949        using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
3950  
3951        BasicBlock *ParentBB = RegionStartI->getParent();
3952        Function *Fn = ParentBB->getParent();
3953        Module &M = *Fn->getParent();
3954  
3955        // Create all the blocks and logic.
3956        // ParentBB:
3957        //    goto RegionCheckTidBB
3958        // RegionCheckTidBB:
3959        //    Tid = __kmpc_hardware_thread_id()
3960        //    if (Tid != 0)
3961        //        goto RegionBarrierBB
3962        // RegionStartBB:
3963        //    <execute instructions guarded>
3964        //    goto RegionEndBB
3965        // RegionEndBB:
3966        //    <store escaping values to shared mem>
3967        //    goto RegionBarrierBB
3968        //  RegionBarrierBB:
3969        //    __kmpc_simple_barrier_spmd()
3970        //    // second barrier is omitted if lacking escaping values.
3971        //    <load escaping values from shared mem>
3972        //    __kmpc_simple_barrier_spmd()
3973        //    goto RegionExitBB
3974        // RegionExitBB:
3975        //    <execute rest of instructions>
3976  
3977        BasicBlock *RegionEndBB = SplitBlock(ParentBB, RegionEndI->getNextNode(),
3978                                             DT, LI, MSU, "region.guarded.end");
3979        BasicBlock *RegionBarrierBB =
3980            SplitBlock(RegionEndBB, &*RegionEndBB->getFirstInsertionPt(), DT, LI,
3981                       MSU, "region.barrier");
3982        BasicBlock *RegionExitBB =
3983            SplitBlock(RegionBarrierBB, &*RegionBarrierBB->getFirstInsertionPt(),
3984                       DT, LI, MSU, "region.exit");
3985        BasicBlock *RegionStartBB =
3986            SplitBlock(ParentBB, RegionStartI, DT, LI, MSU, "region.guarded");
3987  
3988        assert(ParentBB->getUniqueSuccessor() == RegionStartBB &&
3989               "Expected a different CFG");
3990  
3991        BasicBlock *RegionCheckTidBB = SplitBlock(
3992            ParentBB, ParentBB->getTerminator(), DT, LI, MSU, "region.check.tid");
3993  
3994        // Register basic blocks with the Attributor.
3995        A.registerManifestAddedBasicBlock(*RegionEndBB);
3996        A.registerManifestAddedBasicBlock(*RegionBarrierBB);
3997        A.registerManifestAddedBasicBlock(*RegionExitBB);
3998        A.registerManifestAddedBasicBlock(*RegionStartBB);
3999        A.registerManifestAddedBasicBlock(*RegionCheckTidBB);
4000  
4001        bool HasBroadcastValues = false;
4002        // Find escaping outputs from the guarded region to outside users and
4003        // broadcast their values to them.
4004        for (Instruction &I : *RegionStartBB) {
4005          SmallVector<Use *, 4> OutsideUses;
4006          for (Use &U : I.uses()) {
4007            Instruction &UsrI = *cast<Instruction>(U.getUser());
4008            if (UsrI.getParent() != RegionStartBB)
4009              OutsideUses.push_back(&U);
4010          }
4011  
4012          if (OutsideUses.empty())
4013            continue;
4014  
4015          HasBroadcastValues = true;
4016  
4017          // Emit a global variable in shared memory to store the broadcasted
4018          // value.
4019          auto *SharedMem = new GlobalVariable(
4020              M, I.getType(), /* IsConstant */ false,
4021              GlobalValue::InternalLinkage, UndefValue::get(I.getType()),
4022              sanitizeForGlobalName(
4023                  (I.getName() + ".guarded.output.alloc").str()),
4024              nullptr, GlobalValue::NotThreadLocal,
4025              static_cast<unsigned>(AddressSpace::Shared));
4026  
4027          // Emit a store instruction to update the value.
4028          new StoreInst(&I, SharedMem,
4029                        RegionEndBB->getTerminator()->getIterator());
4030  
4031          LoadInst *LoadI = new LoadInst(
4032              I.getType(), SharedMem, I.getName() + ".guarded.output.load",
4033              RegionBarrierBB->getTerminator()->getIterator());
4034  
4035          // Emit a load instruction and replace uses of the output value.
4036          for (Use *U : OutsideUses)
4037            A.changeUseAfterManifest(*U, *LoadI);
4038        }
4039  
4040        auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4041  
4042        // Go to tid check BB in ParentBB.
4043        const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
4044        ParentBB->getTerminator()->eraseFromParent();
4045        OpenMPIRBuilder::LocationDescription Loc(
4046            InsertPointTy(ParentBB, ParentBB->end()), DL);
4047        OMPInfoCache.OMPBuilder.updateToLocation(Loc);
4048        uint32_t SrcLocStrSize;
4049        auto *SrcLocStr =
4050            OMPInfoCache.OMPBuilder.getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4051        Value *Ident =
4052            OMPInfoCache.OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4053        BranchInst::Create(RegionCheckTidBB, ParentBB)->setDebugLoc(DL);
4054  
4055        // Add check for Tid in RegionCheckTidBB
4056        RegionCheckTidBB->getTerminator()->eraseFromParent();
4057        OpenMPIRBuilder::LocationDescription LocRegionCheckTid(
4058            InsertPointTy(RegionCheckTidBB, RegionCheckTidBB->end()), DL);
4059        OMPInfoCache.OMPBuilder.updateToLocation(LocRegionCheckTid);
4060        FunctionCallee HardwareTidFn =
4061            OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4062                M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
4063        CallInst *Tid =
4064            OMPInfoCache.OMPBuilder.Builder.CreateCall(HardwareTidFn, {});
4065        Tid->setDebugLoc(DL);
4066        OMPInfoCache.setCallingConvention(HardwareTidFn, Tid);
4067        Value *TidCheck = OMPInfoCache.OMPBuilder.Builder.CreateIsNull(Tid);
4068        OMPInfoCache.OMPBuilder.Builder
4069            .CreateCondBr(TidCheck, RegionStartBB, RegionBarrierBB)
4070            ->setDebugLoc(DL);
4071  
4072        // First barrier for synchronization, ensures main thread has updated
4073        // values.
4074        FunctionCallee BarrierFn =
4075            OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4076                M, OMPRTL___kmpc_barrier_simple_spmd);
4077        OMPInfoCache.OMPBuilder.updateToLocation(InsertPointTy(
4078            RegionBarrierBB, RegionBarrierBB->getFirstInsertionPt()));
4079        CallInst *Barrier =
4080            OMPInfoCache.OMPBuilder.Builder.CreateCall(BarrierFn, {Ident, Tid});
4081        Barrier->setDebugLoc(DL);
4082        OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
4083  
4084        // Second barrier ensures workers have read broadcast values.
4085        if (HasBroadcastValues) {
4086          CallInst *Barrier =
4087              CallInst::Create(BarrierFn, {Ident, Tid}, "",
4088                               RegionBarrierBB->getTerminator()->getIterator());
4089          Barrier->setDebugLoc(DL);
4090          OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
4091        }
4092      };
4093  
4094      auto &AllocSharedRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
4095      SmallPtrSet<BasicBlock *, 8> Visited;
4096      for (Instruction *GuardedI : SPMDCompatibilityTracker) {
4097        BasicBlock *BB = GuardedI->getParent();
4098        if (!Visited.insert(BB).second)
4099          continue;
4100  
4101        SmallVector<std::pair<Instruction *, Instruction *>> Reorders;
4102        Instruction *LastEffect = nullptr;
4103        BasicBlock::reverse_iterator IP = BB->rbegin(), IPEnd = BB->rend();
4104        while (++IP != IPEnd) {
4105          if (!IP->mayHaveSideEffects() && !IP->mayReadFromMemory())
4106            continue;
4107          Instruction *I = &*IP;
4108          if (OpenMPOpt::getCallIfRegularCall(*I, &AllocSharedRFI))
4109            continue;
4110          if (!I->user_empty() || !SPMDCompatibilityTracker.contains(I)) {
4111            LastEffect = nullptr;
4112            continue;
4113          }
4114          if (LastEffect)
4115            Reorders.push_back({I, LastEffect});
4116          LastEffect = &*IP;
4117        }
4118        for (auto &Reorder : Reorders)
4119          Reorder.first->moveBefore(Reorder.second);
4120      }
4121  
4122      SmallVector<std::pair<Instruction *, Instruction *>, 4> GuardedRegions;
4123  
4124      for (Instruction *GuardedI : SPMDCompatibilityTracker) {
4125        BasicBlock *BB = GuardedI->getParent();
4126        auto *CalleeAA = A.lookupAAFor<AAKernelInfo>(
4127            IRPosition::function(*GuardedI->getFunction()), nullptr,
4128            DepClassTy::NONE);
4129        assert(CalleeAA != nullptr && "Expected Callee AAKernelInfo");
4130        auto &CalleeAAFunction = *cast<AAKernelInfoFunction>(CalleeAA);
4131        // Continue if instruction is already guarded.
4132        if (CalleeAAFunction.getGuardedInstructions().contains(GuardedI))
4133          continue;
4134  
4135        Instruction *GuardedRegionStart = nullptr, *GuardedRegionEnd = nullptr;
4136        for (Instruction &I : *BB) {
4137          // If instruction I needs to be guarded update the guarded region
4138          // bounds.
4139          if (SPMDCompatibilityTracker.contains(&I)) {
4140            CalleeAAFunction.getGuardedInstructions().insert(&I);
4141            if (GuardedRegionStart)
4142              GuardedRegionEnd = &I;
4143            else
4144              GuardedRegionStart = GuardedRegionEnd = &I;
4145  
4146            continue;
4147          }
4148  
4149          // Instruction I does not need guarding, store
4150          // any region found and reset bounds.
4151          if (GuardedRegionStart) {
4152            GuardedRegions.push_back(
4153                std::make_pair(GuardedRegionStart, GuardedRegionEnd));
4154            GuardedRegionStart = nullptr;
4155            GuardedRegionEnd = nullptr;
4156          }
4157        }
4158      }
4159  
4160      for (auto &GR : GuardedRegions)
4161        CreateGuardedRegion(GR.first, GR.second);
4162    }
4163  
forceSingleThreadPerWorkgroupHelper__anondf4de5470111::AAKernelInfoFunction4164    void forceSingleThreadPerWorkgroupHelper(Attributor &A) {
4165      // Only allow 1 thread per workgroup to continue executing the user code.
4166      //
4167      //     InitCB = __kmpc_target_init(...)
4168      //     ThreadIdInBlock = __kmpc_get_hardware_thread_id_in_block();
4169      //     if (ThreadIdInBlock != 0) return;
4170      // UserCode:
4171      //     // user code
4172      //
4173      auto &Ctx = getAnchorValue().getContext();
4174      Function *Kernel = getAssociatedFunction();
4175      assert(Kernel && "Expected an associated function!");
4176  
4177      // Create block for user code to branch to from initial block.
4178      BasicBlock *InitBB = KernelInitCB->getParent();
4179      BasicBlock *UserCodeBB = InitBB->splitBasicBlock(
4180          KernelInitCB->getNextNode(), "main.thread.user_code");
4181      BasicBlock *ReturnBB =
4182          BasicBlock::Create(Ctx, "exit.threads", Kernel, UserCodeBB);
4183  
4184      // Register blocks with attributor:
4185      A.registerManifestAddedBasicBlock(*InitBB);
4186      A.registerManifestAddedBasicBlock(*UserCodeBB);
4187      A.registerManifestAddedBasicBlock(*ReturnBB);
4188  
4189      // Debug location:
4190      const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
4191      ReturnInst::Create(Ctx, ReturnBB)->setDebugLoc(DLoc);
4192      InitBB->getTerminator()->eraseFromParent();
4193  
4194      // Prepare call to OMPRTL___kmpc_get_hardware_thread_id_in_block.
4195      Module &M = *Kernel->getParent();
4196      auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4197      FunctionCallee ThreadIdInBlockFn =
4198          OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4199              M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
4200  
4201      // Get thread ID in block.
4202      CallInst *ThreadIdInBlock =
4203          CallInst::Create(ThreadIdInBlockFn, "thread_id.in.block", InitBB);
4204      OMPInfoCache.setCallingConvention(ThreadIdInBlockFn, ThreadIdInBlock);
4205      ThreadIdInBlock->setDebugLoc(DLoc);
4206  
4207      // Eliminate all threads in the block with ID not equal to 0:
4208      Instruction *IsMainThread =
4209          ICmpInst::Create(ICmpInst::ICmp, CmpInst::ICMP_NE, ThreadIdInBlock,
4210                           ConstantInt::get(ThreadIdInBlock->getType(), 0),
4211                           "thread.is_main", InitBB);
4212      IsMainThread->setDebugLoc(DLoc);
4213      BranchInst::Create(ReturnBB, UserCodeBB, IsMainThread, InitBB);
4214    }
4215  
changeToSPMDMode__anondf4de5470111::AAKernelInfoFunction4216    bool changeToSPMDMode(Attributor &A, ChangeStatus &Changed) {
4217      auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4218  
4219      // We cannot change to SPMD mode if the runtime functions aren't availible.
4220      if (!OMPInfoCache.runtimeFnsAvailable(
4221              {OMPRTL___kmpc_get_hardware_thread_id_in_block,
4222               OMPRTL___kmpc_barrier_simple_spmd}))
4223        return false;
4224  
4225      if (!SPMDCompatibilityTracker.isAssumed()) {
4226        for (Instruction *NonCompatibleI : SPMDCompatibilityTracker) {
4227          if (!NonCompatibleI)
4228            continue;
4229  
4230          // Skip diagnostics on calls to known OpenMP runtime functions for now.
4231          if (auto *CB = dyn_cast<CallBase>(NonCompatibleI))
4232            if (OMPInfoCache.RTLFunctions.contains(CB->getCalledFunction()))
4233              continue;
4234  
4235          auto Remark = [&](OptimizationRemarkAnalysis ORA) {
4236            ORA << "Value has potential side effects preventing SPMD-mode "
4237                   "execution";
4238            if (isa<CallBase>(NonCompatibleI)) {
4239              ORA << ". Add `[[omp::assume(\"ompx_spmd_amenable\")]]` to "
4240                     "the called function to override";
4241            }
4242            return ORA << ".";
4243          };
4244          A.emitRemark<OptimizationRemarkAnalysis>(NonCompatibleI, "OMP121",
4245                                                   Remark);
4246  
4247          LLVM_DEBUG(dbgs() << TAG << "SPMD-incompatible side-effect: "
4248                            << *NonCompatibleI << "\n");
4249        }
4250  
4251        return false;
4252      }
4253  
4254      // Get the actual kernel, could be the caller of the anchor scope if we have
4255      // a debug wrapper.
4256      Function *Kernel = getAnchorScope();
4257      if (Kernel->hasLocalLinkage()) {
4258        assert(Kernel->hasOneUse() && "Unexpected use of debug kernel wrapper.");
4259        auto *CB = cast<CallBase>(Kernel->user_back());
4260        Kernel = CB->getCaller();
4261      }
4262      assert(omp::isOpenMPKernel(*Kernel) && "Expected kernel function!");
4263  
4264      // Check if the kernel is already in SPMD mode, if so, return success.
4265      ConstantStruct *ExistingKernelEnvC =
4266          KernelInfo::getKernelEnvironementFromKernelInitCB(KernelInitCB);
4267      auto *ExecModeC =
4268          KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC);
4269      const int8_t ExecModeVal = ExecModeC->getSExtValue();
4270      if (ExecModeVal != OMP_TGT_EXEC_MODE_GENERIC)
4271        return true;
4272  
4273      // We will now unconditionally modify the IR, indicate a change.
4274      Changed = ChangeStatus::CHANGED;
4275  
4276      // Do not use instruction guards when no parallel is present inside
4277      // the target region.
4278      if (mayContainParallelRegion())
4279        insertInstructionGuardsHelper(A);
4280      else
4281        forceSingleThreadPerWorkgroupHelper(A);
4282  
4283      // Adjust the global exec mode flag that tells the runtime what mode this
4284      // kernel is executed in.
4285      assert(ExecModeVal == OMP_TGT_EXEC_MODE_GENERIC &&
4286             "Initially non-SPMD kernel has SPMD exec mode!");
4287      setExecModeOfKernelEnvironment(
4288          ConstantInt::get(ExecModeC->getIntegerType(),
4289                           ExecModeVal | OMP_TGT_EXEC_MODE_GENERIC_SPMD));
4290  
4291      ++NumOpenMPTargetRegionKernelsSPMD;
4292  
4293      auto Remark = [&](OptimizationRemark OR) {
4294        return OR << "Transformed generic-mode kernel to SPMD-mode.";
4295      };
4296      A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP120", Remark);
4297      return true;
4298    };
4299  
buildCustomStateMachine__anondf4de5470111::AAKernelInfoFunction4300    bool buildCustomStateMachine(Attributor &A, ChangeStatus &Changed) {
4301      // If we have disabled state machine rewrites, don't make a custom one
4302      if (DisableOpenMPOptStateMachineRewrite)
4303        return false;
4304  
4305      // Don't rewrite the state machine if we are not in a valid state.
4306      if (!ReachedKnownParallelRegions.isValidState())
4307        return false;
4308  
4309      auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4310      if (!OMPInfoCache.runtimeFnsAvailable(
4311              {OMPRTL___kmpc_get_hardware_num_threads_in_block,
4312               OMPRTL___kmpc_get_warp_size, OMPRTL___kmpc_barrier_simple_generic,
4313               OMPRTL___kmpc_kernel_parallel, OMPRTL___kmpc_kernel_end_parallel}))
4314        return false;
4315  
4316      ConstantStruct *ExistingKernelEnvC =
4317          KernelInfo::getKernelEnvironementFromKernelInitCB(KernelInitCB);
4318  
4319      // Check if the current configuration is non-SPMD and generic state machine.
4320      // If we already have SPMD mode or a custom state machine we do not need to
4321      // go any further. If it is anything but a constant something is weird and
4322      // we give up.
4323      ConstantInt *UseStateMachineC =
4324          KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
4325              ExistingKernelEnvC);
4326      ConstantInt *ModeC =
4327          KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC);
4328  
4329      // If we are stuck with generic mode, try to create a custom device (=GPU)
4330      // state machine which is specialized for the parallel regions that are
4331      // reachable by the kernel.
4332      if (UseStateMachineC->isZero() ||
4333          (ModeC->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD))
4334        return false;
4335  
4336      Changed = ChangeStatus::CHANGED;
4337  
4338      // If not SPMD mode, indicate we use a custom state machine now.
4339      setUseGenericStateMachineOfKernelEnvironment(
4340          ConstantInt::get(UseStateMachineC->getIntegerType(), false));
4341  
4342      // If we don't actually need a state machine we are done here. This can
4343      // happen if there simply are no parallel regions. In the resulting kernel
4344      // all worker threads will simply exit right away, leaving the main thread
4345      // to do the work alone.
4346      if (!mayContainParallelRegion()) {
4347        ++NumOpenMPTargetRegionKernelsWithoutStateMachine;
4348  
4349        auto Remark = [&](OptimizationRemark OR) {
4350          return OR << "Removing unused state machine from generic-mode kernel.";
4351        };
4352        A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP130", Remark);
4353  
4354        return true;
4355      }
4356  
4357      // Keep track in the statistics of our new shiny custom state machine.
4358      if (ReachedUnknownParallelRegions.empty()) {
4359        ++NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback;
4360  
4361        auto Remark = [&](OptimizationRemark OR) {
4362          return OR << "Rewriting generic-mode kernel with a customized state "
4363                       "machine.";
4364        };
4365        A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP131", Remark);
4366      } else {
4367        ++NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback;
4368  
4369        auto Remark = [&](OptimizationRemarkAnalysis OR) {
4370          return OR << "Generic-mode kernel is executed with a customized state "
4371                       "machine that requires a fallback.";
4372        };
4373        A.emitRemark<OptimizationRemarkAnalysis>(KernelInitCB, "OMP132", Remark);
4374  
4375        // Tell the user why we ended up with a fallback.
4376        for (CallBase *UnknownParallelRegionCB : ReachedUnknownParallelRegions) {
4377          if (!UnknownParallelRegionCB)
4378            continue;
4379          auto Remark = [&](OptimizationRemarkAnalysis ORA) {
4380            return ORA << "Call may contain unknown parallel regions. Use "
4381                       << "`[[omp::assume(\"omp_no_parallelism\")]]` to "
4382                          "override.";
4383          };
4384          A.emitRemark<OptimizationRemarkAnalysis>(UnknownParallelRegionCB,
4385                                                   "OMP133", Remark);
4386        }
4387      }
4388  
4389      // Create all the blocks:
4390      //
4391      //                       InitCB = __kmpc_target_init(...)
4392      //                       BlockHwSize =
4393      //                         __kmpc_get_hardware_num_threads_in_block();
4394      //                       WarpSize = __kmpc_get_warp_size();
4395      //                       BlockSize = BlockHwSize - WarpSize;
4396      // IsWorkerCheckBB:      bool IsWorker = InitCB != -1;
4397      //                       if (IsWorker) {
4398      //                         if (InitCB >= BlockSize) return;
4399      // SMBeginBB:               __kmpc_barrier_simple_generic(...);
4400      //                         void *WorkFn;
4401      //                         bool Active = __kmpc_kernel_parallel(&WorkFn);
4402      //                         if (!WorkFn) return;
4403      // SMIsActiveCheckBB:       if (Active) {
4404      // SMIfCascadeCurrentBB:      if      (WorkFn == <ParFn0>)
4405      //                              ParFn0(...);
4406      // SMIfCascadeCurrentBB:      else if (WorkFn == <ParFn1>)
4407      //                              ParFn1(...);
4408      //                            ...
4409      // SMIfCascadeCurrentBB:      else
4410      //                              ((WorkFnTy*)WorkFn)(...);
4411      // SMEndParallelBB:           __kmpc_kernel_end_parallel(...);
4412      //                          }
4413      // SMDoneBB:                __kmpc_barrier_simple_generic(...);
4414      //                          goto SMBeginBB;
4415      //                       }
4416      // UserCodeEntryBB:      // user code
4417      //                       __kmpc_target_deinit(...)
4418      //
4419      auto &Ctx = getAnchorValue().getContext();
4420      Function *Kernel = getAssociatedFunction();
4421      assert(Kernel && "Expected an associated function!");
4422  
4423      BasicBlock *InitBB = KernelInitCB->getParent();
4424      BasicBlock *UserCodeEntryBB = InitBB->splitBasicBlock(
4425          KernelInitCB->getNextNode(), "thread.user_code.check");
4426      BasicBlock *IsWorkerCheckBB =
4427          BasicBlock::Create(Ctx, "is_worker_check", Kernel, UserCodeEntryBB);
4428      BasicBlock *StateMachineBeginBB = BasicBlock::Create(
4429          Ctx, "worker_state_machine.begin", Kernel, UserCodeEntryBB);
4430      BasicBlock *StateMachineFinishedBB = BasicBlock::Create(
4431          Ctx, "worker_state_machine.finished", Kernel, UserCodeEntryBB);
4432      BasicBlock *StateMachineIsActiveCheckBB = BasicBlock::Create(
4433          Ctx, "worker_state_machine.is_active.check", Kernel, UserCodeEntryBB);
4434      BasicBlock *StateMachineIfCascadeCurrentBB =
4435          BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
4436                             Kernel, UserCodeEntryBB);
4437      BasicBlock *StateMachineEndParallelBB =
4438          BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.end",
4439                             Kernel, UserCodeEntryBB);
4440      BasicBlock *StateMachineDoneBarrierBB = BasicBlock::Create(
4441          Ctx, "worker_state_machine.done.barrier", Kernel, UserCodeEntryBB);
4442      A.registerManifestAddedBasicBlock(*InitBB);
4443      A.registerManifestAddedBasicBlock(*UserCodeEntryBB);
4444      A.registerManifestAddedBasicBlock(*IsWorkerCheckBB);
4445      A.registerManifestAddedBasicBlock(*StateMachineBeginBB);
4446      A.registerManifestAddedBasicBlock(*StateMachineFinishedBB);
4447      A.registerManifestAddedBasicBlock(*StateMachineIsActiveCheckBB);
4448      A.registerManifestAddedBasicBlock(*StateMachineIfCascadeCurrentBB);
4449      A.registerManifestAddedBasicBlock(*StateMachineEndParallelBB);
4450      A.registerManifestAddedBasicBlock(*StateMachineDoneBarrierBB);
4451  
4452      const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
4453      ReturnInst::Create(Ctx, StateMachineFinishedBB)->setDebugLoc(DLoc);
4454      InitBB->getTerminator()->eraseFromParent();
4455  
4456      Instruction *IsWorker =
4457          ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_NE, KernelInitCB,
4458                           ConstantInt::get(KernelInitCB->getType(), -1),
4459                           "thread.is_worker", InitBB);
4460      IsWorker->setDebugLoc(DLoc);
4461      BranchInst::Create(IsWorkerCheckBB, UserCodeEntryBB, IsWorker, InitBB);
4462  
4463      Module &M = *Kernel->getParent();
4464      FunctionCallee BlockHwSizeFn =
4465          OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4466              M, OMPRTL___kmpc_get_hardware_num_threads_in_block);
4467      FunctionCallee WarpSizeFn =
4468          OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4469              M, OMPRTL___kmpc_get_warp_size);
4470      CallInst *BlockHwSize =
4471          CallInst::Create(BlockHwSizeFn, "block.hw_size", IsWorkerCheckBB);
4472      OMPInfoCache.setCallingConvention(BlockHwSizeFn, BlockHwSize);
4473      BlockHwSize->setDebugLoc(DLoc);
4474      CallInst *WarpSize =
4475          CallInst::Create(WarpSizeFn, "warp.size", IsWorkerCheckBB);
4476      OMPInfoCache.setCallingConvention(WarpSizeFn, WarpSize);
4477      WarpSize->setDebugLoc(DLoc);
4478      Instruction *BlockSize = BinaryOperator::CreateSub(
4479          BlockHwSize, WarpSize, "block.size", IsWorkerCheckBB);
4480      BlockSize->setDebugLoc(DLoc);
4481      Instruction *IsMainOrWorker = ICmpInst::Create(
4482          ICmpInst::ICmp, llvm::CmpInst::ICMP_SLT, KernelInitCB, BlockSize,
4483          "thread.is_main_or_worker", IsWorkerCheckBB);
4484      IsMainOrWorker->setDebugLoc(DLoc);
4485      BranchInst::Create(StateMachineBeginBB, StateMachineFinishedBB,
4486                         IsMainOrWorker, IsWorkerCheckBB);
4487  
4488      // Create local storage for the work function pointer.
4489      const DataLayout &DL = M.getDataLayout();
4490      Type *VoidPtrTy = PointerType::getUnqual(Ctx);
4491      Instruction *WorkFnAI =
4492          new AllocaInst(VoidPtrTy, DL.getAllocaAddrSpace(), nullptr,
4493                         "worker.work_fn.addr", Kernel->getEntryBlock().begin());
4494      WorkFnAI->setDebugLoc(DLoc);
4495  
4496      OMPInfoCache.OMPBuilder.updateToLocation(
4497          OpenMPIRBuilder::LocationDescription(
4498              IRBuilder<>::InsertPoint(StateMachineBeginBB,
4499                                       StateMachineBeginBB->end()),
4500              DLoc));
4501  
4502      Value *Ident = KernelInfo::getIdentFromKernelEnvironment(KernelEnvC);
4503      Value *GTid = KernelInitCB;
4504  
4505      FunctionCallee BarrierFn =
4506          OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4507              M, OMPRTL___kmpc_barrier_simple_generic);
4508      CallInst *Barrier =
4509          CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineBeginBB);
4510      OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
4511      Barrier->setDebugLoc(DLoc);
4512  
4513      if (WorkFnAI->getType()->getPointerAddressSpace() !=
4514          (unsigned int)AddressSpace::Generic) {
4515        WorkFnAI = new AddrSpaceCastInst(
4516            WorkFnAI, PointerType::get(Ctx, (unsigned int)AddressSpace::Generic),
4517            WorkFnAI->getName() + ".generic", StateMachineBeginBB);
4518        WorkFnAI->setDebugLoc(DLoc);
4519      }
4520  
4521      FunctionCallee KernelParallelFn =
4522          OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4523              M, OMPRTL___kmpc_kernel_parallel);
4524      CallInst *IsActiveWorker = CallInst::Create(
4525          KernelParallelFn, {WorkFnAI}, "worker.is_active", StateMachineBeginBB);
4526      OMPInfoCache.setCallingConvention(KernelParallelFn, IsActiveWorker);
4527      IsActiveWorker->setDebugLoc(DLoc);
4528      Instruction *WorkFn = new LoadInst(VoidPtrTy, WorkFnAI, "worker.work_fn",
4529                                         StateMachineBeginBB);
4530      WorkFn->setDebugLoc(DLoc);
4531  
4532      FunctionType *ParallelRegionFnTy = FunctionType::get(
4533          Type::getVoidTy(Ctx), {Type::getInt16Ty(Ctx), Type::getInt32Ty(Ctx)},
4534          false);
4535  
4536      Instruction *IsDone =
4537          ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn,
4538                           Constant::getNullValue(VoidPtrTy), "worker.is_done",
4539                           StateMachineBeginBB);
4540      IsDone->setDebugLoc(DLoc);
4541      BranchInst::Create(StateMachineFinishedBB, StateMachineIsActiveCheckBB,
4542                         IsDone, StateMachineBeginBB)
4543          ->setDebugLoc(DLoc);
4544  
4545      BranchInst::Create(StateMachineIfCascadeCurrentBB,
4546                         StateMachineDoneBarrierBB, IsActiveWorker,
4547                         StateMachineIsActiveCheckBB)
4548          ->setDebugLoc(DLoc);
4549  
4550      Value *ZeroArg =
4551          Constant::getNullValue(ParallelRegionFnTy->getParamType(0));
4552  
4553      const unsigned int WrapperFunctionArgNo = 6;
4554  
4555      // Now that we have most of the CFG skeleton it is time for the if-cascade
4556      // that checks the function pointer we got from the runtime against the
4557      // parallel regions we expect, if there are any.
4558      for (int I = 0, E = ReachedKnownParallelRegions.size(); I < E; ++I) {
4559        auto *CB = ReachedKnownParallelRegions[I];
4560        auto *ParallelRegion = dyn_cast<Function>(
4561            CB->getArgOperand(WrapperFunctionArgNo)->stripPointerCasts());
4562        BasicBlock *PRExecuteBB = BasicBlock::Create(
4563            Ctx, "worker_state_machine.parallel_region.execute", Kernel,
4564            StateMachineEndParallelBB);
4565        CallInst::Create(ParallelRegion, {ZeroArg, GTid}, "", PRExecuteBB)
4566            ->setDebugLoc(DLoc);
4567        BranchInst::Create(StateMachineEndParallelBB, PRExecuteBB)
4568            ->setDebugLoc(DLoc);
4569  
4570        BasicBlock *PRNextBB =
4571            BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
4572                               Kernel, StateMachineEndParallelBB);
4573        A.registerManifestAddedBasicBlock(*PRExecuteBB);
4574        A.registerManifestAddedBasicBlock(*PRNextBB);
4575  
4576        // Check if we need to compare the pointer at all or if we can just
4577        // call the parallel region function.
4578        Value *IsPR;
4579        if (I + 1 < E || !ReachedUnknownParallelRegions.empty()) {
4580          Instruction *CmpI = ICmpInst::Create(
4581              ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn, ParallelRegion,
4582              "worker.check_parallel_region", StateMachineIfCascadeCurrentBB);
4583          CmpI->setDebugLoc(DLoc);
4584          IsPR = CmpI;
4585        } else {
4586          IsPR = ConstantInt::getTrue(Ctx);
4587        }
4588  
4589        BranchInst::Create(PRExecuteBB, PRNextBB, IsPR,
4590                           StateMachineIfCascadeCurrentBB)
4591            ->setDebugLoc(DLoc);
4592        StateMachineIfCascadeCurrentBB = PRNextBB;
4593      }
4594  
4595      // At the end of the if-cascade we place the indirect function pointer call
4596      // in case we might need it, that is if there can be parallel regions we
4597      // have not handled in the if-cascade above.
4598      if (!ReachedUnknownParallelRegions.empty()) {
4599        StateMachineIfCascadeCurrentBB->setName(
4600            "worker_state_machine.parallel_region.fallback.execute");
4601        CallInst::Create(ParallelRegionFnTy, WorkFn, {ZeroArg, GTid}, "",
4602                         StateMachineIfCascadeCurrentBB)
4603            ->setDebugLoc(DLoc);
4604      }
4605      BranchInst::Create(StateMachineEndParallelBB,
4606                         StateMachineIfCascadeCurrentBB)
4607          ->setDebugLoc(DLoc);
4608  
4609      FunctionCallee EndParallelFn =
4610          OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4611              M, OMPRTL___kmpc_kernel_end_parallel);
4612      CallInst *EndParallel =
4613          CallInst::Create(EndParallelFn, {}, "", StateMachineEndParallelBB);
4614      OMPInfoCache.setCallingConvention(EndParallelFn, EndParallel);
4615      EndParallel->setDebugLoc(DLoc);
4616      BranchInst::Create(StateMachineDoneBarrierBB, StateMachineEndParallelBB)
4617          ->setDebugLoc(DLoc);
4618  
4619      CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineDoneBarrierBB)
4620          ->setDebugLoc(DLoc);
4621      BranchInst::Create(StateMachineBeginBB, StateMachineDoneBarrierBB)
4622          ->setDebugLoc(DLoc);
4623  
4624      return true;
4625    }
4626  
4627    /// Fixpoint iteration update function. Will be called every time a dependence
4628    /// changed its state (and in the beginning).
updateImpl__anondf4de5470111::AAKernelInfoFunction4629    ChangeStatus updateImpl(Attributor &A) override {
4630      KernelInfoState StateBefore = getState();
4631  
4632      // When we leave this function this RAII will make sure the member
4633      // KernelEnvC is updated properly depending on the state. That member is
4634      // used for simplification of values and needs to be up to date at all
4635      // times.
4636      struct UpdateKernelEnvCRAII {
4637        AAKernelInfoFunction &AA;
4638  
4639        UpdateKernelEnvCRAII(AAKernelInfoFunction &AA) : AA(AA) {}
4640  
4641        ~UpdateKernelEnvCRAII() {
4642          if (!AA.KernelEnvC)
4643            return;
4644  
4645          ConstantStruct *ExistingKernelEnvC =
4646              KernelInfo::getKernelEnvironementFromKernelInitCB(AA.KernelInitCB);
4647  
4648          if (!AA.isValidState()) {
4649            AA.KernelEnvC = ExistingKernelEnvC;
4650            return;
4651          }
4652  
4653          if (!AA.ReachedKnownParallelRegions.isValidState())
4654            AA.setUseGenericStateMachineOfKernelEnvironment(
4655                KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
4656                    ExistingKernelEnvC));
4657  
4658          if (!AA.SPMDCompatibilityTracker.isValidState())
4659            AA.setExecModeOfKernelEnvironment(
4660                KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC));
4661  
4662          ConstantInt *MayUseNestedParallelismC =
4663              KernelInfo::getMayUseNestedParallelismFromKernelEnvironment(
4664                  AA.KernelEnvC);
4665          ConstantInt *NewMayUseNestedParallelismC = ConstantInt::get(
4666              MayUseNestedParallelismC->getIntegerType(), AA.NestedParallelism);
4667          AA.setMayUseNestedParallelismOfKernelEnvironment(
4668              NewMayUseNestedParallelismC);
4669        }
4670      } RAII(*this);
4671  
4672      // Callback to check a read/write instruction.
4673      auto CheckRWInst = [&](Instruction &I) {
4674        // We handle calls later.
4675        if (isa<CallBase>(I))
4676          return true;
4677        // We only care about write effects.
4678        if (!I.mayWriteToMemory())
4679          return true;
4680        if (auto *SI = dyn_cast<StoreInst>(&I)) {
4681          const auto *UnderlyingObjsAA = A.getAAFor<AAUnderlyingObjects>(
4682              *this, IRPosition::value(*SI->getPointerOperand()),
4683              DepClassTy::OPTIONAL);
4684          auto *HS = A.getAAFor<AAHeapToStack>(
4685              *this, IRPosition::function(*I.getFunction()),
4686              DepClassTy::OPTIONAL);
4687          if (UnderlyingObjsAA &&
4688              UnderlyingObjsAA->forallUnderlyingObjects([&](Value &Obj) {
4689                if (AA::isAssumedThreadLocalObject(A, Obj, *this))
4690                  return true;
4691                // Check for AAHeapToStack moved objects which must not be
4692                // guarded.
4693                auto *CB = dyn_cast<CallBase>(&Obj);
4694                return CB && HS && HS->isAssumedHeapToStack(*CB);
4695              }))
4696            return true;
4697        }
4698  
4699        // Insert instruction that needs guarding.
4700        SPMDCompatibilityTracker.insert(&I);
4701        return true;
4702      };
4703  
4704      bool UsedAssumedInformationInCheckRWInst = false;
4705      if (!SPMDCompatibilityTracker.isAtFixpoint())
4706        if (!A.checkForAllReadWriteInstructions(
4707                CheckRWInst, *this, UsedAssumedInformationInCheckRWInst))
4708          SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4709  
4710      bool UsedAssumedInformationFromReachingKernels = false;
4711      if (!IsKernelEntry) {
4712        updateParallelLevels(A);
4713  
4714        bool AllReachingKernelsKnown = true;
4715        updateReachingKernelEntries(A, AllReachingKernelsKnown);
4716        UsedAssumedInformationFromReachingKernels = !AllReachingKernelsKnown;
4717  
4718        if (!SPMDCompatibilityTracker.empty()) {
4719          if (!ParallelLevels.isValidState())
4720            SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4721          else if (!ReachingKernelEntries.isValidState())
4722            SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4723          else {
4724            // Check if all reaching kernels agree on the mode as we can otherwise
4725            // not guard instructions. We might not be sure about the mode so we
4726            // we cannot fix the internal spmd-zation state either.
4727            int SPMD = 0, Generic = 0;
4728            for (auto *Kernel : ReachingKernelEntries) {
4729              auto *CBAA = A.getAAFor<AAKernelInfo>(
4730                  *this, IRPosition::function(*Kernel), DepClassTy::OPTIONAL);
4731              if (CBAA && CBAA->SPMDCompatibilityTracker.isValidState() &&
4732                  CBAA->SPMDCompatibilityTracker.isAssumed())
4733                ++SPMD;
4734              else
4735                ++Generic;
4736              if (!CBAA || !CBAA->SPMDCompatibilityTracker.isAtFixpoint())
4737                UsedAssumedInformationFromReachingKernels = true;
4738            }
4739            if (SPMD != 0 && Generic != 0)
4740              SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4741          }
4742        }
4743      }
4744  
4745      // Callback to check a call instruction.
4746      bool AllParallelRegionStatesWereFixed = true;
4747      bool AllSPMDStatesWereFixed = true;
4748      auto CheckCallInst = [&](Instruction &I) {
4749        auto &CB = cast<CallBase>(I);
4750        auto *CBAA = A.getAAFor<AAKernelInfo>(
4751            *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
4752        if (!CBAA)
4753          return false;
4754        getState() ^= CBAA->getState();
4755        AllSPMDStatesWereFixed &= CBAA->SPMDCompatibilityTracker.isAtFixpoint();
4756        AllParallelRegionStatesWereFixed &=
4757            CBAA->ReachedKnownParallelRegions.isAtFixpoint();
4758        AllParallelRegionStatesWereFixed &=
4759            CBAA->ReachedUnknownParallelRegions.isAtFixpoint();
4760        return true;
4761      };
4762  
4763      bool UsedAssumedInformationInCheckCallInst = false;
4764      if (!A.checkForAllCallLikeInstructions(
4765              CheckCallInst, *this, UsedAssumedInformationInCheckCallInst)) {
4766        LLVM_DEBUG(dbgs() << TAG
4767                          << "Failed to visit all call-like instructions!\n";);
4768        return indicatePessimisticFixpoint();
4769      }
4770  
4771      // If we haven't used any assumed information for the reached parallel
4772      // region states we can fix it.
4773      if (!UsedAssumedInformationInCheckCallInst &&
4774          AllParallelRegionStatesWereFixed) {
4775        ReachedKnownParallelRegions.indicateOptimisticFixpoint();
4776        ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
4777      }
4778  
4779      // If we haven't used any assumed information for the SPMD state we can fix
4780      // it.
4781      if (!UsedAssumedInformationInCheckRWInst &&
4782          !UsedAssumedInformationInCheckCallInst &&
4783          !UsedAssumedInformationFromReachingKernels && AllSPMDStatesWereFixed)
4784        SPMDCompatibilityTracker.indicateOptimisticFixpoint();
4785  
4786      return StateBefore == getState() ? ChangeStatus::UNCHANGED
4787                                       : ChangeStatus::CHANGED;
4788    }
4789  
4790  private:
4791    /// Update info regarding reaching kernels.
updateReachingKernelEntries__anondf4de5470111::AAKernelInfoFunction4792    void updateReachingKernelEntries(Attributor &A,
4793                                     bool &AllReachingKernelsKnown) {
4794      auto PredCallSite = [&](AbstractCallSite ACS) {
4795        Function *Caller = ACS.getInstruction()->getFunction();
4796  
4797        assert(Caller && "Caller is nullptr");
4798  
4799        auto *CAA = A.getOrCreateAAFor<AAKernelInfo>(
4800            IRPosition::function(*Caller), this, DepClassTy::REQUIRED);
4801        if (CAA && CAA->ReachingKernelEntries.isValidState()) {
4802          ReachingKernelEntries ^= CAA->ReachingKernelEntries;
4803          return true;
4804        }
4805  
4806        // We lost track of the caller of the associated function, any kernel
4807        // could reach now.
4808        ReachingKernelEntries.indicatePessimisticFixpoint();
4809  
4810        return true;
4811      };
4812  
4813      if (!A.checkForAllCallSites(PredCallSite, *this,
4814                                  true /* RequireAllCallSites */,
4815                                  AllReachingKernelsKnown))
4816        ReachingKernelEntries.indicatePessimisticFixpoint();
4817    }
4818  
4819    /// Update info regarding parallel levels.
updateParallelLevels__anondf4de5470111::AAKernelInfoFunction4820    void updateParallelLevels(Attributor &A) {
4821      auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4822      OMPInformationCache::RuntimeFunctionInfo &Parallel51RFI =
4823          OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
4824  
4825      auto PredCallSite = [&](AbstractCallSite ACS) {
4826        Function *Caller = ACS.getInstruction()->getFunction();
4827  
4828        assert(Caller && "Caller is nullptr");
4829  
4830        auto *CAA =
4831            A.getOrCreateAAFor<AAKernelInfo>(IRPosition::function(*Caller));
4832        if (CAA && CAA->ParallelLevels.isValidState()) {
4833          // Any function that is called by `__kmpc_parallel_51` will not be
4834          // folded as the parallel level in the function is updated. In order to
4835          // get it right, all the analysis would depend on the implentation. That
4836          // said, if in the future any change to the implementation, the analysis
4837          // could be wrong. As a consequence, we are just conservative here.
4838          if (Caller == Parallel51RFI.Declaration) {
4839            ParallelLevels.indicatePessimisticFixpoint();
4840            return true;
4841          }
4842  
4843          ParallelLevels ^= CAA->ParallelLevels;
4844  
4845          return true;
4846        }
4847  
4848        // We lost track of the caller of the associated function, any kernel
4849        // could reach now.
4850        ParallelLevels.indicatePessimisticFixpoint();
4851  
4852        return true;
4853      };
4854  
4855      bool AllCallSitesKnown = true;
4856      if (!A.checkForAllCallSites(PredCallSite, *this,
4857                                  true /* RequireAllCallSites */,
4858                                  AllCallSitesKnown))
4859        ParallelLevels.indicatePessimisticFixpoint();
4860    }
4861  };
4862  
4863  /// The call site kernel info abstract attribute, basically, what can we say
4864  /// about a call site with regards to the KernelInfoState. For now this simply
4865  /// forwards the information from the callee.
4866  struct AAKernelInfoCallSite : AAKernelInfo {
AAKernelInfoCallSite__anondf4de5470111::AAKernelInfoCallSite4867    AAKernelInfoCallSite(const IRPosition &IRP, Attributor &A)
4868        : AAKernelInfo(IRP, A) {}
4869  
4870    /// See AbstractAttribute::initialize(...).
initialize__anondf4de5470111::AAKernelInfoCallSite4871    void initialize(Attributor &A) override {
4872      AAKernelInfo::initialize(A);
4873  
4874      CallBase &CB = cast<CallBase>(getAssociatedValue());
4875      auto *AssumptionAA = A.getAAFor<AAAssumptionInfo>(
4876          *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
4877  
4878      // Check for SPMD-mode assumptions.
4879      if (AssumptionAA && AssumptionAA->hasAssumption("ompx_spmd_amenable")) {
4880        indicateOptimisticFixpoint();
4881        return;
4882      }
4883  
4884      // First weed out calls we do not care about, that is readonly/readnone
4885      // calls, intrinsics, and "no_openmp" calls. Neither of these can reach a
4886      // parallel region or anything else we are looking for.
4887      if (!CB.mayWriteToMemory() || isa<IntrinsicInst>(CB)) {
4888        indicateOptimisticFixpoint();
4889        return;
4890      }
4891  
4892      // Next we check if we know the callee. If it is a known OpenMP function
4893      // we will handle them explicitly in the switch below. If it is not, we
4894      // will use an AAKernelInfo object on the callee to gather information and
4895      // merge that into the current state. The latter happens in the updateImpl.
4896      auto CheckCallee = [&](Function *Callee, unsigned NumCallees) {
4897        auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4898        const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
4899        if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
4900          // Unknown caller or declarations are not analyzable, we give up.
4901          if (!Callee || !A.isFunctionIPOAmendable(*Callee)) {
4902  
4903            // Unknown callees might contain parallel regions, except if they have
4904            // an appropriate assumption attached.
4905            if (!AssumptionAA ||
4906                !(AssumptionAA->hasAssumption("omp_no_openmp") ||
4907                  AssumptionAA->hasAssumption("omp_no_parallelism")))
4908              ReachedUnknownParallelRegions.insert(&CB);
4909  
4910            // If SPMDCompatibilityTracker is not fixed, we need to give up on the
4911            // idea we can run something unknown in SPMD-mode.
4912            if (!SPMDCompatibilityTracker.isAtFixpoint()) {
4913              SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4914              SPMDCompatibilityTracker.insert(&CB);
4915            }
4916  
4917            // We have updated the state for this unknown call properly, there
4918            // won't be any change so we indicate a fixpoint.
4919            indicateOptimisticFixpoint();
4920          }
4921          // If the callee is known and can be used in IPO, we will update the
4922          // state based on the callee state in updateImpl.
4923          return;
4924        }
4925        if (NumCallees > 1) {
4926          indicatePessimisticFixpoint();
4927          return;
4928        }
4929  
4930        RuntimeFunction RF = It->getSecond();
4931        switch (RF) {
4932        // All the functions we know are compatible with SPMD mode.
4933        case OMPRTL___kmpc_is_spmd_exec_mode:
4934        case OMPRTL___kmpc_distribute_static_fini:
4935        case OMPRTL___kmpc_for_static_fini:
4936        case OMPRTL___kmpc_global_thread_num:
4937        case OMPRTL___kmpc_get_hardware_num_threads_in_block:
4938        case OMPRTL___kmpc_get_hardware_num_blocks:
4939        case OMPRTL___kmpc_single:
4940        case OMPRTL___kmpc_end_single:
4941        case OMPRTL___kmpc_master:
4942        case OMPRTL___kmpc_end_master:
4943        case OMPRTL___kmpc_barrier:
4944        case OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2:
4945        case OMPRTL___kmpc_nvptx_teams_reduce_nowait_v2:
4946        case OMPRTL___kmpc_error:
4947        case OMPRTL___kmpc_flush:
4948        case OMPRTL___kmpc_get_hardware_thread_id_in_block:
4949        case OMPRTL___kmpc_get_warp_size:
4950        case OMPRTL_omp_get_thread_num:
4951        case OMPRTL_omp_get_num_threads:
4952        case OMPRTL_omp_get_max_threads:
4953        case OMPRTL_omp_in_parallel:
4954        case OMPRTL_omp_get_dynamic:
4955        case OMPRTL_omp_get_cancellation:
4956        case OMPRTL_omp_get_nested:
4957        case OMPRTL_omp_get_schedule:
4958        case OMPRTL_omp_get_thread_limit:
4959        case OMPRTL_omp_get_supported_active_levels:
4960        case OMPRTL_omp_get_max_active_levels:
4961        case OMPRTL_omp_get_level:
4962        case OMPRTL_omp_get_ancestor_thread_num:
4963        case OMPRTL_omp_get_team_size:
4964        case OMPRTL_omp_get_active_level:
4965        case OMPRTL_omp_in_final:
4966        case OMPRTL_omp_get_proc_bind:
4967        case OMPRTL_omp_get_num_places:
4968        case OMPRTL_omp_get_num_procs:
4969        case OMPRTL_omp_get_place_proc_ids:
4970        case OMPRTL_omp_get_place_num:
4971        case OMPRTL_omp_get_partition_num_places:
4972        case OMPRTL_omp_get_partition_place_nums:
4973        case OMPRTL_omp_get_wtime:
4974          break;
4975        case OMPRTL___kmpc_distribute_static_init_4:
4976        case OMPRTL___kmpc_distribute_static_init_4u:
4977        case OMPRTL___kmpc_distribute_static_init_8:
4978        case OMPRTL___kmpc_distribute_static_init_8u:
4979        case OMPRTL___kmpc_for_static_init_4:
4980        case OMPRTL___kmpc_for_static_init_4u:
4981        case OMPRTL___kmpc_for_static_init_8:
4982        case OMPRTL___kmpc_for_static_init_8u: {
4983          // Check the schedule and allow static schedule in SPMD mode.
4984          unsigned ScheduleArgOpNo = 2;
4985          auto *ScheduleTypeCI =
4986              dyn_cast<ConstantInt>(CB.getArgOperand(ScheduleArgOpNo));
4987          unsigned ScheduleTypeVal =
4988              ScheduleTypeCI ? ScheduleTypeCI->getZExtValue() : 0;
4989          switch (OMPScheduleType(ScheduleTypeVal)) {
4990          case OMPScheduleType::UnorderedStatic:
4991          case OMPScheduleType::UnorderedStaticChunked:
4992          case OMPScheduleType::OrderedDistribute:
4993          case OMPScheduleType::OrderedDistributeChunked:
4994            break;
4995          default:
4996            SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4997            SPMDCompatibilityTracker.insert(&CB);
4998            break;
4999          };
5000        } break;
5001        case OMPRTL___kmpc_target_init:
5002          KernelInitCB = &CB;
5003          break;
5004        case OMPRTL___kmpc_target_deinit:
5005          KernelDeinitCB = &CB;
5006          break;
5007        case OMPRTL___kmpc_parallel_51:
5008          if (!handleParallel51(A, CB))
5009            indicatePessimisticFixpoint();
5010          return;
5011        case OMPRTL___kmpc_omp_task:
5012          // We do not look into tasks right now, just give up.
5013          SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5014          SPMDCompatibilityTracker.insert(&CB);
5015          ReachedUnknownParallelRegions.insert(&CB);
5016          break;
5017        case OMPRTL___kmpc_alloc_shared:
5018        case OMPRTL___kmpc_free_shared:
5019          // Return without setting a fixpoint, to be resolved in updateImpl.
5020          return;
5021        default:
5022          // Unknown OpenMP runtime calls cannot be executed in SPMD-mode,
5023          // generally. However, they do not hide parallel regions.
5024          SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5025          SPMDCompatibilityTracker.insert(&CB);
5026          break;
5027        }
5028        // All other OpenMP runtime calls will not reach parallel regions so they
5029        // can be safely ignored for now. Since it is a known OpenMP runtime call
5030        // we have now modeled all effects and there is no need for any update.
5031        indicateOptimisticFixpoint();
5032      };
5033  
5034      const auto *AACE =
5035          A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::OPTIONAL);
5036      if (!AACE || !AACE->getState().isValidState() || AACE->hasUnknownCallee()) {
5037        CheckCallee(getAssociatedFunction(), 1);
5038        return;
5039      }
5040      const auto &OptimisticEdges = AACE->getOptimisticEdges();
5041      for (auto *Callee : OptimisticEdges) {
5042        CheckCallee(Callee, OptimisticEdges.size());
5043        if (isAtFixpoint())
5044          break;
5045      }
5046    }
5047  
updateImpl__anondf4de5470111::AAKernelInfoCallSite5048    ChangeStatus updateImpl(Attributor &A) override {
5049      // TODO: Once we have call site specific value information we can provide
5050      //       call site specific liveness information and then it makes
5051      //       sense to specialize attributes for call sites arguments instead of
5052      //       redirecting requests to the callee argument.
5053      auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
5054      KernelInfoState StateBefore = getState();
5055  
5056      auto CheckCallee = [&](Function *F, int NumCallees) {
5057        const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(F);
5058  
5059        // If F is not a runtime function, propagate the AAKernelInfo of the
5060        // callee.
5061        if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
5062          const IRPosition &FnPos = IRPosition::function(*F);
5063          auto *FnAA =
5064              A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED);
5065          if (!FnAA)
5066            return indicatePessimisticFixpoint();
5067          if (getState() == FnAA->getState())
5068            return ChangeStatus::UNCHANGED;
5069          getState() = FnAA->getState();
5070          return ChangeStatus::CHANGED;
5071        }
5072        if (NumCallees > 1)
5073          return indicatePessimisticFixpoint();
5074  
5075        CallBase &CB = cast<CallBase>(getAssociatedValue());
5076        if (It->getSecond() == OMPRTL___kmpc_parallel_51) {
5077          if (!handleParallel51(A, CB))
5078            return indicatePessimisticFixpoint();
5079          return StateBefore == getState() ? ChangeStatus::UNCHANGED
5080                                           : ChangeStatus::CHANGED;
5081        }
5082  
5083        // F is a runtime function that allocates or frees memory, check
5084        // AAHeapToStack and AAHeapToShared.
5085        assert(
5086            (It->getSecond() == OMPRTL___kmpc_alloc_shared ||
5087             It->getSecond() == OMPRTL___kmpc_free_shared) &&
5088            "Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call");
5089  
5090        auto *HeapToStackAA = A.getAAFor<AAHeapToStack>(
5091            *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
5092        auto *HeapToSharedAA = A.getAAFor<AAHeapToShared>(
5093            *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
5094  
5095        RuntimeFunction RF = It->getSecond();
5096  
5097        switch (RF) {
5098        // If neither HeapToStack nor HeapToShared assume the call is removed,
5099        // assume SPMD incompatibility.
5100        case OMPRTL___kmpc_alloc_shared:
5101          if ((!HeapToStackAA || !HeapToStackAA->isAssumedHeapToStack(CB)) &&
5102              (!HeapToSharedAA || !HeapToSharedAA->isAssumedHeapToShared(CB)))
5103            SPMDCompatibilityTracker.insert(&CB);
5104          break;
5105        case OMPRTL___kmpc_free_shared:
5106          if ((!HeapToStackAA ||
5107               !HeapToStackAA->isAssumedHeapToStackRemovedFree(CB)) &&
5108              (!HeapToSharedAA ||
5109               !HeapToSharedAA->isAssumedHeapToSharedRemovedFree(CB)))
5110            SPMDCompatibilityTracker.insert(&CB);
5111          break;
5112        default:
5113          SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5114          SPMDCompatibilityTracker.insert(&CB);
5115        }
5116        return ChangeStatus::CHANGED;
5117      };
5118  
5119      const auto *AACE =
5120          A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::OPTIONAL);
5121      if (!AACE || !AACE->getState().isValidState() || AACE->hasUnknownCallee()) {
5122        if (Function *F = getAssociatedFunction())
5123          CheckCallee(F, /*NumCallees=*/1);
5124      } else {
5125        const auto &OptimisticEdges = AACE->getOptimisticEdges();
5126        for (auto *Callee : OptimisticEdges) {
5127          CheckCallee(Callee, OptimisticEdges.size());
5128          if (isAtFixpoint())
5129            break;
5130        }
5131      }
5132  
5133      return StateBefore == getState() ? ChangeStatus::UNCHANGED
5134                                       : ChangeStatus::CHANGED;
5135    }
5136  
5137    /// Deal with a __kmpc_parallel_51 call (\p CB). Returns true if the call was
5138    /// handled, if a problem occurred, false is returned.
handleParallel51__anondf4de5470111::AAKernelInfoCallSite5139    bool handleParallel51(Attributor &A, CallBase &CB) {
5140      const unsigned int NonWrapperFunctionArgNo = 5;
5141      const unsigned int WrapperFunctionArgNo = 6;
5142      auto ParallelRegionOpArgNo = SPMDCompatibilityTracker.isAssumed()
5143                                       ? NonWrapperFunctionArgNo
5144                                       : WrapperFunctionArgNo;
5145  
5146      auto *ParallelRegion = dyn_cast<Function>(
5147          CB.getArgOperand(ParallelRegionOpArgNo)->stripPointerCasts());
5148      if (!ParallelRegion)
5149        return false;
5150  
5151      ReachedKnownParallelRegions.insert(&CB);
5152      /// Check nested parallelism
5153      auto *FnAA = A.getAAFor<AAKernelInfo>(
5154          *this, IRPosition::function(*ParallelRegion), DepClassTy::OPTIONAL);
5155      NestedParallelism |= !FnAA || !FnAA->getState().isValidState() ||
5156                           !FnAA->ReachedKnownParallelRegions.empty() ||
5157                           !FnAA->ReachedKnownParallelRegions.isValidState() ||
5158                           !FnAA->ReachedUnknownParallelRegions.isValidState() ||
5159                           !FnAA->ReachedUnknownParallelRegions.empty();
5160      return true;
5161    }
5162  };
5163  
5164  struct AAFoldRuntimeCall
5165      : public StateWrapper<BooleanState, AbstractAttribute> {
5166    using Base = StateWrapper<BooleanState, AbstractAttribute>;
5167  
AAFoldRuntimeCall__anondf4de5470111::AAFoldRuntimeCall5168    AAFoldRuntimeCall(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
5169  
5170    /// Statistics are tracked as part of manifest for now.
trackStatistics__anondf4de5470111::AAFoldRuntimeCall5171    void trackStatistics() const override {}
5172  
5173    /// Create an abstract attribute biew for the position \p IRP.
5174    static AAFoldRuntimeCall &createForPosition(const IRPosition &IRP,
5175                                                Attributor &A);
5176  
5177    /// See AbstractAttribute::getName()
getName__anondf4de5470111::AAFoldRuntimeCall5178    const std::string getName() const override { return "AAFoldRuntimeCall"; }
5179  
5180    /// See AbstractAttribute::getIdAddr()
getIdAddr__anondf4de5470111::AAFoldRuntimeCall5181    const char *getIdAddr() const override { return &ID; }
5182  
5183    /// This function should return true if the type of the \p AA is
5184    /// AAFoldRuntimeCall
classof__anondf4de5470111::AAFoldRuntimeCall5185    static bool classof(const AbstractAttribute *AA) {
5186      return (AA->getIdAddr() == &ID);
5187    }
5188  
5189    static const char ID;
5190  };
5191  
5192  struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
AAFoldRuntimeCallCallSiteReturned__anondf4de5470111::AAFoldRuntimeCallCallSiteReturned5193    AAFoldRuntimeCallCallSiteReturned(const IRPosition &IRP, Attributor &A)
5194        : AAFoldRuntimeCall(IRP, A) {}
5195  
5196    /// See AbstractAttribute::getAsStr()
getAsStr__anondf4de5470111::AAFoldRuntimeCallCallSiteReturned5197    const std::string getAsStr(Attributor *) const override {
5198      if (!isValidState())
5199        return "<invalid>";
5200  
5201      std::string Str("simplified value: ");
5202  
5203      if (!SimplifiedValue)
5204        return Str + std::string("none");
5205  
5206      if (!*SimplifiedValue)
5207        return Str + std::string("nullptr");
5208  
5209      if (ConstantInt *CI = dyn_cast<ConstantInt>(*SimplifiedValue))
5210        return Str + std::to_string(CI->getSExtValue());
5211  
5212      return Str + std::string("unknown");
5213    }
5214  
initialize__anondf4de5470111::AAFoldRuntimeCallCallSiteReturned5215    void initialize(Attributor &A) override {
5216      if (DisableOpenMPOptFolding)
5217        indicatePessimisticFixpoint();
5218  
5219      Function *Callee = getAssociatedFunction();
5220  
5221      auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
5222      const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
5223      assert(It != OMPInfoCache.RuntimeFunctionIDMap.end() &&
5224             "Expected a known OpenMP runtime function");
5225  
5226      RFKind = It->getSecond();
5227  
5228      CallBase &CB = cast<CallBase>(getAssociatedValue());
5229      A.registerSimplificationCallback(
5230          IRPosition::callsite_returned(CB),
5231          [&](const IRPosition &IRP, const AbstractAttribute *AA,
5232              bool &UsedAssumedInformation) -> std::optional<Value *> {
5233            assert((isValidState() ||
5234                    (SimplifiedValue && *SimplifiedValue == nullptr)) &&
5235                   "Unexpected invalid state!");
5236  
5237            if (!isAtFixpoint()) {
5238              UsedAssumedInformation = true;
5239              if (AA)
5240                A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
5241            }
5242            return SimplifiedValue;
5243          });
5244    }
5245  
updateImpl__anondf4de5470111::AAFoldRuntimeCallCallSiteReturned5246    ChangeStatus updateImpl(Attributor &A) override {
5247      ChangeStatus Changed = ChangeStatus::UNCHANGED;
5248      switch (RFKind) {
5249      case OMPRTL___kmpc_is_spmd_exec_mode:
5250        Changed |= foldIsSPMDExecMode(A);
5251        break;
5252      case OMPRTL___kmpc_parallel_level:
5253        Changed |= foldParallelLevel(A);
5254        break;
5255      case OMPRTL___kmpc_get_hardware_num_threads_in_block:
5256        Changed = Changed | foldKernelFnAttribute(A, "omp_target_thread_limit");
5257        break;
5258      case OMPRTL___kmpc_get_hardware_num_blocks:
5259        Changed = Changed | foldKernelFnAttribute(A, "omp_target_num_teams");
5260        break;
5261      default:
5262        llvm_unreachable("Unhandled OpenMP runtime function!");
5263      }
5264  
5265      return Changed;
5266    }
5267  
manifest__anondf4de5470111::AAFoldRuntimeCallCallSiteReturned5268    ChangeStatus manifest(Attributor &A) override {
5269      ChangeStatus Changed = ChangeStatus::UNCHANGED;
5270  
5271      if (SimplifiedValue && *SimplifiedValue) {
5272        Instruction &I = *getCtxI();
5273        A.changeAfterManifest(IRPosition::inst(I), **SimplifiedValue);
5274        A.deleteAfterManifest(I);
5275  
5276        CallBase *CB = dyn_cast<CallBase>(&I);
5277        auto Remark = [&](OptimizationRemark OR) {
5278          if (auto *C = dyn_cast<ConstantInt>(*SimplifiedValue))
5279            return OR << "Replacing OpenMP runtime call "
5280                      << CB->getCalledFunction()->getName() << " with "
5281                      << ore::NV("FoldedValue", C->getZExtValue()) << ".";
5282          return OR << "Replacing OpenMP runtime call "
5283                    << CB->getCalledFunction()->getName() << ".";
5284        };
5285  
5286        if (CB && EnableVerboseRemarks)
5287          A.emitRemark<OptimizationRemark>(CB, "OMP180", Remark);
5288  
5289        LLVM_DEBUG(dbgs() << TAG << "Replacing runtime call: " << I << " with "
5290                          << **SimplifiedValue << "\n");
5291  
5292        Changed = ChangeStatus::CHANGED;
5293      }
5294  
5295      return Changed;
5296    }
5297  
indicatePessimisticFixpoint__anondf4de5470111::AAFoldRuntimeCallCallSiteReturned5298    ChangeStatus indicatePessimisticFixpoint() override {
5299      SimplifiedValue = nullptr;
5300      return AAFoldRuntimeCall::indicatePessimisticFixpoint();
5301    }
5302  
5303  private:
5304    /// Fold __kmpc_is_spmd_exec_mode into a constant if possible.
foldIsSPMDExecMode__anondf4de5470111::AAFoldRuntimeCallCallSiteReturned5305    ChangeStatus foldIsSPMDExecMode(Attributor &A) {
5306      std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
5307  
5308      unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
5309      unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
5310      auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
5311          *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
5312  
5313      if (!CallerKernelInfoAA ||
5314          !CallerKernelInfoAA->ReachingKernelEntries.isValidState())
5315        return indicatePessimisticFixpoint();
5316  
5317      for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
5318        auto *AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
5319                                            DepClassTy::REQUIRED);
5320  
5321        if (!AA || !AA->isValidState()) {
5322          SimplifiedValue = nullptr;
5323          return indicatePessimisticFixpoint();
5324        }
5325  
5326        if (AA->SPMDCompatibilityTracker.isAssumed()) {
5327          if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5328            ++KnownSPMDCount;
5329          else
5330            ++AssumedSPMDCount;
5331        } else {
5332          if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5333            ++KnownNonSPMDCount;
5334          else
5335            ++AssumedNonSPMDCount;
5336        }
5337      }
5338  
5339      if ((AssumedSPMDCount + KnownSPMDCount) &&
5340          (AssumedNonSPMDCount + KnownNonSPMDCount))
5341        return indicatePessimisticFixpoint();
5342  
5343      auto &Ctx = getAnchorValue().getContext();
5344      if (KnownSPMDCount || AssumedSPMDCount) {
5345        assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
5346               "Expected only SPMD kernels!");
5347        // All reaching kernels are in SPMD mode. Update all function calls to
5348        // __kmpc_is_spmd_exec_mode to 1.
5349        SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), true);
5350      } else if (KnownNonSPMDCount || AssumedNonSPMDCount) {
5351        assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
5352               "Expected only non-SPMD kernels!");
5353        // All reaching kernels are in non-SPMD mode. Update all function
5354        // calls to __kmpc_is_spmd_exec_mode to 0.
5355        SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), false);
5356      } else {
5357        // We have empty reaching kernels, therefore we cannot tell if the
5358        // associated call site can be folded. At this moment, SimplifiedValue
5359        // must be none.
5360        assert(!SimplifiedValue && "SimplifiedValue should be none");
5361      }
5362  
5363      return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
5364                                                      : ChangeStatus::CHANGED;
5365    }
5366  
5367    /// Fold __kmpc_parallel_level into a constant if possible.
foldParallelLevel__anondf4de5470111::AAFoldRuntimeCallCallSiteReturned5368    ChangeStatus foldParallelLevel(Attributor &A) {
5369      std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
5370  
5371      auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
5372          *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
5373  
5374      if (!CallerKernelInfoAA ||
5375          !CallerKernelInfoAA->ParallelLevels.isValidState())
5376        return indicatePessimisticFixpoint();
5377  
5378      if (!CallerKernelInfoAA->ReachingKernelEntries.isValidState())
5379        return indicatePessimisticFixpoint();
5380  
5381      if (CallerKernelInfoAA->ReachingKernelEntries.empty()) {
5382        assert(!SimplifiedValue &&
5383               "SimplifiedValue should keep none at this point");
5384        return ChangeStatus::UNCHANGED;
5385      }
5386  
5387      unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
5388      unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
5389      for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
5390        auto *AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
5391                                            DepClassTy::REQUIRED);
5392        if (!AA || !AA->SPMDCompatibilityTracker.isValidState())
5393          return indicatePessimisticFixpoint();
5394  
5395        if (AA->SPMDCompatibilityTracker.isAssumed()) {
5396          if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5397            ++KnownSPMDCount;
5398          else
5399            ++AssumedSPMDCount;
5400        } else {
5401          if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5402            ++KnownNonSPMDCount;
5403          else
5404            ++AssumedNonSPMDCount;
5405        }
5406      }
5407  
5408      if ((AssumedSPMDCount + KnownSPMDCount) &&
5409          (AssumedNonSPMDCount + KnownNonSPMDCount))
5410        return indicatePessimisticFixpoint();
5411  
5412      auto &Ctx = getAnchorValue().getContext();
5413      // If the caller can only be reached by SPMD kernel entries, the parallel
5414      // level is 1. Similarly, if the caller can only be reached by non-SPMD
5415      // kernel entries, it is 0.
5416      if (AssumedSPMDCount || KnownSPMDCount) {
5417        assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
5418               "Expected only SPMD kernels!");
5419        SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 1);
5420      } else {
5421        assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
5422               "Expected only non-SPMD kernels!");
5423        SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 0);
5424      }
5425      return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
5426                                                      : ChangeStatus::CHANGED;
5427    }
5428  
foldKernelFnAttribute__anondf4de5470111::AAFoldRuntimeCallCallSiteReturned5429    ChangeStatus foldKernelFnAttribute(Attributor &A, llvm::StringRef Attr) {
5430      // Specialize only if all the calls agree with the attribute constant value
5431      int32_t CurrentAttrValue = -1;
5432      std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
5433  
5434      auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
5435          *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
5436  
5437      if (!CallerKernelInfoAA ||
5438          !CallerKernelInfoAA->ReachingKernelEntries.isValidState())
5439        return indicatePessimisticFixpoint();
5440  
5441      // Iterate over the kernels that reach this function
5442      for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
5443        int32_t NextAttrVal = K->getFnAttributeAsParsedInteger(Attr, -1);
5444  
5445        if (NextAttrVal == -1 ||
5446            (CurrentAttrValue != -1 && CurrentAttrValue != NextAttrVal))
5447          return indicatePessimisticFixpoint();
5448        CurrentAttrValue = NextAttrVal;
5449      }
5450  
5451      if (CurrentAttrValue != -1) {
5452        auto &Ctx = getAnchorValue().getContext();
5453        SimplifiedValue =
5454            ConstantInt::get(Type::getInt32Ty(Ctx), CurrentAttrValue);
5455      }
5456      return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
5457                                                      : ChangeStatus::CHANGED;
5458    }
5459  
5460    /// An optional value the associated value is assumed to fold to. That is, we
5461    /// assume the associated value (which is a call) can be replaced by this
5462    /// simplified value.
5463    std::optional<Value *> SimplifiedValue;
5464  
5465    /// The runtime function kind of the callee of the associated call site.
5466    RuntimeFunction RFKind;
5467  };
5468  
5469  } // namespace
5470  
5471  /// Register folding callsite
registerFoldRuntimeCall(RuntimeFunction RF)5472  void OpenMPOpt::registerFoldRuntimeCall(RuntimeFunction RF) {
5473    auto &RFI = OMPInfoCache.RFIs[RF];
5474    RFI.foreachUse(SCC, [&](Use &U, Function &F) {
5475      CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &RFI);
5476      if (!CI)
5477        return false;
5478      A.getOrCreateAAFor<AAFoldRuntimeCall>(
5479          IRPosition::callsite_returned(*CI), /* QueryingAA */ nullptr,
5480          DepClassTy::NONE, /* ForceUpdate */ false,
5481          /* UpdateAfterInit */ false);
5482      return false;
5483    });
5484  }
5485  
registerAAs(bool IsModulePass)5486  void OpenMPOpt::registerAAs(bool IsModulePass) {
5487    if (SCC.empty())
5488      return;
5489  
5490    if (IsModulePass) {
5491      // Ensure we create the AAKernelInfo AAs first and without triggering an
5492      // update. This will make sure we register all value simplification
5493      // callbacks before any other AA has the chance to create an AAValueSimplify
5494      // or similar.
5495      auto CreateKernelInfoCB = [&](Use &, Function &Kernel) {
5496        A.getOrCreateAAFor<AAKernelInfo>(
5497            IRPosition::function(Kernel), /* QueryingAA */ nullptr,
5498            DepClassTy::NONE, /* ForceUpdate */ false,
5499            /* UpdateAfterInit */ false);
5500        return false;
5501      };
5502      OMPInformationCache::RuntimeFunctionInfo &InitRFI =
5503          OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
5504      InitRFI.foreachUse(SCC, CreateKernelInfoCB);
5505  
5506      registerFoldRuntimeCall(OMPRTL___kmpc_is_spmd_exec_mode);
5507      registerFoldRuntimeCall(OMPRTL___kmpc_parallel_level);
5508      registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_threads_in_block);
5509      registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_blocks);
5510    }
5511  
5512    // Create CallSite AA for all Getters.
5513    if (DeduceICVValues) {
5514      for (int Idx = 0; Idx < OMPInfoCache.ICVs.size() - 1; ++Idx) {
5515        auto ICVInfo = OMPInfoCache.ICVs[static_cast<InternalControlVar>(Idx)];
5516  
5517        auto &GetterRFI = OMPInfoCache.RFIs[ICVInfo.Getter];
5518  
5519        auto CreateAA = [&](Use &U, Function &Caller) {
5520          CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &GetterRFI);
5521          if (!CI)
5522            return false;
5523  
5524          auto &CB = cast<CallBase>(*CI);
5525  
5526          IRPosition CBPos = IRPosition::callsite_function(CB);
5527          A.getOrCreateAAFor<AAICVTracker>(CBPos);
5528          return false;
5529        };
5530  
5531        GetterRFI.foreachUse(SCC, CreateAA);
5532      }
5533    }
5534  
5535    // Create an ExecutionDomain AA for every function and a HeapToStack AA for
5536    // every function if there is a device kernel.
5537    if (!isOpenMPDevice(M))
5538      return;
5539  
5540    for (auto *F : SCC) {
5541      if (F->isDeclaration())
5542        continue;
5543  
5544      // We look at internal functions only on-demand but if any use is not a
5545      // direct call or outside the current set of analyzed functions, we have
5546      // to do it eagerly.
5547      if (F->hasLocalLinkage()) {
5548        if (llvm::all_of(F->uses(), [this](const Use &U) {
5549              const auto *CB = dyn_cast<CallBase>(U.getUser());
5550              return CB && CB->isCallee(&U) &&
5551                     A.isRunOn(const_cast<Function *>(CB->getCaller()));
5552            }))
5553          continue;
5554      }
5555      registerAAsForFunction(A, *F);
5556    }
5557  }
5558  
registerAAsForFunction(Attributor & A,const Function & F)5559  void OpenMPOpt::registerAAsForFunction(Attributor &A, const Function &F) {
5560    if (!DisableOpenMPOptDeglobalization)
5561      A.getOrCreateAAFor<AAHeapToShared>(IRPosition::function(F));
5562    A.getOrCreateAAFor<AAExecutionDomain>(IRPosition::function(F));
5563    if (!DisableOpenMPOptDeglobalization)
5564      A.getOrCreateAAFor<AAHeapToStack>(IRPosition::function(F));
5565    if (F.hasFnAttribute(Attribute::Convergent))
5566      A.getOrCreateAAFor<AANonConvergent>(IRPosition::function(F));
5567  
5568    for (auto &I : instructions(F)) {
5569      if (auto *LI = dyn_cast<LoadInst>(&I)) {
5570        bool UsedAssumedInformation = false;
5571        A.getAssumedSimplified(IRPosition::value(*LI), /* AA */ nullptr,
5572                               UsedAssumedInformation, AA::Interprocedural);
5573        continue;
5574      }
5575      if (auto *CI = dyn_cast<CallBase>(&I)) {
5576        if (CI->isIndirectCall())
5577          A.getOrCreateAAFor<AAIndirectCallInfo>(
5578              IRPosition::callsite_function(*CI));
5579      }
5580      if (auto *SI = dyn_cast<StoreInst>(&I)) {
5581        A.getOrCreateAAFor<AAIsDead>(IRPosition::value(*SI));
5582        continue;
5583      }
5584      if (auto *FI = dyn_cast<FenceInst>(&I)) {
5585        A.getOrCreateAAFor<AAIsDead>(IRPosition::value(*FI));
5586        continue;
5587      }
5588      if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
5589        if (II->getIntrinsicID() == Intrinsic::assume) {
5590          A.getOrCreateAAFor<AAPotentialValues>(
5591              IRPosition::value(*II->getArgOperand(0)));
5592          continue;
5593        }
5594      }
5595    }
5596  }
5597  
5598  const char AAICVTracker::ID = 0;
5599  const char AAKernelInfo::ID = 0;
5600  const char AAExecutionDomain::ID = 0;
5601  const char AAHeapToShared::ID = 0;
5602  const char AAFoldRuntimeCall::ID = 0;
5603  
createForPosition(const IRPosition & IRP,Attributor & A)5604  AAICVTracker &AAICVTracker::createForPosition(const IRPosition &IRP,
5605                                                Attributor &A) {
5606    AAICVTracker *AA = nullptr;
5607    switch (IRP.getPositionKind()) {
5608    case IRPosition::IRP_INVALID:
5609    case IRPosition::IRP_FLOAT:
5610    case IRPosition::IRP_ARGUMENT:
5611    case IRPosition::IRP_CALL_SITE_ARGUMENT:
5612      llvm_unreachable("ICVTracker can only be created for function position!");
5613    case IRPosition::IRP_RETURNED:
5614      AA = new (A.Allocator) AAICVTrackerFunctionReturned(IRP, A);
5615      break;
5616    case IRPosition::IRP_CALL_SITE_RETURNED:
5617      AA = new (A.Allocator) AAICVTrackerCallSiteReturned(IRP, A);
5618      break;
5619    case IRPosition::IRP_CALL_SITE:
5620      AA = new (A.Allocator) AAICVTrackerCallSite(IRP, A);
5621      break;
5622    case IRPosition::IRP_FUNCTION:
5623      AA = new (A.Allocator) AAICVTrackerFunction(IRP, A);
5624      break;
5625    }
5626  
5627    return *AA;
5628  }
5629  
createForPosition(const IRPosition & IRP,Attributor & A)5630  AAExecutionDomain &AAExecutionDomain::createForPosition(const IRPosition &IRP,
5631                                                          Attributor &A) {
5632    AAExecutionDomainFunction *AA = nullptr;
5633    switch (IRP.getPositionKind()) {
5634    case IRPosition::IRP_INVALID:
5635    case IRPosition::IRP_FLOAT:
5636    case IRPosition::IRP_ARGUMENT:
5637    case IRPosition::IRP_CALL_SITE_ARGUMENT:
5638    case IRPosition::IRP_RETURNED:
5639    case IRPosition::IRP_CALL_SITE_RETURNED:
5640    case IRPosition::IRP_CALL_SITE:
5641      llvm_unreachable(
5642          "AAExecutionDomain can only be created for function position!");
5643    case IRPosition::IRP_FUNCTION:
5644      AA = new (A.Allocator) AAExecutionDomainFunction(IRP, A);
5645      break;
5646    }
5647  
5648    return *AA;
5649  }
5650  
createForPosition(const IRPosition & IRP,Attributor & A)5651  AAHeapToShared &AAHeapToShared::createForPosition(const IRPosition &IRP,
5652                                                    Attributor &A) {
5653    AAHeapToSharedFunction *AA = nullptr;
5654    switch (IRP.getPositionKind()) {
5655    case IRPosition::IRP_INVALID:
5656    case IRPosition::IRP_FLOAT:
5657    case IRPosition::IRP_ARGUMENT:
5658    case IRPosition::IRP_CALL_SITE_ARGUMENT:
5659    case IRPosition::IRP_RETURNED:
5660    case IRPosition::IRP_CALL_SITE_RETURNED:
5661    case IRPosition::IRP_CALL_SITE:
5662      llvm_unreachable(
5663          "AAHeapToShared can only be created for function position!");
5664    case IRPosition::IRP_FUNCTION:
5665      AA = new (A.Allocator) AAHeapToSharedFunction(IRP, A);
5666      break;
5667    }
5668  
5669    return *AA;
5670  }
5671  
createForPosition(const IRPosition & IRP,Attributor & A)5672  AAKernelInfo &AAKernelInfo::createForPosition(const IRPosition &IRP,
5673                                                Attributor &A) {
5674    AAKernelInfo *AA = nullptr;
5675    switch (IRP.getPositionKind()) {
5676    case IRPosition::IRP_INVALID:
5677    case IRPosition::IRP_FLOAT:
5678    case IRPosition::IRP_ARGUMENT:
5679    case IRPosition::IRP_RETURNED:
5680    case IRPosition::IRP_CALL_SITE_RETURNED:
5681    case IRPosition::IRP_CALL_SITE_ARGUMENT:
5682      llvm_unreachable("KernelInfo can only be created for function position!");
5683    case IRPosition::IRP_CALL_SITE:
5684      AA = new (A.Allocator) AAKernelInfoCallSite(IRP, A);
5685      break;
5686    case IRPosition::IRP_FUNCTION:
5687      AA = new (A.Allocator) AAKernelInfoFunction(IRP, A);
5688      break;
5689    }
5690  
5691    return *AA;
5692  }
5693  
createForPosition(const IRPosition & IRP,Attributor & A)5694  AAFoldRuntimeCall &AAFoldRuntimeCall::createForPosition(const IRPosition &IRP,
5695                                                          Attributor &A) {
5696    AAFoldRuntimeCall *AA = nullptr;
5697    switch (IRP.getPositionKind()) {
5698    case IRPosition::IRP_INVALID:
5699    case IRPosition::IRP_FLOAT:
5700    case IRPosition::IRP_ARGUMENT:
5701    case IRPosition::IRP_RETURNED:
5702    case IRPosition::IRP_FUNCTION:
5703    case IRPosition::IRP_CALL_SITE:
5704    case IRPosition::IRP_CALL_SITE_ARGUMENT:
5705      llvm_unreachable("KernelInfo can only be created for call site position!");
5706    case IRPosition::IRP_CALL_SITE_RETURNED:
5707      AA = new (A.Allocator) AAFoldRuntimeCallCallSiteReturned(IRP, A);
5708      break;
5709    }
5710  
5711    return *AA;
5712  }
5713  
run(Module & M,ModuleAnalysisManager & AM)5714  PreservedAnalyses OpenMPOptPass::run(Module &M, ModuleAnalysisManager &AM) {
5715    if (!containsOpenMP(M))
5716      return PreservedAnalyses::all();
5717    if (DisableOpenMPOptimizations)
5718      return PreservedAnalyses::all();
5719  
5720    FunctionAnalysisManager &FAM =
5721        AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
5722    KernelSet Kernels = getDeviceKernels(M);
5723  
5724    if (PrintModuleBeforeOptimizations)
5725      LLVM_DEBUG(dbgs() << TAG << "Module before OpenMPOpt Module Pass:\n" << M);
5726  
5727    auto IsCalled = [&](Function &F) {
5728      if (Kernels.contains(&F))
5729        return true;
5730      for (const User *U : F.users())
5731        if (!isa<BlockAddress>(U))
5732          return true;
5733      return false;
5734    };
5735  
5736    auto EmitRemark = [&](Function &F) {
5737      auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F);
5738      ORE.emit([&]() {
5739        OptimizationRemarkAnalysis ORA(DEBUG_TYPE, "OMP140", &F);
5740        return ORA << "Could not internalize function. "
5741                   << "Some optimizations may not be possible. [OMP140]";
5742      });
5743    };
5744  
5745    bool Changed = false;
5746  
5747    // Create internal copies of each function if this is a kernel Module. This
5748    // allows iterprocedural passes to see every call edge.
5749    DenseMap<Function *, Function *> InternalizedMap;
5750    if (isOpenMPDevice(M)) {
5751      SmallPtrSet<Function *, 16> InternalizeFns;
5752      for (Function &F : M)
5753        if (!F.isDeclaration() && !Kernels.contains(&F) && IsCalled(F) &&
5754            !DisableInternalization) {
5755          if (Attributor::isInternalizable(F)) {
5756            InternalizeFns.insert(&F);
5757          } else if (!F.hasLocalLinkage() && !F.hasFnAttribute(Attribute::Cold)) {
5758            EmitRemark(F);
5759          }
5760        }
5761  
5762      Changed |=
5763          Attributor::internalizeFunctions(InternalizeFns, InternalizedMap);
5764    }
5765  
5766    // Look at every function in the Module unless it was internalized.
5767    SetVector<Function *> Functions;
5768    SmallVector<Function *, 16> SCC;
5769    for (Function &F : M)
5770      if (!F.isDeclaration() && !InternalizedMap.lookup(&F)) {
5771        SCC.push_back(&F);
5772        Functions.insert(&F);
5773      }
5774  
5775    if (SCC.empty())
5776      return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
5777  
5778    AnalysisGetter AG(FAM);
5779  
5780    auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
5781      return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);
5782    };
5783  
5784    BumpPtrAllocator Allocator;
5785    CallGraphUpdater CGUpdater;
5786  
5787    bool PostLink = LTOPhase == ThinOrFullLTOPhase::FullLTOPostLink ||
5788                    LTOPhase == ThinOrFullLTOPhase::ThinLTOPreLink;
5789    OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ nullptr, PostLink);
5790  
5791    unsigned MaxFixpointIterations =
5792        (isOpenMPDevice(M)) ? SetFixpointIterations : 32;
5793  
5794    AttributorConfig AC(CGUpdater);
5795    AC.DefaultInitializeLiveInternals = false;
5796    AC.IsModulePass = true;
5797    AC.RewriteSignatures = false;
5798    AC.MaxFixpointIterations = MaxFixpointIterations;
5799    AC.OREGetter = OREGetter;
5800    AC.PassName = DEBUG_TYPE;
5801    AC.InitializationCallback = OpenMPOpt::registerAAsForFunction;
5802    AC.IPOAmendableCB = [](const Function &F) {
5803      return F.hasFnAttribute("kernel");
5804    };
5805  
5806    Attributor A(Functions, InfoCache, AC);
5807  
5808    OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
5809    Changed |= OMPOpt.run(true);
5810  
5811    // Optionally inline device functions for potentially better performance.
5812    if (AlwaysInlineDeviceFunctions && isOpenMPDevice(M))
5813      for (Function &F : M)
5814        if (!F.isDeclaration() && !Kernels.contains(&F) &&
5815            !F.hasFnAttribute(Attribute::NoInline))
5816          F.addFnAttr(Attribute::AlwaysInline);
5817  
5818    if (PrintModuleAfterOptimizations)
5819      LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt Module Pass:\n" << M);
5820  
5821    if (Changed)
5822      return PreservedAnalyses::none();
5823  
5824    return PreservedAnalyses::all();
5825  }
5826  
run(LazyCallGraph::SCC & C,CGSCCAnalysisManager & AM,LazyCallGraph & CG,CGSCCUpdateResult & UR)5827  PreservedAnalyses OpenMPOptCGSCCPass::run(LazyCallGraph::SCC &C,
5828                                            CGSCCAnalysisManager &AM,
5829                                            LazyCallGraph &CG,
5830                                            CGSCCUpdateResult &UR) {
5831    if (!containsOpenMP(*C.begin()->getFunction().getParent()))
5832      return PreservedAnalyses::all();
5833    if (DisableOpenMPOptimizations)
5834      return PreservedAnalyses::all();
5835  
5836    SmallVector<Function *, 16> SCC;
5837    // If there are kernels in the module, we have to run on all SCC's.
5838    for (LazyCallGraph::Node &N : C) {
5839      Function *Fn = &N.getFunction();
5840      SCC.push_back(Fn);
5841    }
5842  
5843    if (SCC.empty())
5844      return PreservedAnalyses::all();
5845  
5846    Module &M = *C.begin()->getFunction().getParent();
5847  
5848    if (PrintModuleBeforeOptimizations)
5849      LLVM_DEBUG(dbgs() << TAG << "Module before OpenMPOpt CGSCC Pass:\n" << M);
5850  
5851    KernelSet Kernels = getDeviceKernels(M);
5852  
5853    FunctionAnalysisManager &FAM =
5854        AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();
5855  
5856    AnalysisGetter AG(FAM);
5857  
5858    auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
5859      return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);
5860    };
5861  
5862    BumpPtrAllocator Allocator;
5863    CallGraphUpdater CGUpdater;
5864    CGUpdater.initialize(CG, C, AM, UR);
5865  
5866    bool PostLink = LTOPhase == ThinOrFullLTOPhase::FullLTOPostLink ||
5867                    LTOPhase == ThinOrFullLTOPhase::ThinLTOPreLink;
5868    SetVector<Function *> Functions(SCC.begin(), SCC.end());
5869    OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator,
5870                                  /*CGSCC*/ &Functions, PostLink);
5871  
5872    unsigned MaxFixpointIterations =
5873        (isOpenMPDevice(M)) ? SetFixpointIterations : 32;
5874  
5875    AttributorConfig AC(CGUpdater);
5876    AC.DefaultInitializeLiveInternals = false;
5877    AC.IsModulePass = false;
5878    AC.RewriteSignatures = false;
5879    AC.MaxFixpointIterations = MaxFixpointIterations;
5880    AC.OREGetter = OREGetter;
5881    AC.PassName = DEBUG_TYPE;
5882    AC.InitializationCallback = OpenMPOpt::registerAAsForFunction;
5883  
5884    Attributor A(Functions, InfoCache, AC);
5885  
5886    OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
5887    bool Changed = OMPOpt.run(false);
5888  
5889    if (PrintModuleAfterOptimizations)
5890      LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt CGSCC Pass:\n" << M);
5891  
5892    if (Changed)
5893      return PreservedAnalyses::none();
5894  
5895    return PreservedAnalyses::all();
5896  }
5897  
isOpenMPKernel(Function & Fn)5898  bool llvm::omp::isOpenMPKernel(Function &Fn) {
5899    return Fn.hasFnAttribute("kernel");
5900  }
5901  
getDeviceKernels(Module & M)5902  KernelSet llvm::omp::getDeviceKernels(Module &M) {
5903    // TODO: Create a more cross-platform way of determining device kernels.
5904    NamedMDNode *MD = M.getNamedMetadata("nvvm.annotations");
5905    KernelSet Kernels;
5906  
5907    if (!MD)
5908      return Kernels;
5909  
5910    for (auto *Op : MD->operands()) {
5911      if (Op->getNumOperands() < 2)
5912        continue;
5913      MDString *KindID = dyn_cast<MDString>(Op->getOperand(1));
5914      if (!KindID || KindID->getString() != "kernel")
5915        continue;
5916  
5917      Function *KernelFn =
5918          mdconst::dyn_extract_or_null<Function>(Op->getOperand(0));
5919      if (!KernelFn)
5920        continue;
5921  
5922      // We are only interested in OpenMP target regions. Others, such as kernels
5923      // generated by CUDA but linked together, are not interesting to this pass.
5924      if (isOpenMPKernel(*KernelFn)) {
5925        ++NumOpenMPTargetRegionKernels;
5926        Kernels.insert(KernelFn);
5927      } else
5928        ++NumNonOpenMPTargetRegionKernels;
5929    }
5930  
5931    return Kernels;
5932  }
5933  
containsOpenMP(Module & M)5934  bool llvm::omp::containsOpenMP(Module &M) {
5935    Metadata *MD = M.getModuleFlag("openmp");
5936    if (!MD)
5937      return false;
5938  
5939    return true;
5940  }
5941  
isOpenMPDevice(Module & M)5942  bool llvm::omp::isOpenMPDevice(Module &M) {
5943    Metadata *MD = M.getModuleFlag("openmp-device");
5944    if (!MD)
5945      return false;
5946  
5947    return true;
5948  }
5949