xref: /freebsd/contrib/llvm-project/llvm/lib/Transforms/HipStdPar/HipStdPar.cpp (revision 3ceba58a7509418b47b8fca2d2b6bbf088714e26)
1 //===----- HipStdPar.cpp - HIP C++ Standard Parallelism Support Passes ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // This file implements two passes that enable HIP C++ Standard Parallelism
9 // Support:
10 //
11 // 1. AcceleratorCodeSelection (required): Given that only algorithms are
12 //    accelerated, and that the accelerated implementation exists in the form of
13 //    a compute kernel, we assume that only the kernel, and all functions
14 //    reachable from it, constitute code that the user expects the accelerator
15 //    to execute. Thus, we identify the set of all functions reachable from
16 //    kernels, and then remove all unreachable ones. This last part is necessary
17 //    because it is possible for code that the user did not expect to execute on
18 //    an accelerator to contain constructs that cannot be handled by the target
19 //    BE, which cannot be provably demonstrated to be dead code in general, and
20 //    thus can lead to mis-compilation. The degenerate case of this is when a
21 //    Module contains no kernels (the parent TU had no algorithm invocations fit
22 //    for acceleration), which we handle by completely emptying said module.
23 //    **NOTE**: The above does not handle indirectly reachable functions i.e.
24 //              it is possible to obtain a case where the target of an indirect
25 //              call is otherwise unreachable and thus is removed; this
26 //              restriction is aligned with the current `-hipstdpar` limitations
27 //              and will be relaxed in the future.
28 //
29 // 2. AllocationInterposition (required only when on-demand paging is
30 //    unsupported): Some accelerators or operating systems might not support
31 //    transparent on-demand paging. Thus, they would only be able to access
32 //    memory that is allocated by an accelerator-aware mechanism. For such cases
33 //    the user can opt into enabling allocation / deallocation interposition,
34 //    whereby we replace calls to known allocation / deallocation functions with
35 //    calls to runtime implemented equivalents that forward the requests to
36 //    accelerator-aware interfaces. We also support freeing system allocated
37 //    memory that ends up in one of the runtime equivalents, since this can
38 //    happen if e.g. a library that was compiled without interposition returns
39 //    an allocation that can be validly passed to `free`.
40 //===----------------------------------------------------------------------===//
41 
42 #include "llvm/Transforms/HipStdPar/HipStdPar.h"
43 
44 #include "llvm/ADT/SmallPtrSet.h"
45 #include "llvm/ADT/SmallVector.h"
46 #include "llvm/ADT/STLExtras.h"
47 #include "llvm/Analysis/CallGraph.h"
48 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
49 #include "llvm/IR/Constants.h"
50 #include "llvm/IR/DebugInfoMetadata.h"
51 #include "llvm/IR/Function.h"
52 #include "llvm/IR/Module.h"
53 #include "llvm/Transforms/Utils/ModuleUtils.h"
54 
55 #include <cassert>
56 #include <string>
57 #include <utility>
58 
59 using namespace llvm;
60 
61 template<typename T>
62 static inline void eraseFromModule(T &ToErase) {
63   ToErase.replaceAllUsesWith(PoisonValue::get(ToErase.getType()));
64   ToErase.eraseFromParent();
65 }
66 
67 static inline bool checkIfSupported(GlobalVariable &G) {
68   if (!G.isThreadLocal())
69     return true;
70 
71   G.dropDroppableUses();
72 
73   if (!G.isConstantUsed())
74     return true;
75 
76   std::string W;
77   raw_string_ostream OS(W);
78 
79   OS << "Accelerator does not support the thread_local variable "
80     << G.getName();
81 
82   Instruction *I = nullptr;
83   SmallVector<User *> Tmp(G.user_begin(), G.user_end());
84   SmallPtrSet<User *, 5> Visited;
85   do {
86     auto U = std::move(Tmp.back());
87     Tmp.pop_back();
88 
89     if (Visited.contains(U))
90       continue;
91 
92     if (isa<Instruction>(U))
93       I = cast<Instruction>(U);
94     else
95       Tmp.insert(Tmp.end(), U->user_begin(), U->user_end());
96 
97     Visited.insert(U);
98   } while (!I && !Tmp.empty());
99 
100   assert(I && "thread_local global should have at least one non-constant use.");
101 
102   G.getContext().diagnose(
103     DiagnosticInfoUnsupported(*I->getParent()->getParent(), W,
104                               I->getDebugLoc(), DS_Error));
105 
106   return false;
107 }
108 
109 static inline void clearModule(Module &M) { // TODO: simplify.
110   while (!M.functions().empty())
111     eraseFromModule(*M.begin());
112   while (!M.globals().empty())
113     eraseFromModule(*M.globals().begin());
114   while (!M.aliases().empty())
115     eraseFromModule(*M.aliases().begin());
116   while (!M.ifuncs().empty())
117     eraseFromModule(*M.ifuncs().begin());
118 }
119 
120 static inline void maybeHandleGlobals(Module &M) {
121   unsigned GlobAS = M.getDataLayout().getDefaultGlobalsAddressSpace();
122   for (auto &&G : M.globals()) { // TODO: should we handle these in the FE?
123     if (!checkIfSupported(G))
124       return clearModule(M);
125 
126     if (G.isThreadLocal())
127       continue;
128     if (G.isConstant())
129       continue;
130     if (G.getAddressSpace() != GlobAS)
131       continue;
132     if (G.getLinkage() != GlobalVariable::ExternalLinkage)
133       continue;
134 
135     G.setLinkage(GlobalVariable::ExternalWeakLinkage);
136     G.setInitializer(nullptr);
137     G.setExternallyInitialized(true);
138   }
139 }
140 
141 template<unsigned N>
142 static inline void removeUnreachableFunctions(
143   const SmallPtrSet<const Function *, N>& Reachable, Module &M) {
144   removeFromUsedLists(M, [&](Constant *C) {
145     if (auto F = dyn_cast<Function>(C))
146       return !Reachable.contains(F);
147 
148     return false;
149   });
150 
151   SmallVector<std::reference_wrapper<Function>> ToRemove;
152   copy_if(M, std::back_inserter(ToRemove), [&](auto &&F) {
153     return !F.isIntrinsic() && !Reachable.contains(&F);
154   });
155 
156   for_each(ToRemove, eraseFromModule<Function>);
157 }
158 
159 static inline bool isAcceleratorExecutionRoot(const Function *F) {
160     if (!F)
161       return false;
162 
163     return F->getCallingConv() == CallingConv::AMDGPU_KERNEL;
164 }
165 
166 static inline bool checkIfSupported(const Function *F, const CallBase *CB) {
167   const auto Dx = F->getName().rfind("__hipstdpar_unsupported");
168 
169   if (Dx == StringRef::npos)
170     return true;
171 
172   const auto N = F->getName().substr(0, Dx);
173 
174   std::string W;
175   raw_string_ostream OS(W);
176 
177   if (N == "__ASM")
178     OS << "Accelerator does not support the ASM block:\n"
179       << cast<ConstantDataArray>(CB->getArgOperand(0))->getAsCString();
180   else
181     OS << "Accelerator does not support the " << N << " function.";
182 
183   auto Caller = CB->getParent()->getParent();
184 
185   Caller->getContext().diagnose(
186     DiagnosticInfoUnsupported(*Caller, W, CB->getDebugLoc(), DS_Error));
187 
188   return false;
189 }
190 
191 PreservedAnalyses
192   HipStdParAcceleratorCodeSelectionPass::run(Module &M,
193                                              ModuleAnalysisManager &MAM) {
194   auto &CGA = MAM.getResult<CallGraphAnalysis>(M);
195 
196   SmallPtrSet<const Function *, 32> Reachable;
197   for (auto &&CGN : CGA) {
198     if (!isAcceleratorExecutionRoot(CGN.first))
199       continue;
200 
201     Reachable.insert(CGN.first);
202 
203     SmallVector<const Function *> Tmp({CGN.first});
204     do {
205       auto F = std::move(Tmp.back());
206       Tmp.pop_back();
207 
208       for (auto &&N : *CGA[F]) {
209         if (!N.second)
210           continue;
211         if (!N.second->getFunction())
212           continue;
213         if (Reachable.contains(N.second->getFunction()))
214           continue;
215 
216         if (!checkIfSupported(N.second->getFunction(),
217                               dyn_cast<CallBase>(*N.first)))
218           return PreservedAnalyses::none();
219 
220         Reachable.insert(N.second->getFunction());
221         Tmp.push_back(N.second->getFunction());
222       }
223     } while (!std::empty(Tmp));
224   }
225 
226   if (std::empty(Reachable))
227     clearModule(M);
228   else
229     removeUnreachableFunctions(Reachable, M);
230 
231   maybeHandleGlobals(M);
232 
233   return PreservedAnalyses::none();
234 }
235 
236 static constexpr std::pair<StringLiteral, StringLiteral> ReplaceMap[]{
237   {"aligned_alloc",             "__hipstdpar_aligned_alloc"},
238   {"calloc",                    "__hipstdpar_calloc"},
239   {"free",                      "__hipstdpar_free"},
240   {"malloc",                    "__hipstdpar_malloc"},
241   {"memalign",                  "__hipstdpar_aligned_alloc"},
242   {"posix_memalign",            "__hipstdpar_posix_aligned_alloc"},
243   {"realloc",                   "__hipstdpar_realloc"},
244   {"reallocarray",              "__hipstdpar_realloc_array"},
245   {"_ZdaPv",                    "__hipstdpar_operator_delete"},
246   {"_ZdaPvm",                   "__hipstdpar_operator_delete_sized"},
247   {"_ZdaPvSt11align_val_t",     "__hipstdpar_operator_delete_aligned"},
248   {"_ZdaPvmSt11align_val_t",    "__hipstdpar_operator_delete_aligned_sized"},
249   {"_ZdlPv",                    "__hipstdpar_operator_delete"},
250   {"_ZdlPvm",                   "__hipstdpar_operator_delete_sized"},
251   {"_ZdlPvSt11align_val_t",     "__hipstdpar_operator_delete_aligned"},
252   {"_ZdlPvmSt11align_val_t",    "__hipstdpar_operator_delete_aligned_sized"},
253   {"_Znam",                     "__hipstdpar_operator_new"},
254   {"_ZnamRKSt9nothrow_t",       "__hipstdpar_operator_new_nothrow"},
255   {"_ZnamSt11align_val_t",      "__hipstdpar_operator_new_aligned"},
256   {"_ZnamSt11align_val_tRKSt9nothrow_t",
257                                 "__hipstdpar_operator_new_aligned_nothrow"},
258 
259   {"_Znwm",                     "__hipstdpar_operator_new"},
260   {"_ZnwmRKSt9nothrow_t",       "__hipstdpar_operator_new_nothrow"},
261   {"_ZnwmSt11align_val_t",      "__hipstdpar_operator_new_aligned"},
262   {"_ZnwmSt11align_val_tRKSt9nothrow_t",
263                                 "__hipstdpar_operator_new_aligned_nothrow"},
264   {"__builtin_calloc",          "__hipstdpar_calloc"},
265   {"__builtin_free",            "__hipstdpar_free"},
266   {"__builtin_malloc",          "__hipstdpar_malloc"},
267   {"__builtin_operator_delete", "__hipstdpar_operator_delete"},
268   {"__builtin_operator_new",    "__hipstdpar_operator_new"},
269   {"__builtin_realloc",         "__hipstdpar_realloc"},
270   {"__libc_calloc",             "__hipstdpar_calloc"},
271   {"__libc_free",               "__hipstdpar_free"},
272   {"__libc_malloc",             "__hipstdpar_malloc"},
273   {"__libc_memalign",           "__hipstdpar_aligned_alloc"},
274   {"__libc_realloc",            "__hipstdpar_realloc"}
275 };
276 
277 PreservedAnalyses
278 HipStdParAllocationInterpositionPass::run(Module &M, ModuleAnalysisManager&) {
279   SmallDenseMap<StringRef, StringRef> AllocReplacements(std::cbegin(ReplaceMap),
280                                                         std::cend(ReplaceMap));
281 
282   for (auto &&F : M) {
283     if (!F.hasName())
284       continue;
285     if (!AllocReplacements.contains(F.getName()))
286       continue;
287 
288     if (auto R = M.getFunction(AllocReplacements[F.getName()])) {
289       F.replaceAllUsesWith(R);
290     } else {
291       std::string W;
292       raw_string_ostream OS(W);
293 
294       OS << "cannot be interposed, missing: " << AllocReplacements[F.getName()]
295         << ". Tried to run the allocation interposition pass without the "
296         << "replacement functions available.";
297 
298       F.getContext().diagnose(DiagnosticInfoUnsupported(F, W,
299                                                         F.getSubprogram(),
300                                                         DS_Warning));
301     }
302   }
303 
304   if (auto F = M.getFunction("__hipstdpar_hidden_free")) {
305     auto LibcFree = M.getOrInsertFunction("__libc_free", F->getFunctionType(),
306                                           F->getAttributes());
307     F->replaceAllUsesWith(LibcFree.getCallee());
308 
309     eraseFromModule(*F);
310   }
311 
312   return PreservedAnalyses::none();
313 }
314