xref: /freebsd/contrib/llvm-project/llvm/lib/Target/AMDGPU/R600OpenCLImageTypeLoweringPass.cpp (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1 //===- R600OpenCLImageTypeLoweringPass.cpp ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file
10 /// This pass resolves calls to OpenCL image attribute, image resource ID and
11 /// sampler resource ID getter functions.
12 ///
13 /// Image attributes (size and format) are expected to be passed to the kernel
14 /// as kernel arguments immediately following the image argument itself,
15 /// therefore this pass adds image size and format arguments to the kernel
16 /// functions in the module. The kernel functions with image arguments are
17 /// re-created using the new signature. The new arguments are added to the
18 /// kernel metadata with kernel_arg_type set to "image_size" or "image_format".
19 /// Note: this pass may invalidate pointers to functions.
20 ///
21 /// Resource IDs of read-only images, write-only images and samplers are
22 /// defined to be their index among the kernel arguments of the same
23 /// type and access qualifier.
24 //
25 //===----------------------------------------------------------------------===//
26 
27 #include "R600.h"
28 #include "llvm/ADT/SmallVector.h"
29 #include "llvm/ADT/StringRef.h"
30 #include "llvm/IR/Constants.h"
31 #include "llvm/IR/Function.h"
32 #include "llvm/IR/Instructions.h"
33 #include "llvm/IR/Metadata.h"
34 #include "llvm/IR/Module.h"
35 #include "llvm/Pass.h"
36 #include "llvm/Transforms/Utils/Cloning.h"
37 
38 using namespace llvm;
39 
40 static StringRef GetImageSizeFunc =         "llvm.OpenCL.image.get.size";
41 static StringRef GetImageFormatFunc =       "llvm.OpenCL.image.get.format";
42 static StringRef GetImageResourceIDFunc =   "llvm.OpenCL.image.get.resource.id";
43 static StringRef GetSamplerResourceIDFunc =
44     "llvm.OpenCL.sampler.get.resource.id";
45 
46 static StringRef ImageSizeArgMDType =   "__llvm_image_size";
47 static StringRef ImageFormatArgMDType = "__llvm_image_format";
48 
49 static StringRef KernelsMDNodeName = "opencl.kernels";
50 static StringRef KernelArgMDNodeNames[] = {
51   "kernel_arg_addr_space",
52   "kernel_arg_access_qual",
53   "kernel_arg_type",
54   "kernel_arg_base_type",
55   "kernel_arg_type_qual"};
56 static const unsigned NumKernelArgMDNodes = 5;
57 
58 namespace {
59 
60 using MDVector = SmallVector<Metadata *, 8>;
61 struct KernelArgMD {
62   MDVector ArgVector[NumKernelArgMDNodes];
63 };
64 
65 } // end anonymous namespace
66 
67 static inline bool
IsImageType(StringRef TypeString)68 IsImageType(StringRef TypeString) {
69   return TypeString == "image2d_t" || TypeString == "image3d_t";
70 }
71 
72 static inline bool
IsSamplerType(StringRef TypeString)73 IsSamplerType(StringRef TypeString) {
74   return TypeString == "sampler_t";
75 }
76 
77 static Function *
GetFunctionFromMDNode(MDNode * Node)78 GetFunctionFromMDNode(MDNode *Node) {
79   if (!Node)
80     return nullptr;
81 
82   size_t NumOps = Node->getNumOperands();
83   if (NumOps != NumKernelArgMDNodes + 1)
84     return nullptr;
85 
86   auto F = mdconst::dyn_extract<Function>(Node->getOperand(0));
87   if (!F)
88     return nullptr;
89 
90   // Validation checks.
91   size_t ExpectNumArgNodeOps = F->arg_size() + 1;
92   for (size_t i = 0; i < NumKernelArgMDNodes; ++i) {
93     MDNode *ArgNode = dyn_cast_or_null<MDNode>(Node->getOperand(i + 1));
94     if (ArgNode->getNumOperands() != ExpectNumArgNodeOps)
95       return nullptr;
96     if (!ArgNode->getOperand(0))
97       return nullptr;
98 
99     // FIXME: It should be possible to do image lowering when some metadata
100     // args missing or not in the expected order.
101     MDString *StringNode = dyn_cast<MDString>(ArgNode->getOperand(0));
102     if (!StringNode || StringNode->getString() != KernelArgMDNodeNames[i])
103       return nullptr;
104   }
105 
106   return F;
107 }
108 
109 static StringRef
AccessQualFromMD(MDNode * KernelMDNode,unsigned ArgIdx)110 AccessQualFromMD(MDNode *KernelMDNode, unsigned ArgIdx) {
111   MDNode *ArgAQNode = cast<MDNode>(KernelMDNode->getOperand(2));
112   return cast<MDString>(ArgAQNode->getOperand(ArgIdx + 1))->getString();
113 }
114 
115 static StringRef
ArgTypeFromMD(MDNode * KernelMDNode,unsigned ArgIdx)116 ArgTypeFromMD(MDNode *KernelMDNode, unsigned ArgIdx) {
117   MDNode *ArgTypeNode = cast<MDNode>(KernelMDNode->getOperand(3));
118   return cast<MDString>(ArgTypeNode->getOperand(ArgIdx + 1))->getString();
119 }
120 
121 static MDVector
GetArgMD(MDNode * KernelMDNode,unsigned OpIdx)122 GetArgMD(MDNode *KernelMDNode, unsigned OpIdx) {
123   MDVector Res;
124   for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) {
125     MDNode *Node = cast<MDNode>(KernelMDNode->getOperand(i + 1));
126     Res.push_back(Node->getOperand(OpIdx));
127   }
128   return Res;
129 }
130 
131 static void
PushArgMD(KernelArgMD & MD,const MDVector & V)132 PushArgMD(KernelArgMD &MD, const MDVector &V) {
133   assert(V.size() == NumKernelArgMDNodes);
134   for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) {
135     MD.ArgVector[i].push_back(V[i]);
136   }
137 }
138 
139 namespace {
140 
141 class R600OpenCLImageTypeLoweringPass : public ModulePass {
142   static char ID;
143 
144   LLVMContext *Context;
145   Type *Int32Type;
146   Type *ImageSizeType;
147   Type *ImageFormatType;
148   SmallVector<Instruction *, 4> InstsToErase;
149 
replaceImageUses(Argument & ImageArg,uint32_t ResourceID,Argument & ImageSizeArg,Argument & ImageFormatArg)150   bool replaceImageUses(Argument &ImageArg, uint32_t ResourceID,
151                         Argument &ImageSizeArg,
152                         Argument &ImageFormatArg) {
153     bool Modified = false;
154 
155     for (auto &Use : ImageArg.uses()) {
156       auto Inst = dyn_cast<CallInst>(Use.getUser());
157       if (!Inst) {
158         continue;
159       }
160 
161       Function *F = Inst->getCalledFunction();
162       if (!F)
163         continue;
164 
165       Value *Replacement = nullptr;
166       StringRef Name = F->getName();
167       if (Name.starts_with(GetImageResourceIDFunc)) {
168         Replacement = ConstantInt::get(Int32Type, ResourceID);
169       } else if (Name.starts_with(GetImageSizeFunc)) {
170         Replacement = &ImageSizeArg;
171       } else if (Name.starts_with(GetImageFormatFunc)) {
172         Replacement = &ImageFormatArg;
173       } else {
174         continue;
175       }
176 
177       Inst->replaceAllUsesWith(Replacement);
178       InstsToErase.push_back(Inst);
179       Modified = true;
180     }
181 
182     return Modified;
183   }
184 
replaceSamplerUses(Argument & SamplerArg,uint32_t ResourceID)185   bool replaceSamplerUses(Argument &SamplerArg, uint32_t ResourceID) {
186     bool Modified = false;
187 
188     for (const auto &Use : SamplerArg.uses()) {
189       auto Inst = dyn_cast<CallInst>(Use.getUser());
190       if (!Inst) {
191         continue;
192       }
193 
194       Function *F = Inst->getCalledFunction();
195       if (!F)
196         continue;
197 
198       Value *Replacement = nullptr;
199       StringRef Name = F->getName();
200       if (Name == GetSamplerResourceIDFunc) {
201         Replacement = ConstantInt::get(Int32Type, ResourceID);
202       } else {
203         continue;
204       }
205 
206       Inst->replaceAllUsesWith(Replacement);
207       InstsToErase.push_back(Inst);
208       Modified = true;
209     }
210 
211     return Modified;
212   }
213 
replaceImageAndSamplerUses(Function * F,MDNode * KernelMDNode)214   bool replaceImageAndSamplerUses(Function *F, MDNode *KernelMDNode) {
215     uint32_t NumReadOnlyImageArgs = 0;
216     uint32_t NumWriteOnlyImageArgs = 0;
217     uint32_t NumSamplerArgs = 0;
218 
219     bool Modified = false;
220     InstsToErase.clear();
221     for (auto ArgI = F->arg_begin(); ArgI != F->arg_end(); ++ArgI) {
222       Argument &Arg = *ArgI;
223       StringRef Type = ArgTypeFromMD(KernelMDNode, Arg.getArgNo());
224 
225       // Handle image types.
226       if (IsImageType(Type)) {
227         StringRef AccessQual = AccessQualFromMD(KernelMDNode, Arg.getArgNo());
228         uint32_t ResourceID;
229         if (AccessQual == "read_only") {
230           ResourceID = NumReadOnlyImageArgs++;
231         } else if (AccessQual == "write_only") {
232           ResourceID = NumWriteOnlyImageArgs++;
233         } else {
234           llvm_unreachable("Wrong image access qualifier.");
235         }
236 
237         Argument &SizeArg = *(++ArgI);
238         Argument &FormatArg = *(++ArgI);
239         Modified |= replaceImageUses(Arg, ResourceID, SizeArg, FormatArg);
240 
241       // Handle sampler type.
242       } else if (IsSamplerType(Type)) {
243         uint32_t ResourceID = NumSamplerArgs++;
244         Modified |= replaceSamplerUses(Arg, ResourceID);
245       }
246     }
247     for (auto *Inst : InstsToErase)
248       Inst->eraseFromParent();
249 
250     return Modified;
251   }
252 
253   std::tuple<Function *, MDNode *>
addImplicitArgs(Function * F,MDNode * KernelMDNode)254   addImplicitArgs(Function *F, MDNode *KernelMDNode) {
255     bool Modified = false;
256 
257     FunctionType *FT = F->getFunctionType();
258     SmallVector<Type *, 8> ArgTypes;
259 
260     // Metadata operands for new MDNode.
261     KernelArgMD NewArgMDs;
262     PushArgMD(NewArgMDs, GetArgMD(KernelMDNode, 0));
263 
264     // Add implicit arguments to the signature.
265     for (unsigned i = 0; i < FT->getNumParams(); ++i) {
266       ArgTypes.push_back(FT->getParamType(i));
267       MDVector ArgMD = GetArgMD(KernelMDNode, i + 1);
268       PushArgMD(NewArgMDs, ArgMD);
269 
270       if (!IsImageType(ArgTypeFromMD(KernelMDNode, i)))
271         continue;
272 
273       // Add size implicit argument.
274       ArgTypes.push_back(ImageSizeType);
275       ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageSizeArgMDType);
276       PushArgMD(NewArgMDs, ArgMD);
277 
278       // Add format implicit argument.
279       ArgTypes.push_back(ImageFormatType);
280       ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageFormatArgMDType);
281       PushArgMD(NewArgMDs, ArgMD);
282 
283       Modified = true;
284     }
285     if (!Modified) {
286       return std::tuple(nullptr, nullptr);
287     }
288 
289     // Create function with new signature and clone the old body into it.
290     auto NewFT = FunctionType::get(FT->getReturnType(), ArgTypes, false);
291     auto NewF = Function::Create(NewFT, F->getLinkage(), F->getName());
292     ValueToValueMapTy VMap;
293     auto NewFArgIt = NewF->arg_begin();
294     for (auto &Arg: F->args()) {
295       auto ArgName = Arg.getName();
296       NewFArgIt->setName(ArgName);
297       VMap[&Arg] = &(*NewFArgIt++);
298       if (IsImageType(ArgTypeFromMD(KernelMDNode, Arg.getArgNo()))) {
299         (NewFArgIt++)->setName(Twine("__size_") + ArgName);
300         (NewFArgIt++)->setName(Twine("__format_") + ArgName);
301       }
302     }
303     SmallVector<ReturnInst*, 8> Returns;
304     CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly,
305                       Returns);
306 
307     // Build new MDNode.
308     SmallVector<Metadata *, 6> KernelMDArgs;
309     KernelMDArgs.push_back(ConstantAsMetadata::get(NewF));
310     for (const MDVector &MDV : NewArgMDs.ArgVector)
311       KernelMDArgs.push_back(MDNode::get(*Context, MDV));
312     MDNode *NewMDNode = MDNode::get(*Context, KernelMDArgs);
313 
314     return std::tuple(NewF, NewMDNode);
315   }
316 
transformKernels(Module & M)317   bool transformKernels(Module &M) {
318     NamedMDNode *KernelsMDNode = M.getNamedMetadata(KernelsMDNodeName);
319     if (!KernelsMDNode)
320       return false;
321 
322     bool Modified = false;
323     for (unsigned i = 0; i < KernelsMDNode->getNumOperands(); ++i) {
324       MDNode *KernelMDNode = KernelsMDNode->getOperand(i);
325       Function *F = GetFunctionFromMDNode(KernelMDNode);
326       if (!F)
327         continue;
328 
329       Function *NewF;
330       MDNode *NewMDNode;
331       std::tie(NewF, NewMDNode) = addImplicitArgs(F, KernelMDNode);
332       if (NewF) {
333         // Replace old function and metadata with new ones.
334         F->eraseFromParent();
335         M.getFunctionList().push_back(NewF);
336         M.getOrInsertFunction(NewF->getName(), NewF->getFunctionType(),
337                               NewF->getAttributes());
338         KernelsMDNode->setOperand(i, NewMDNode);
339 
340         F = NewF;
341         KernelMDNode = NewMDNode;
342         Modified = true;
343       }
344 
345       Modified |= replaceImageAndSamplerUses(F, KernelMDNode);
346     }
347 
348     return Modified;
349   }
350 
351 public:
R600OpenCLImageTypeLoweringPass()352   R600OpenCLImageTypeLoweringPass() : ModulePass(ID) {}
353 
runOnModule(Module & M)354   bool runOnModule(Module &M) override {
355     Context = &M.getContext();
356     Int32Type = Type::getInt32Ty(M.getContext());
357     ImageSizeType = ArrayType::get(Int32Type, 3);
358     ImageFormatType = ArrayType::get(Int32Type, 2);
359 
360     return transformKernels(M);
361   }
362 
getPassName() const363   StringRef getPassName() const override {
364     return "R600 OpenCL Image Type Pass";
365   }
366 };
367 
368 } // end anonymous namespace
369 
370 char R600OpenCLImageTypeLoweringPass::ID = 0;
371 
createR600OpenCLImageTypeLoweringPass()372 ModulePass *llvm::createR600OpenCLImageTypeLoweringPass() {
373   return new R600OpenCLImageTypeLoweringPass();
374 }
375