xref: /freebsd/contrib/llvm-project/llvm/lib/Target/AMDGPU/AMDGPULowerBufferFatPointers.cpp (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1 //===-- AMDGPULowerBufferFatPointers.cpp ---------------------------=//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass lowers operations on buffer fat pointers (addrspace 7) to
10 // operations on buffer resources (addrspace 8) and is needed for correct
11 // codegen.
12 //
13 // # Background
14 //
15 // Address space 7 (the buffer fat pointer) is a 160-bit pointer that consists
16 // of a 128-bit buffer descriptor and a 32-bit offset into that descriptor.
17 // The buffer resource part needs to be it needs to be a "raw" buffer resource
18 // (it must have a stride of 0 and bounds checks must be in raw buffer mode
19 // or disabled).
20 //
21 // When these requirements are met, a buffer resource can be treated as a
22 // typical (though quite wide) pointer that follows typical LLVM pointer
23 // semantics. This allows the frontend to reason about such buffers (which are
24 // often encountered in the context of SPIR-V kernels).
25 //
26 // However, because of their non-power-of-2 size, these fat pointers cannot be
27 // present during translation to MIR (though this restriction may be lifted
28 // during the transition to GlobalISel). Therefore, this pass is needed in order
29 // to correctly implement these fat pointers.
30 //
31 // The resource intrinsics take the resource part (the address space 8 pointer)
32 // and the offset part (the 32-bit integer) as separate arguments. In addition,
33 // many users of these buffers manipulate the offset while leaving the resource
34 // part alone. For these reasons, we want to typically separate the resource
35 // and offset parts into separate variables, but combine them together when
36 // encountering cases where this is required, such as by inserting these values
37 // into aggretates or moving them to memory.
38 //
39 // Therefore, at a high level, `ptr addrspace(7) %x` becomes `ptr addrspace(8)
40 // %x.rsrc` and `i32 %x.off`, which will be combined into `{ptr addrspace(8),
41 // i32} %x = {%x.rsrc, %x.off}` if needed. Similarly, `vector<Nxp7>` becomes
42 // `{vector<Nxp8>, vector<Nxi32 >}` and its component parts.
43 //
44 // # Implementation
45 //
46 // This pass proceeds in three main phases:
47 //
48 // ## Rewriting loads and stores of p7
49 //
50 // The first phase is to rewrite away all loads and stors of `ptr addrspace(7)`,
51 // including aggregates containing such pointers, to ones that use `i160`. This
52 // is handled by `StoreFatPtrsAsIntsVisitor` , which visits loads, stores, and
53 // allocas and, if the loaded or stored type contains `ptr addrspace(7)`,
54 // rewrites that type to one where the p7s are replaced by i160s, copying other
55 // parts of aggregates as needed. In the case of a store, each pointer is
56 // `ptrtoint`d to i160 before storing, and load integers are `inttoptr`d back.
57 // This same transformation is applied to vectors of pointers.
58 //
59 // Such a transformation allows the later phases of the pass to not need
60 // to handle buffer fat pointers moving to and from memory, where we load
61 // have to handle the incompatibility between a `{Nxp8, Nxi32}` representation
62 // and `Nxi60` directly. Instead, that transposing action (where the vectors
63 // of resources and vectors of offsets are concatentated before being stored to
64 // memory) are handled through implementing `inttoptr` and `ptrtoint` only.
65 //
66 // Atomics operations on `ptr addrspace(7)` values are not suppported, as the
67 // hardware does not include a 160-bit atomic.
68 //
69 // ## Type remapping
70 //
71 // We use a `ValueMapper` to mangle uses of [vectors of] buffer fat pointers
72 // to the corresponding struct type, which has a resource part and an offset
73 // part.
74 //
75 // This uses a `BufferFatPtrToStructTypeMap` and a `FatPtrConstMaterializer`
76 // to, usually by way of `setType`ing values. Constants are handled here
77 // because there isn't a good way to fix them up later.
78 //
79 // This has the downside of leaving the IR in an invalid state (for example,
80 // the instruction `getelementptr {ptr addrspace(8), i32} %p, ...` will exist),
81 // but all such invalid states will be resolved by the third phase.
82 //
83 // Functions that don't take buffer fat pointers are modified in place. Those
84 // that do take such pointers have their basic blocks moved to a new function
85 // with arguments that are {ptr addrspace(8), i32} arguments and return values.
86 // This phase also records intrinsics so that they can be remangled or deleted
87 // later.
88 //
89 //
90 // ## Splitting pointer structs
91 //
92 // The meat of this pass consists of defining semantics for operations that
93 // produce or consume [vectors of] buffer fat pointers in terms of their
94 // resource and offset parts. This is accomplished throgh the `SplitPtrStructs`
95 // visitor.
96 //
97 // In the first pass through each function that is being lowered, the splitter
98 // inserts new instructions to implement the split-structures behavior, which is
99 // needed for correctness and performance. It records a list of "split users",
100 // instructions that are being replaced by operations on the resource and offset
101 // parts.
102 //
103 // Split users do not necessarily need to produce parts themselves (
104 // a `load float, ptr addrspace(7)` does not, for example), but, if they do not
105 // generate fat buffer pointers, they must RAUW in their replacement
106 // instructions during the initial visit.
107 //
108 // When these new instructions are created, they use the split parts recorded
109 // for their initial arguments in order to generate their replacements, creating
110 // a parallel set of instructions that does not refer to the original fat
111 // pointer values but instead to their resource and offset components.
112 //
113 // Instructions, such as `extractvalue`, that produce buffer fat pointers from
114 // sources that do not have split parts, have such parts generated using
115 // `extractvalue`. This is also the initial handling of PHI nodes, which
116 // are then cleaned up.
117 //
118 // ### Conditionals
119 //
120 // PHI nodes are initially given resource parts via `extractvalue`. However,
121 // this is not an efficient rewrite of such nodes, as, in most cases, the
122 // resource part in a conditional or loop remains constant throughout the loop
123 // and only the offset varies. Failing to optimize away these constant resources
124 // would cause additional registers to be sent around loops and might lead to
125 // waterfall loops being generated for buffer operations due to the
126 // "non-uniform" resource argument.
127 //
128 // Therefore, after all instructions have been visited, the pointer splitter
129 // post-processes all encountered conditionals. Given a PHI node or select,
130 // getPossibleRsrcRoots() collects all values that the resource parts of that
131 // conditional's input could come from as well as collecting all conditional
132 // instructions encountered during the search. If, after filtering out the
133 // initial node itself, the set of encountered conditionals is a subset of the
134 // potential roots and there is a single potential resource that isn't in the
135 // conditional set, that value is the only possible value the resource argument
136 // could have throughout the control flow.
137 //
138 // If that condition is met, then a PHI node can have its resource part changed
139 // to the singleton value and then be replaced by a PHI on the offsets.
140 // Otherwise, each PHI node is split into two, one for the resource part and one
141 // for the offset part, which replace the temporary `extractvalue` instructions
142 // that were added during the first pass.
143 //
144 // Similar logic applies to `select`, where
145 // `%z = select i1 %cond, %cond, ptr addrspace(7) %x, ptr addrspace(7) %y`
146 // can be split into `%z.rsrc = %x.rsrc` and
147 // `%z.off = select i1 %cond, ptr i32 %x.off, i32 %y.off`
148 // if both `%x` and `%y` have the same resource part, but two `select`
149 // operations will be needed if they do not.
150 //
151 // ### Final processing
152 //
153 // After conditionals have been cleaned up, the IR for each function is
154 // rewritten to remove all the old instructions that have been split up.
155 //
156 // Any instruction that used to produce a buffer fat pointer (and therefore now
157 // produces a resource-and-offset struct after type remapping) is
158 // replaced as follows:
159 // 1. All debug value annotations are cloned to reflect that the resource part
160 //    and offset parts are computed separately and constitute different
161 //    fragments of the underlying source language variable.
162 // 2. All uses that were themselves split are replaced by a `poison` of the
163 //    struct type, as they will themselves be erased soon. This rule, combined
164 //    with debug handling, should leave the use lists of split instructions
165 //    empty in almost all cases.
166 // 3. If a user of the original struct-valued result remains, the structure
167 //    needed for the new types to work is constructed out of the newly-defined
168 //    parts, and the original instruction is replaced by this structure
169 //    before being erased. Instructions requiring this construction include
170 //    `ret` and `insertvalue`.
171 //
172 // # Consequences
173 //
174 // This pass does not alter the CFG.
175 //
176 // Alias analysis information will become coarser, as the LLVM alias analyzer
177 // cannot handle the buffer intrinsics. Specifically, while we can determine
178 // that the following two loads do not alias:
179 // ```
180 //   %y = getelementptr i32, ptr addrspace(7) %x, i32 1
181 //   %a = load i32, ptr addrspace(7) %x
182 //   %b = load i32, ptr addrspace(7) %y
183 // ```
184 // we cannot (except through some code that runs during scheduling) determine
185 // that the rewritten loads below do not alias.
186 // ```
187 //   %y.off = add i32 %x.off, 1
188 //   %a = call @llvm.amdgcn.raw.ptr.buffer.load(ptr addrspace(8) %x.rsrc, i32
189 //     %x.off, ...)
190 //   %b = call @llvm.amdgcn.raw.ptr.buffer.load(ptr addrspace(8)
191 //     %x.rsrc, i32 %y.off, ...)
192 // ```
193 // However, existing alias information is preserved.
194 //===----------------------------------------------------------------------===//
195 
196 #include "AMDGPU.h"
197 #include "AMDGPUTargetMachine.h"
198 #include "GCNSubtarget.h"
199 #include "SIDefines.h"
200 #include "llvm/ADT/SetOperations.h"
201 #include "llvm/ADT/SmallVector.h"
202 #include "llvm/Analysis/ConstantFolding.h"
203 #include "llvm/Analysis/Utils/Local.h"
204 #include "llvm/CodeGen/TargetPassConfig.h"
205 #include "llvm/IR/AttributeMask.h"
206 #include "llvm/IR/Constants.h"
207 #include "llvm/IR/DebugInfo.h"
208 #include "llvm/IR/DerivedTypes.h"
209 #include "llvm/IR/IRBuilder.h"
210 #include "llvm/IR/InstIterator.h"
211 #include "llvm/IR/InstVisitor.h"
212 #include "llvm/IR/Instructions.h"
213 #include "llvm/IR/Intrinsics.h"
214 #include "llvm/IR/IntrinsicsAMDGPU.h"
215 #include "llvm/IR/Metadata.h"
216 #include "llvm/IR/Operator.h"
217 #include "llvm/IR/PatternMatch.h"
218 #include "llvm/IR/ReplaceConstant.h"
219 #include "llvm/InitializePasses.h"
220 #include "llvm/Pass.h"
221 #include "llvm/Support/AtomicOrdering.h"
222 #include "llvm/Support/Debug.h"
223 #include "llvm/Support/ErrorHandling.h"
224 #include "llvm/Transforms/Utils/Cloning.h"
225 #include "llvm/Transforms/Utils/Local.h"
226 #include "llvm/Transforms/Utils/ValueMapper.h"
227 
228 #define DEBUG_TYPE "amdgpu-lower-buffer-fat-pointers"
229 
230 using namespace llvm;
231 
232 static constexpr unsigned BufferOffsetWidth = 32;
233 
234 namespace {
235 /// Recursively replace instances of ptr addrspace(7) and vector<Nxptr
236 /// addrspace(7)> with some other type as defined by the relevant subclass.
237 class BufferFatPtrTypeLoweringBase : public ValueMapTypeRemapper {
238   DenseMap<Type *, Type *> Map;
239 
240   Type *remapTypeImpl(Type *Ty, SmallPtrSetImpl<StructType *> &Seen);
241 
242 protected:
243   virtual Type *remapScalar(PointerType *PT) = 0;
244   virtual Type *remapVector(VectorType *VT) = 0;
245 
246   const DataLayout &DL;
247 
248 public:
BufferFatPtrTypeLoweringBase(const DataLayout & DL)249   BufferFatPtrTypeLoweringBase(const DataLayout &DL) : DL(DL) {}
250   Type *remapType(Type *SrcTy) override;
clear()251   void clear() { Map.clear(); }
252 };
253 
254 /// Remap ptr addrspace(7) to i160 and vector<Nxptr addrspace(7)> to
255 /// vector<Nxi60> in order to correctly handling loading/storing these values
256 /// from memory.
257 class BufferFatPtrToIntTypeMap : public BufferFatPtrTypeLoweringBase {
258   using BufferFatPtrTypeLoweringBase::BufferFatPtrTypeLoweringBase;
259 
260 protected:
remapScalar(PointerType * PT)261   Type *remapScalar(PointerType *PT) override { return DL.getIntPtrType(PT); }
remapVector(VectorType * VT)262   Type *remapVector(VectorType *VT) override { return DL.getIntPtrType(VT); }
263 };
264 
265 /// Remap ptr addrspace(7) to {ptr addrspace(8), i32} (the resource and offset
266 /// parts of the pointer) so that we can easily rewrite operations on these
267 /// values that aren't loading them from or storing them to memory.
268 class BufferFatPtrToStructTypeMap : public BufferFatPtrTypeLoweringBase {
269   using BufferFatPtrTypeLoweringBase::BufferFatPtrTypeLoweringBase;
270 
271 protected:
272   Type *remapScalar(PointerType *PT) override;
273   Type *remapVector(VectorType *VT) override;
274 };
275 } // namespace
276 
277 // This code is adapted from the type remapper in lib/Linker/IRMover.cpp
remapTypeImpl(Type * Ty,SmallPtrSetImpl<StructType * > & Seen)278 Type *BufferFatPtrTypeLoweringBase::remapTypeImpl(
279     Type *Ty, SmallPtrSetImpl<StructType *> &Seen) {
280   Type **Entry = &Map[Ty];
281   if (*Entry)
282     return *Entry;
283   if (auto *PT = dyn_cast<PointerType>(Ty)) {
284     if (PT->getAddressSpace() == AMDGPUAS::BUFFER_FAT_POINTER) {
285       return *Entry = remapScalar(PT);
286     }
287   }
288   if (auto *VT = dyn_cast<VectorType>(Ty)) {
289     auto *PT = dyn_cast<PointerType>(VT->getElementType());
290     if (PT && PT->getAddressSpace() == AMDGPUAS::BUFFER_FAT_POINTER) {
291       return *Entry = remapVector(VT);
292     }
293     return *Entry = Ty;
294   }
295   // Whether the type is one that is structurally uniqued - that is, if it is
296   // not a named struct (the only kind of type where multiple structurally
297   // identical types that have a distinct `Type*`)
298   StructType *TyAsStruct = dyn_cast<StructType>(Ty);
299   bool IsUniqued = !TyAsStruct || TyAsStruct->isLiteral();
300   // Base case for ints, floats, opaque pointers, and so on, which don't
301   // require recursion.
302   if (Ty->getNumContainedTypes() == 0 && IsUniqued)
303     return *Entry = Ty;
304   if (!IsUniqued) {
305     // Create a dummy type for recursion purposes.
306     if (!Seen.insert(TyAsStruct).second) {
307       StructType *Placeholder = StructType::create(Ty->getContext());
308       return *Entry = Placeholder;
309     }
310   }
311   bool Changed = false;
312   SmallVector<Type *> ElementTypes(Ty->getNumContainedTypes(), nullptr);
313   for (unsigned int I = 0, E = Ty->getNumContainedTypes(); I < E; ++I) {
314     Type *OldElem = Ty->getContainedType(I);
315     Type *NewElem = remapTypeImpl(OldElem, Seen);
316     ElementTypes[I] = NewElem;
317     Changed |= (OldElem != NewElem);
318   }
319   // Recursive calls to remapTypeImpl() may have invalidated pointer.
320   Entry = &Map[Ty];
321   if (!Changed) {
322     return *Entry = Ty;
323   }
324   if (auto *ArrTy = dyn_cast<ArrayType>(Ty))
325     return *Entry = ArrayType::get(ElementTypes[0], ArrTy->getNumElements());
326   if (auto *FnTy = dyn_cast<FunctionType>(Ty))
327     return *Entry = FunctionType::get(ElementTypes[0],
328                                       ArrayRef(ElementTypes).slice(1),
329                                       FnTy->isVarArg());
330   if (auto *STy = dyn_cast<StructType>(Ty)) {
331     // Genuine opaque types don't have a remapping.
332     if (STy->isOpaque())
333       return *Entry = Ty;
334     bool IsPacked = STy->isPacked();
335     if (IsUniqued)
336       return *Entry = StructType::get(Ty->getContext(), ElementTypes, IsPacked);
337     SmallString<16> Name(STy->getName());
338     STy->setName("");
339     Type **RecursionEntry = &Map[Ty];
340     if (*RecursionEntry) {
341       auto *Placeholder = cast<StructType>(*RecursionEntry);
342       Placeholder->setBody(ElementTypes, IsPacked);
343       Placeholder->setName(Name);
344       return *Entry = Placeholder;
345     }
346     return *Entry = StructType::create(Ty->getContext(), ElementTypes, Name,
347                                        IsPacked);
348   }
349   llvm_unreachable("Unknown type of type that contains elements");
350 }
351 
remapType(Type * SrcTy)352 Type *BufferFatPtrTypeLoweringBase::remapType(Type *SrcTy) {
353   SmallPtrSet<StructType *, 2> Visited;
354   return remapTypeImpl(SrcTy, Visited);
355 }
356 
remapScalar(PointerType * PT)357 Type *BufferFatPtrToStructTypeMap::remapScalar(PointerType *PT) {
358   LLVMContext &Ctx = PT->getContext();
359   return StructType::get(PointerType::get(Ctx, AMDGPUAS::BUFFER_RESOURCE),
360                          IntegerType::get(Ctx, BufferOffsetWidth));
361 }
362 
remapVector(VectorType * VT)363 Type *BufferFatPtrToStructTypeMap::remapVector(VectorType *VT) {
364   ElementCount EC = VT->getElementCount();
365   LLVMContext &Ctx = VT->getContext();
366   Type *RsrcVec =
367       VectorType::get(PointerType::get(Ctx, AMDGPUAS::BUFFER_RESOURCE), EC);
368   Type *OffVec = VectorType::get(IntegerType::get(Ctx, BufferOffsetWidth), EC);
369   return StructType::get(RsrcVec, OffVec);
370 }
371 
isBufferFatPtrOrVector(Type * Ty)372 static bool isBufferFatPtrOrVector(Type *Ty) {
373   if (auto *PT = dyn_cast<PointerType>(Ty->getScalarType()))
374     return PT->getAddressSpace() == AMDGPUAS::BUFFER_FAT_POINTER;
375   return false;
376 }
377 
378 // True if the type is {ptr addrspace(8), i32} or a struct containing vectors of
379 // those types. Used to quickly skip instructions we don't need to process.
isSplitFatPtr(Type * Ty)380 static bool isSplitFatPtr(Type *Ty) {
381   auto *ST = dyn_cast<StructType>(Ty);
382   if (!ST)
383     return false;
384   if (!ST->isLiteral() || ST->getNumElements() != 2)
385     return false;
386   auto *MaybeRsrc =
387       dyn_cast<PointerType>(ST->getElementType(0)->getScalarType());
388   auto *MaybeOff =
389       dyn_cast<IntegerType>(ST->getElementType(1)->getScalarType());
390   return MaybeRsrc && MaybeOff &&
391          MaybeRsrc->getAddressSpace() == AMDGPUAS::BUFFER_RESOURCE &&
392          MaybeOff->getBitWidth() == BufferOffsetWidth;
393 }
394 
395 // True if the result type or any argument types are buffer fat pointers.
isBufferFatPtrConst(Constant * C)396 static bool isBufferFatPtrConst(Constant *C) {
397   Type *T = C->getType();
398   return isBufferFatPtrOrVector(T) || any_of(C->operands(), [](const Use &U) {
399            return isBufferFatPtrOrVector(U.get()->getType());
400          });
401 }
402 
403 namespace {
404 /// Convert [vectors of] buffer fat pointers to integers when they are read from
405 /// or stored to memory. This ensures that these pointers will have the same
406 /// memory layout as before they are lowered, even though they will no longer
407 /// have their previous layout in registers/in the program (they'll be broken
408 /// down into resource and offset parts). This has the downside of imposing
409 /// marshalling costs when reading or storing these values, but since placing
410 /// such pointers into memory is an uncommon operation at best, we feel that
411 /// this cost is acceptable for better performance in the common case.
412 class StoreFatPtrsAsIntsVisitor
413     : public InstVisitor<StoreFatPtrsAsIntsVisitor, bool> {
414   BufferFatPtrToIntTypeMap *TypeMap;
415 
416   ValueToValueMapTy ConvertedForStore;
417 
418   IRBuilder<> IRB;
419 
420   // Convert all the buffer fat pointers within the input value to inttegers
421   // so that it can be stored in memory.
422   Value *fatPtrsToInts(Value *V, Type *From, Type *To, const Twine &Name);
423   // Convert all the i160s that need to be buffer fat pointers (as specified)
424   // by the To type) into those pointers to preserve the semantics of the rest
425   // of the program.
426   Value *intsToFatPtrs(Value *V, Type *From, Type *To, const Twine &Name);
427 
428 public:
StoreFatPtrsAsIntsVisitor(BufferFatPtrToIntTypeMap * TypeMap,LLVMContext & Ctx)429   StoreFatPtrsAsIntsVisitor(BufferFatPtrToIntTypeMap *TypeMap, LLVMContext &Ctx)
430       : TypeMap(TypeMap), IRB(Ctx) {}
431   bool processFunction(Function &F);
432 
visitInstruction(Instruction & I)433   bool visitInstruction(Instruction &I) { return false; }
434   bool visitAllocaInst(AllocaInst &I);
435   bool visitLoadInst(LoadInst &LI);
436   bool visitStoreInst(StoreInst &SI);
437   bool visitGetElementPtrInst(GetElementPtrInst &I);
438 };
439 } // namespace
440 
fatPtrsToInts(Value * V,Type * From,Type * To,const Twine & Name)441 Value *StoreFatPtrsAsIntsVisitor::fatPtrsToInts(Value *V, Type *From, Type *To,
442                                                 const Twine &Name) {
443   if (From == To)
444     return V;
445   ValueToValueMapTy::iterator Find = ConvertedForStore.find(V);
446   if (Find != ConvertedForStore.end())
447     return Find->second;
448   if (isBufferFatPtrOrVector(From)) {
449     Value *Cast = IRB.CreatePtrToInt(V, To, Name + ".int");
450     ConvertedForStore[V] = Cast;
451     return Cast;
452   }
453   if (From->getNumContainedTypes() == 0)
454     return V;
455   // Structs, arrays, and other compound types.
456   Value *Ret = PoisonValue::get(To);
457   if (auto *AT = dyn_cast<ArrayType>(From)) {
458     Type *FromPart = AT->getArrayElementType();
459     Type *ToPart = cast<ArrayType>(To)->getElementType();
460     for (uint64_t I = 0, E = AT->getArrayNumElements(); I < E; ++I) {
461       Value *Field = IRB.CreateExtractValue(V, I);
462       Value *NewField =
463           fatPtrsToInts(Field, FromPart, ToPart, Name + "." + Twine(I));
464       Ret = IRB.CreateInsertValue(Ret, NewField, I);
465     }
466   } else {
467     for (auto [Idx, FromPart, ToPart] :
468          enumerate(From->subtypes(), To->subtypes())) {
469       Value *Field = IRB.CreateExtractValue(V, Idx);
470       Value *NewField =
471           fatPtrsToInts(Field, FromPart, ToPart, Name + "." + Twine(Idx));
472       Ret = IRB.CreateInsertValue(Ret, NewField, Idx);
473     }
474   }
475   ConvertedForStore[V] = Ret;
476   return Ret;
477 }
478 
intsToFatPtrs(Value * V,Type * From,Type * To,const Twine & Name)479 Value *StoreFatPtrsAsIntsVisitor::intsToFatPtrs(Value *V, Type *From, Type *To,
480                                                 const Twine &Name) {
481   if (From == To)
482     return V;
483   if (isBufferFatPtrOrVector(To)) {
484     Value *Cast = IRB.CreateIntToPtr(V, To, Name + ".ptr");
485     return Cast;
486   }
487   if (From->getNumContainedTypes() == 0)
488     return V;
489   // Structs, arrays, and other compound types.
490   Value *Ret = PoisonValue::get(To);
491   if (auto *AT = dyn_cast<ArrayType>(From)) {
492     Type *FromPart = AT->getArrayElementType();
493     Type *ToPart = cast<ArrayType>(To)->getElementType();
494     for (uint64_t I = 0, E = AT->getArrayNumElements(); I < E; ++I) {
495       Value *Field = IRB.CreateExtractValue(V, I);
496       Value *NewField =
497           intsToFatPtrs(Field, FromPart, ToPart, Name + "." + Twine(I));
498       Ret = IRB.CreateInsertValue(Ret, NewField, I);
499     }
500   } else {
501     for (auto [Idx, FromPart, ToPart] :
502          enumerate(From->subtypes(), To->subtypes())) {
503       Value *Field = IRB.CreateExtractValue(V, Idx);
504       Value *NewField =
505           intsToFatPtrs(Field, FromPart, ToPart, Name + "." + Twine(Idx));
506       Ret = IRB.CreateInsertValue(Ret, NewField, Idx);
507     }
508   }
509   return Ret;
510 }
511 
processFunction(Function & F)512 bool StoreFatPtrsAsIntsVisitor::processFunction(Function &F) {
513   bool Changed = false;
514   // The visitors will mutate GEPs and allocas, but will push loads and stores
515   // to the worklist to avoid invalidation.
516   for (Instruction &I : make_early_inc_range(instructions(F))) {
517     Changed |= visit(I);
518   }
519   ConvertedForStore.clear();
520   return Changed;
521 }
522 
visitAllocaInst(AllocaInst & I)523 bool StoreFatPtrsAsIntsVisitor::visitAllocaInst(AllocaInst &I) {
524   Type *Ty = I.getAllocatedType();
525   Type *NewTy = TypeMap->remapType(Ty);
526   if (Ty == NewTy)
527     return false;
528   I.setAllocatedType(NewTy);
529   return true;
530 }
531 
visitGetElementPtrInst(GetElementPtrInst & I)532 bool StoreFatPtrsAsIntsVisitor::visitGetElementPtrInst(GetElementPtrInst &I) {
533   Type *Ty = I.getSourceElementType();
534   Type *NewTy = TypeMap->remapType(Ty);
535   if (Ty == NewTy)
536     return false;
537   // We'll be rewriting the type `ptr addrspace(7)` out of existence soon, so
538   // make sure GEPs don't have different semantics with the new type.
539   I.setSourceElementType(NewTy);
540   I.setResultElementType(TypeMap->remapType(I.getResultElementType()));
541   return true;
542 }
543 
visitLoadInst(LoadInst & LI)544 bool StoreFatPtrsAsIntsVisitor::visitLoadInst(LoadInst &LI) {
545   Type *Ty = LI.getType();
546   Type *IntTy = TypeMap->remapType(Ty);
547   if (Ty == IntTy)
548     return false;
549 
550   IRB.SetInsertPoint(&LI);
551   auto *NLI = cast<LoadInst>(LI.clone());
552   NLI->mutateType(IntTy);
553   NLI = IRB.Insert(NLI);
554   copyMetadataForLoad(*NLI, LI);
555   NLI->takeName(&LI);
556 
557   Value *CastBack = intsToFatPtrs(NLI, IntTy, Ty, NLI->getName());
558   LI.replaceAllUsesWith(CastBack);
559   LI.eraseFromParent();
560   return true;
561 }
562 
visitStoreInst(StoreInst & SI)563 bool StoreFatPtrsAsIntsVisitor::visitStoreInst(StoreInst &SI) {
564   Value *V = SI.getValueOperand();
565   Type *Ty = V->getType();
566   Type *IntTy = TypeMap->remapType(Ty);
567   if (Ty == IntTy)
568     return false;
569 
570   IRB.SetInsertPoint(&SI);
571   Value *IntV = fatPtrsToInts(V, Ty, IntTy, V->getName());
572   for (auto *Dbg : at::getAssignmentMarkers(&SI))
573     Dbg->setValue(IntV);
574 
575   SI.setOperand(0, IntV);
576   return true;
577 }
578 
579 /// Return the ptr addrspace(8) and i32 (resource and offset parts) in a lowered
580 /// buffer fat pointer constant.
581 static std::pair<Constant *, Constant *>
splitLoweredFatBufferConst(Constant * C)582 splitLoweredFatBufferConst(Constant *C) {
583   assert(isSplitFatPtr(C->getType()) && "Not a split fat buffer pointer");
584   return std::make_pair(C->getAggregateElement(0u), C->getAggregateElement(1u));
585 }
586 
587 namespace {
588 /// Handle the remapping of ptr addrspace(7) constants.
589 class FatPtrConstMaterializer final : public ValueMaterializer {
590   BufferFatPtrToStructTypeMap *TypeMap;
591   // An internal mapper that is used to recurse into the arguments of constants.
592   // While the documentation for `ValueMapper` specifies not to use it
593   // recursively, examination of the logic in mapValue() shows that it can
594   // safely be used recursively when handling constants, like it does in its own
595   // logic.
596   ValueMapper InternalMapper;
597 
598   Constant *materializeBufferFatPtrConst(Constant *C);
599 
600 public:
601   // UnderlyingMap is the value map this materializer will be filling.
FatPtrConstMaterializer(BufferFatPtrToStructTypeMap * TypeMap,ValueToValueMapTy & UnderlyingMap)602   FatPtrConstMaterializer(BufferFatPtrToStructTypeMap *TypeMap,
603                           ValueToValueMapTy &UnderlyingMap)
604       : TypeMap(TypeMap),
605         InternalMapper(UnderlyingMap, RF_None, TypeMap, this) {}
606   virtual ~FatPtrConstMaterializer() = default;
607 
608   Value *materialize(Value *V) override;
609 };
610 } // namespace
611 
materializeBufferFatPtrConst(Constant * C)612 Constant *FatPtrConstMaterializer::materializeBufferFatPtrConst(Constant *C) {
613   Type *SrcTy = C->getType();
614   auto *NewTy = dyn_cast<StructType>(TypeMap->remapType(SrcTy));
615   if (C->isNullValue())
616     return ConstantAggregateZero::getNullValue(NewTy);
617   if (isa<PoisonValue>(C)) {
618     return ConstantStruct::get(NewTy,
619                                {PoisonValue::get(NewTy->getElementType(0)),
620                                 PoisonValue::get(NewTy->getElementType(1))});
621   }
622   if (isa<UndefValue>(C)) {
623     return ConstantStruct::get(NewTy,
624                                {UndefValue::get(NewTy->getElementType(0)),
625                                 UndefValue::get(NewTy->getElementType(1))});
626   }
627 
628   if (auto *VC = dyn_cast<ConstantVector>(C)) {
629     if (Constant *S = VC->getSplatValue()) {
630       Constant *NewS = InternalMapper.mapConstant(*S);
631       if (!NewS)
632         return nullptr;
633       auto [Rsrc, Off] = splitLoweredFatBufferConst(NewS);
634       auto EC = VC->getType()->getElementCount();
635       return ConstantStruct::get(NewTy, {ConstantVector::getSplat(EC, Rsrc),
636                                          ConstantVector::getSplat(EC, Off)});
637     }
638     SmallVector<Constant *> Rsrcs;
639     SmallVector<Constant *> Offs;
640     for (Value *Op : VC->operand_values()) {
641       auto *NewOp = dyn_cast_or_null<Constant>(InternalMapper.mapValue(*Op));
642       if (!NewOp)
643         return nullptr;
644       auto [Rsrc, Off] = splitLoweredFatBufferConst(NewOp);
645       Rsrcs.push_back(Rsrc);
646       Offs.push_back(Off);
647     }
648     Constant *RsrcVec = ConstantVector::get(Rsrcs);
649     Constant *OffVec = ConstantVector::get(Offs);
650     return ConstantStruct::get(NewTy, {RsrcVec, OffVec});
651   }
652 
653   if (isa<GlobalValue>(C))
654     report_fatal_error("Global values containing ptr addrspace(7) (buffer "
655                        "fat pointer) values are not supported");
656 
657   if (isa<ConstantExpr>(C))
658     report_fatal_error("Constant exprs containing ptr addrspace(7) (buffer "
659                        "fat pointer) values should have been expanded earlier");
660 
661   return nullptr;
662 }
663 
materialize(Value * V)664 Value *FatPtrConstMaterializer::materialize(Value *V) {
665   Constant *C = dyn_cast<Constant>(V);
666   if (!C)
667     return nullptr;
668   // Structs and other types that happen to contain fat pointers get remapped
669   // by the mapValue() logic.
670   if (!isBufferFatPtrConst(C))
671     return nullptr;
672   return materializeBufferFatPtrConst(C);
673 }
674 
675 using PtrParts = std::pair<Value *, Value *>;
676 namespace {
677 // The visitor returns the resource and offset parts for an instruction if they
678 // can be computed, or (nullptr, nullptr) for cases that don't have a meaningful
679 // value mapping.
680 class SplitPtrStructs : public InstVisitor<SplitPtrStructs, PtrParts> {
681   ValueToValueMapTy RsrcParts;
682   ValueToValueMapTy OffParts;
683 
684   // Track instructions that have been rewritten into a user of the component
685   // parts of their ptr addrspace(7) input. Instructions that produced
686   // ptr addrspace(7) parts should **not** be RAUW'd before being added to this
687   // set, as that replacement will be handled in a post-visit step. However,
688   // instructions that yield values that aren't fat pointers (ex. ptrtoint)
689   // should RAUW themselves with new instructions that use the split parts
690   // of their arguments during processing.
691   DenseSet<Instruction *> SplitUsers;
692 
693   // Nodes that need a second look once we've computed the parts for all other
694   // instructions to see if, for example, we really need to phi on the resource
695   // part.
696   SmallVector<Instruction *> Conditionals;
697   // Temporary instructions produced while lowering conditionals that should be
698   // killed.
699   SmallVector<Instruction *> ConditionalTemps;
700 
701   // Subtarget info, needed for determining what cache control bits to set.
702   const TargetMachine *TM;
703   const GCNSubtarget *ST = nullptr;
704 
705   IRBuilder<> IRB;
706 
707   // Copy metadata between instructions if applicable.
708   void copyMetadata(Value *Dest, Value *Src);
709 
710   // Get the resource and offset parts of the value V, inserting appropriate
711   // extractvalue calls if needed.
712   PtrParts getPtrParts(Value *V);
713 
714   // Given an instruction that could produce multiple resource parts (a PHI or
715   // select), collect the set of possible instructions that could have provided
716   // its resource parts  that it could have (the `Roots`) and the set of
717   // conditional instructions visited during the search (`Seen`). If, after
718   // removing the root of the search from `Seen` and `Roots`, `Seen` is a subset
719   // of `Roots` and `Roots - Seen` contains one element, the resource part of
720   // that element can replace the resource part of all other elements in `Seen`.
721   void getPossibleRsrcRoots(Instruction *I, SmallPtrSetImpl<Value *> &Roots,
722                             SmallPtrSetImpl<Value *> &Seen);
723   void processConditionals();
724 
725   // If an instruction hav been split into resource and offset parts,
726   // delete that instruction. If any of its uses have not themselves been split
727   // into parts (for example, an insertvalue), construct the structure
728   // that the type rewrites declared should be produced by the dying instruction
729   // and use that.
730   // Also, kill the temporary extractvalue operations produced by the two-stage
731   // lowering of PHIs and conditionals.
732   void killAndReplaceSplitInstructions(SmallVectorImpl<Instruction *> &Origs);
733 
734   void setAlign(CallInst *Intr, Align A, unsigned RsrcArgIdx);
735   void insertPreMemOpFence(AtomicOrdering Order, SyncScope::ID SSID);
736   void insertPostMemOpFence(AtomicOrdering Order, SyncScope::ID SSID);
737   Value *handleMemoryInst(Instruction *I, Value *Arg, Value *Ptr, Type *Ty,
738                           Align Alignment, AtomicOrdering Order,
739                           bool IsVolatile, SyncScope::ID SSID);
740 
741 public:
SplitPtrStructs(LLVMContext & Ctx,const TargetMachine * TM)742   SplitPtrStructs(LLVMContext &Ctx, const TargetMachine *TM)
743       : TM(TM), IRB(Ctx) {}
744 
745   void processFunction(Function &F);
746 
747   PtrParts visitInstruction(Instruction &I);
748   PtrParts visitLoadInst(LoadInst &LI);
749   PtrParts visitStoreInst(StoreInst &SI);
750   PtrParts visitAtomicRMWInst(AtomicRMWInst &AI);
751   PtrParts visitAtomicCmpXchgInst(AtomicCmpXchgInst &AI);
752   PtrParts visitGetElementPtrInst(GetElementPtrInst &GEP);
753 
754   PtrParts visitPtrToIntInst(PtrToIntInst &PI);
755   PtrParts visitIntToPtrInst(IntToPtrInst &IP);
756   PtrParts visitAddrSpaceCastInst(AddrSpaceCastInst &I);
757   PtrParts visitICmpInst(ICmpInst &Cmp);
758   PtrParts visitFreezeInst(FreezeInst &I);
759 
760   PtrParts visitExtractElementInst(ExtractElementInst &I);
761   PtrParts visitInsertElementInst(InsertElementInst &I);
762   PtrParts visitShuffleVectorInst(ShuffleVectorInst &I);
763 
764   PtrParts visitPHINode(PHINode &PHI);
765   PtrParts visitSelectInst(SelectInst &SI);
766 
767   PtrParts visitIntrinsicInst(IntrinsicInst &II);
768 };
769 } // namespace
770 
copyMetadata(Value * Dest,Value * Src)771 void SplitPtrStructs::copyMetadata(Value *Dest, Value *Src) {
772   auto *DestI = dyn_cast<Instruction>(Dest);
773   auto *SrcI = dyn_cast<Instruction>(Src);
774 
775   if (!DestI || !SrcI)
776     return;
777 
778   DestI->copyMetadata(*SrcI);
779 }
780 
getPtrParts(Value * V)781 PtrParts SplitPtrStructs::getPtrParts(Value *V) {
782   assert(isSplitFatPtr(V->getType()) && "it's not meaningful to get the parts "
783                                         "of something that wasn't rewritten");
784   auto *RsrcEntry = &RsrcParts[V];
785   auto *OffEntry = &OffParts[V];
786   if (*RsrcEntry && *OffEntry)
787     return {*RsrcEntry, *OffEntry};
788 
789   if (auto *C = dyn_cast<Constant>(V)) {
790     auto [Rsrc, Off] = splitLoweredFatBufferConst(C);
791     return {*RsrcEntry = Rsrc, *OffEntry = Off};
792   }
793 
794   IRBuilder<>::InsertPointGuard Guard(IRB);
795   if (auto *I = dyn_cast<Instruction>(V)) {
796     LLVM_DEBUG(dbgs() << "Recursing to split parts of " << *I << "\n");
797     auto [Rsrc, Off] = visit(*I);
798     if (Rsrc && Off)
799       return {*RsrcEntry = Rsrc, *OffEntry = Off};
800     // We'll be creating the new values after the relevant instruction.
801     // This instruction generates a value and so isn't a terminator.
802     IRB.SetInsertPoint(*I->getInsertionPointAfterDef());
803     IRB.SetCurrentDebugLocation(I->getDebugLoc());
804   } else if (auto *A = dyn_cast<Argument>(V)) {
805     IRB.SetInsertPointPastAllocas(A->getParent());
806     IRB.SetCurrentDebugLocation(DebugLoc());
807   }
808   Value *Rsrc = IRB.CreateExtractValue(V, 0, V->getName() + ".rsrc");
809   Value *Off = IRB.CreateExtractValue(V, 1, V->getName() + ".off");
810   return {*RsrcEntry = Rsrc, *OffEntry = Off};
811 }
812 
813 /// Returns the instruction that defines the resource part of the value V.
814 /// Note that this is not getUnderlyingObject(), since that looks through
815 /// operations like ptrmask which might modify the resource part.
816 ///
817 /// We can limit ourselves to just looking through GEPs followed by looking
818 /// through addrspacecasts because only those two operations preserve the
819 /// resource part, and because operations on an `addrspace(8)` (which is the
820 /// legal input to this addrspacecast) would produce a different resource part.
rsrcPartRoot(Value * V)821 static Value *rsrcPartRoot(Value *V) {
822   while (auto *GEP = dyn_cast<GEPOperator>(V))
823     V = GEP->getPointerOperand();
824   while (auto *ASC = dyn_cast<AddrSpaceCastOperator>(V))
825     V = ASC->getPointerOperand();
826   return V;
827 }
828 
getPossibleRsrcRoots(Instruction * I,SmallPtrSetImpl<Value * > & Roots,SmallPtrSetImpl<Value * > & Seen)829 void SplitPtrStructs::getPossibleRsrcRoots(Instruction *I,
830                                            SmallPtrSetImpl<Value *> &Roots,
831                                            SmallPtrSetImpl<Value *> &Seen) {
832   if (auto *PHI = dyn_cast<PHINode>(I)) {
833     if (!Seen.insert(I).second)
834       return;
835     for (Value *In : PHI->incoming_values()) {
836       In = rsrcPartRoot(In);
837       Roots.insert(In);
838       if (isa<PHINode, SelectInst>(In))
839         getPossibleRsrcRoots(cast<Instruction>(In), Roots, Seen);
840     }
841   } else if (auto *SI = dyn_cast<SelectInst>(I)) {
842     if (!Seen.insert(SI).second)
843       return;
844     Value *TrueVal = rsrcPartRoot(SI->getTrueValue());
845     Value *FalseVal = rsrcPartRoot(SI->getFalseValue());
846     Roots.insert(TrueVal);
847     Roots.insert(FalseVal);
848     if (isa<PHINode, SelectInst>(TrueVal))
849       getPossibleRsrcRoots(cast<Instruction>(TrueVal), Roots, Seen);
850     if (isa<PHINode, SelectInst>(FalseVal))
851       getPossibleRsrcRoots(cast<Instruction>(FalseVal), Roots, Seen);
852   } else {
853     llvm_unreachable("getPossibleRsrcParts() only works on phi and select");
854   }
855 }
856 
processConditionals()857 void SplitPtrStructs::processConditionals() {
858   SmallDenseMap<Instruction *, Value *> FoundRsrcs;
859   SmallPtrSet<Value *, 4> Roots;
860   SmallPtrSet<Value *, 4> Seen;
861   for (Instruction *I : Conditionals) {
862     // These have to exist by now because we've visited these nodes.
863     Value *Rsrc = RsrcParts[I];
864     Value *Off = OffParts[I];
865     assert(Rsrc && Off && "must have visited conditionals by now");
866 
867     std::optional<Value *> MaybeRsrc;
868     auto MaybeFoundRsrc = FoundRsrcs.find(I);
869     if (MaybeFoundRsrc != FoundRsrcs.end()) {
870       MaybeRsrc = MaybeFoundRsrc->second;
871     } else {
872       IRBuilder<>::InsertPointGuard Guard(IRB);
873       Roots.clear();
874       Seen.clear();
875       getPossibleRsrcRoots(I, Roots, Seen);
876       LLVM_DEBUG(dbgs() << "Processing conditional: " << *I << "\n");
877 #ifndef NDEBUG
878       for (Value *V : Roots)
879         LLVM_DEBUG(dbgs() << "Root: " << *V << "\n");
880       for (Value *V : Seen)
881         LLVM_DEBUG(dbgs() << "Seen: " << *V << "\n");
882 #endif
883       // If we are our own possible root, then we shouldn't block our
884       // replacement with a valid incoming value.
885       Roots.erase(I);
886       // We don't want to block the optimization for conditionals that don't
887       // refer to themselves but did see themselves during the traversal.
888       Seen.erase(I);
889 
890       if (set_is_subset(Seen, Roots)) {
891         auto Diff = set_difference(Roots, Seen);
892         if (Diff.size() == 1) {
893           Value *RootVal = *Diff.begin();
894           // Handle the case where previous loops already looked through
895           // an addrspacecast.
896           if (isSplitFatPtr(RootVal->getType()))
897             MaybeRsrc = std::get<0>(getPtrParts(RootVal));
898           else
899             MaybeRsrc = RootVal;
900         }
901       }
902     }
903 
904     if (auto *PHI = dyn_cast<PHINode>(I)) {
905       Value *NewRsrc;
906       StructType *PHITy = cast<StructType>(PHI->getType());
907       IRB.SetInsertPoint(*PHI->getInsertionPointAfterDef());
908       IRB.SetCurrentDebugLocation(PHI->getDebugLoc());
909       if (MaybeRsrc) {
910         NewRsrc = *MaybeRsrc;
911       } else {
912         Type *RsrcTy = PHITy->getElementType(0);
913         auto *RsrcPHI = IRB.CreatePHI(RsrcTy, PHI->getNumIncomingValues());
914         RsrcPHI->takeName(Rsrc);
915         for (auto [V, BB] : llvm::zip(PHI->incoming_values(), PHI->blocks())) {
916           Value *VRsrc = std::get<0>(getPtrParts(V));
917           RsrcPHI->addIncoming(VRsrc, BB);
918         }
919         copyMetadata(RsrcPHI, PHI);
920         NewRsrc = RsrcPHI;
921       }
922 
923       Type *OffTy = PHITy->getElementType(1);
924       auto *NewOff = IRB.CreatePHI(OffTy, PHI->getNumIncomingValues());
925       NewOff->takeName(Off);
926       for (auto [V, BB] : llvm::zip(PHI->incoming_values(), PHI->blocks())) {
927         assert(OffParts.count(V) && "An offset part had to be created by now");
928         Value *VOff = std::get<1>(getPtrParts(V));
929         NewOff->addIncoming(VOff, BB);
930       }
931       copyMetadata(NewOff, PHI);
932 
933       // Note: We don't eraseFromParent() the temporaries because we don't want
934       // to put the corrections maps in an inconstent state. That'll be handed
935       // during the rest of the killing. Also, `ValueToValueMapTy` guarantees
936       // that references in that map will be updated as well.
937       ConditionalTemps.push_back(cast<Instruction>(Rsrc));
938       ConditionalTemps.push_back(cast<Instruction>(Off));
939       Rsrc->replaceAllUsesWith(NewRsrc);
940       Off->replaceAllUsesWith(NewOff);
941 
942       // Save on recomputing the cycle traversals in known-root cases.
943       if (MaybeRsrc)
944         for (Value *V : Seen)
945           FoundRsrcs[cast<Instruction>(V)] = NewRsrc;
946     } else if (isa<SelectInst>(I)) {
947       if (MaybeRsrc) {
948         ConditionalTemps.push_back(cast<Instruction>(Rsrc));
949         Rsrc->replaceAllUsesWith(*MaybeRsrc);
950         for (Value *V : Seen)
951           FoundRsrcs[cast<Instruction>(V)] = *MaybeRsrc;
952       }
953     } else {
954       llvm_unreachable("Only PHIs and selects go in the conditionals list");
955     }
956   }
957 }
958 
killAndReplaceSplitInstructions(SmallVectorImpl<Instruction * > & Origs)959 void SplitPtrStructs::killAndReplaceSplitInstructions(
960     SmallVectorImpl<Instruction *> &Origs) {
961   for (Instruction *I : ConditionalTemps)
962     I->eraseFromParent();
963 
964   for (Instruction *I : Origs) {
965     if (!SplitUsers.contains(I))
966       continue;
967 
968     SmallVector<DbgValueInst *> Dbgs;
969     findDbgValues(Dbgs, I);
970     for (auto *Dbg : Dbgs) {
971       IRB.SetInsertPoint(Dbg);
972       auto &DL = I->getDataLayout();
973       assert(isSplitFatPtr(I->getType()) &&
974              "We should've RAUW'd away loads, stores, etc. at this point");
975       auto *OffDbg = cast<DbgValueInst>(Dbg->clone());
976       copyMetadata(OffDbg, Dbg);
977       auto [Rsrc, Off] = getPtrParts(I);
978 
979       int64_t RsrcSz = DL.getTypeSizeInBits(Rsrc->getType());
980       int64_t OffSz = DL.getTypeSizeInBits(Off->getType());
981 
982       std::optional<DIExpression *> RsrcExpr =
983           DIExpression::createFragmentExpression(Dbg->getExpression(), 0,
984                                                  RsrcSz);
985       std::optional<DIExpression *> OffExpr =
986           DIExpression::createFragmentExpression(Dbg->getExpression(), RsrcSz,
987                                                  OffSz);
988       if (OffExpr) {
989         OffDbg->setExpression(*OffExpr);
990         OffDbg->replaceVariableLocationOp(I, Off);
991         IRB.Insert(OffDbg);
992       } else {
993         OffDbg->deleteValue();
994       }
995       if (RsrcExpr) {
996         Dbg->setExpression(*RsrcExpr);
997         Dbg->replaceVariableLocationOp(I, Rsrc);
998       } else {
999         Dbg->replaceVariableLocationOp(I, UndefValue::get(I->getType()));
1000       }
1001     }
1002 
1003     Value *Poison = PoisonValue::get(I->getType());
1004     I->replaceUsesWithIf(Poison, [&](const Use &U) -> bool {
1005       if (const auto *UI = dyn_cast<Instruction>(U.getUser()))
1006         return SplitUsers.contains(UI);
1007       return false;
1008     });
1009 
1010     if (I->use_empty()) {
1011       I->eraseFromParent();
1012       continue;
1013     }
1014     IRB.SetInsertPoint(*I->getInsertionPointAfterDef());
1015     IRB.SetCurrentDebugLocation(I->getDebugLoc());
1016     auto [Rsrc, Off] = getPtrParts(I);
1017     Value *Struct = PoisonValue::get(I->getType());
1018     Struct = IRB.CreateInsertValue(Struct, Rsrc, 0);
1019     Struct = IRB.CreateInsertValue(Struct, Off, 1);
1020     copyMetadata(Struct, I);
1021     Struct->takeName(I);
1022     I->replaceAllUsesWith(Struct);
1023     I->eraseFromParent();
1024   }
1025 }
1026 
setAlign(CallInst * Intr,Align A,unsigned RsrcArgIdx)1027 void SplitPtrStructs::setAlign(CallInst *Intr, Align A, unsigned RsrcArgIdx) {
1028   LLVMContext &Ctx = Intr->getContext();
1029   Intr->addParamAttr(RsrcArgIdx, Attribute::getWithAlignment(Ctx, A));
1030 }
1031 
insertPreMemOpFence(AtomicOrdering Order,SyncScope::ID SSID)1032 void SplitPtrStructs::insertPreMemOpFence(AtomicOrdering Order,
1033                                           SyncScope::ID SSID) {
1034   switch (Order) {
1035   case AtomicOrdering::Release:
1036   case AtomicOrdering::AcquireRelease:
1037   case AtomicOrdering::SequentiallyConsistent:
1038     IRB.CreateFence(AtomicOrdering::Release, SSID);
1039     break;
1040   default:
1041     break;
1042   }
1043 }
1044 
insertPostMemOpFence(AtomicOrdering Order,SyncScope::ID SSID)1045 void SplitPtrStructs::insertPostMemOpFence(AtomicOrdering Order,
1046                                            SyncScope::ID SSID) {
1047   switch (Order) {
1048   case AtomicOrdering::Acquire:
1049   case AtomicOrdering::AcquireRelease:
1050   case AtomicOrdering::SequentiallyConsistent:
1051     IRB.CreateFence(AtomicOrdering::Acquire, SSID);
1052     break;
1053   default:
1054     break;
1055   }
1056 }
1057 
handleMemoryInst(Instruction * I,Value * Arg,Value * Ptr,Type * Ty,Align Alignment,AtomicOrdering Order,bool IsVolatile,SyncScope::ID SSID)1058 Value *SplitPtrStructs::handleMemoryInst(Instruction *I, Value *Arg, Value *Ptr,
1059                                          Type *Ty, Align Alignment,
1060                                          AtomicOrdering Order, bool IsVolatile,
1061                                          SyncScope::ID SSID) {
1062   IRB.SetInsertPoint(I);
1063 
1064   auto [Rsrc, Off] = getPtrParts(Ptr);
1065   SmallVector<Value *, 5> Args;
1066   if (Arg)
1067     Args.push_back(Arg);
1068   Args.push_back(Rsrc);
1069   Args.push_back(Off);
1070   insertPreMemOpFence(Order, SSID);
1071   // soffset is always 0 for these cases, where we always want any offset to be
1072   // part of bounds checking and we don't know which parts of the GEPs is
1073   // uniform.
1074   Args.push_back(IRB.getInt32(0));
1075 
1076   uint32_t Aux = 0;
1077   bool IsInvariant =
1078       (isa<LoadInst>(I) && I->getMetadata(LLVMContext::MD_invariant_load));
1079   bool IsNonTemporal = I->getMetadata(LLVMContext::MD_nontemporal);
1080   // Atomic loads and stores need glc, atomic read-modify-write doesn't.
1081   bool IsOneWayAtomic =
1082       !isa<AtomicRMWInst>(I) && Order != AtomicOrdering::NotAtomic;
1083   if (IsOneWayAtomic)
1084     Aux |= AMDGPU::CPol::GLC;
1085   if (IsNonTemporal && !IsInvariant)
1086     Aux |= AMDGPU::CPol::SLC;
1087   if (isa<LoadInst>(I) && ST->getGeneration() == AMDGPUSubtarget::GFX10)
1088     Aux |= (Aux & AMDGPU::CPol::GLC ? AMDGPU::CPol::DLC : 0);
1089   if (IsVolatile)
1090     Aux |= AMDGPU::CPol::VOLATILE;
1091   Args.push_back(IRB.getInt32(Aux));
1092 
1093   Intrinsic::ID IID = Intrinsic::not_intrinsic;
1094   if (isa<LoadInst>(I))
1095     IID = Order == AtomicOrdering::NotAtomic
1096               ? Intrinsic::amdgcn_raw_ptr_buffer_load
1097               : Intrinsic::amdgcn_raw_ptr_atomic_buffer_load;
1098   else if (isa<StoreInst>(I))
1099     IID = Intrinsic::amdgcn_raw_ptr_buffer_store;
1100   else if (auto *RMW = dyn_cast<AtomicRMWInst>(I)) {
1101     switch (RMW->getOperation()) {
1102     case AtomicRMWInst::Xchg:
1103       IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_swap;
1104       break;
1105     case AtomicRMWInst::Add:
1106       IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_add;
1107       break;
1108     case AtomicRMWInst::Sub:
1109       IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_sub;
1110       break;
1111     case AtomicRMWInst::And:
1112       IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_and;
1113       break;
1114     case AtomicRMWInst::Or:
1115       IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_or;
1116       break;
1117     case AtomicRMWInst::Xor:
1118       IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_xor;
1119       break;
1120     case AtomicRMWInst::Max:
1121       IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_smax;
1122       break;
1123     case AtomicRMWInst::Min:
1124       IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_smin;
1125       break;
1126     case AtomicRMWInst::UMax:
1127       IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_umax;
1128       break;
1129     case AtomicRMWInst::UMin:
1130       IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_umin;
1131       break;
1132     case AtomicRMWInst::FAdd:
1133       IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_fadd;
1134       break;
1135     case AtomicRMWInst::FMax:
1136       IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_fmax;
1137       break;
1138     case AtomicRMWInst::FMin:
1139       IID = Intrinsic::amdgcn_raw_ptr_buffer_atomic_fmin;
1140       break;
1141     case AtomicRMWInst::FSub: {
1142       report_fatal_error("atomic floating point subtraction not supported for "
1143                          "buffer resources and should've been expanded away");
1144       break;
1145     }
1146     case AtomicRMWInst::Nand:
1147       report_fatal_error("atomic nand not supported for buffer resources and "
1148                          "should've been expanded away");
1149       break;
1150     case AtomicRMWInst::UIncWrap:
1151     case AtomicRMWInst::UDecWrap:
1152       report_fatal_error("wrapping increment/decrement not supported for "
1153                          "buffer resources and should've ben expanded away");
1154       break;
1155     case AtomicRMWInst::BAD_BINOP:
1156       llvm_unreachable("Not sure how we got a bad binop");
1157     }
1158   }
1159 
1160   auto *Call = IRB.CreateIntrinsic(IID, Ty, Args);
1161   copyMetadata(Call, I);
1162   setAlign(Call, Alignment, Arg ? 1 : 0);
1163   Call->takeName(I);
1164 
1165   insertPostMemOpFence(Order, SSID);
1166   // The "no moving p7 directly" rewrites ensure that this load or store won't
1167   // itself need to be split into parts.
1168   SplitUsers.insert(I);
1169   I->replaceAllUsesWith(Call);
1170   return Call;
1171 }
1172 
visitInstruction(Instruction & I)1173 PtrParts SplitPtrStructs::visitInstruction(Instruction &I) {
1174   return {nullptr, nullptr};
1175 }
1176 
visitLoadInst(LoadInst & LI)1177 PtrParts SplitPtrStructs::visitLoadInst(LoadInst &LI) {
1178   if (!isSplitFatPtr(LI.getPointerOperandType()))
1179     return {nullptr, nullptr};
1180   handleMemoryInst(&LI, nullptr, LI.getPointerOperand(), LI.getType(),
1181                    LI.getAlign(), LI.getOrdering(), LI.isVolatile(),
1182                    LI.getSyncScopeID());
1183   return {nullptr, nullptr};
1184 }
1185 
visitStoreInst(StoreInst & SI)1186 PtrParts SplitPtrStructs::visitStoreInst(StoreInst &SI) {
1187   if (!isSplitFatPtr(SI.getPointerOperandType()))
1188     return {nullptr, nullptr};
1189   Value *Arg = SI.getValueOperand();
1190   handleMemoryInst(&SI, Arg, SI.getPointerOperand(), Arg->getType(),
1191                    SI.getAlign(), SI.getOrdering(), SI.isVolatile(),
1192                    SI.getSyncScopeID());
1193   return {nullptr, nullptr};
1194 }
1195 
visitAtomicRMWInst(AtomicRMWInst & AI)1196 PtrParts SplitPtrStructs::visitAtomicRMWInst(AtomicRMWInst &AI) {
1197   if (!isSplitFatPtr(AI.getPointerOperand()->getType()))
1198     return {nullptr, nullptr};
1199   Value *Arg = AI.getValOperand();
1200   handleMemoryInst(&AI, Arg, AI.getPointerOperand(), Arg->getType(),
1201                    AI.getAlign(), AI.getOrdering(), AI.isVolatile(),
1202                    AI.getSyncScopeID());
1203   return {nullptr, nullptr};
1204 }
1205 
1206 // Unlike load, store, and RMW, cmpxchg needs special handling to account
1207 // for the boolean argument.
visitAtomicCmpXchgInst(AtomicCmpXchgInst & AI)1208 PtrParts SplitPtrStructs::visitAtomicCmpXchgInst(AtomicCmpXchgInst &AI) {
1209   Value *Ptr = AI.getPointerOperand();
1210   if (!isSplitFatPtr(Ptr->getType()))
1211     return {nullptr, nullptr};
1212   IRB.SetInsertPoint(&AI);
1213 
1214   Type *Ty = AI.getNewValOperand()->getType();
1215   AtomicOrdering Order = AI.getMergedOrdering();
1216   SyncScope::ID SSID = AI.getSyncScopeID();
1217   bool IsNonTemporal = AI.getMetadata(LLVMContext::MD_nontemporal);
1218 
1219   auto [Rsrc, Off] = getPtrParts(Ptr);
1220   insertPreMemOpFence(Order, SSID);
1221 
1222   uint32_t Aux = 0;
1223   if (IsNonTemporal)
1224     Aux |= AMDGPU::CPol::SLC;
1225   if (AI.isVolatile())
1226     Aux |= AMDGPU::CPol::VOLATILE;
1227   auto *Call =
1228       IRB.CreateIntrinsic(Intrinsic::amdgcn_raw_ptr_buffer_atomic_cmpswap, Ty,
1229                           {AI.getNewValOperand(), AI.getCompareOperand(), Rsrc,
1230                            Off, IRB.getInt32(0), IRB.getInt32(Aux)});
1231   copyMetadata(Call, &AI);
1232   setAlign(Call, AI.getAlign(), 2);
1233   Call->takeName(&AI);
1234   insertPostMemOpFence(Order, SSID);
1235 
1236   Value *Res = PoisonValue::get(AI.getType());
1237   Res = IRB.CreateInsertValue(Res, Call, 0);
1238   if (!AI.isWeak()) {
1239     Value *Succeeded = IRB.CreateICmpEQ(Call, AI.getCompareOperand());
1240     Res = IRB.CreateInsertValue(Res, Succeeded, 1);
1241   }
1242   SplitUsers.insert(&AI);
1243   AI.replaceAllUsesWith(Res);
1244   return {nullptr, nullptr};
1245 }
1246 
visitGetElementPtrInst(GetElementPtrInst & GEP)1247 PtrParts SplitPtrStructs::visitGetElementPtrInst(GetElementPtrInst &GEP) {
1248   using namespace llvm::PatternMatch;
1249   Value *Ptr = GEP.getPointerOperand();
1250   if (!isSplitFatPtr(Ptr->getType()))
1251     return {nullptr, nullptr};
1252   IRB.SetInsertPoint(&GEP);
1253 
1254   auto [Rsrc, Off] = getPtrParts(Ptr);
1255   const DataLayout &DL = GEP.getDataLayout();
1256   bool InBounds = GEP.isInBounds();
1257 
1258   // In order to call emitGEPOffset() and thus not have to reimplement it,
1259   // we need the GEP result to have ptr addrspace(7) type.
1260   Type *FatPtrTy = IRB.getPtrTy(AMDGPUAS::BUFFER_FAT_POINTER);
1261   if (auto *VT = dyn_cast<VectorType>(Off->getType()))
1262     FatPtrTy = VectorType::get(FatPtrTy, VT->getElementCount());
1263   GEP.mutateType(FatPtrTy);
1264   Value *OffAccum = emitGEPOffset(&IRB, DL, &GEP);
1265   GEP.mutateType(Ptr->getType());
1266   if (match(OffAccum, m_Zero())) { // Constant-zero offset
1267     SplitUsers.insert(&GEP);
1268     return {Rsrc, Off};
1269   }
1270 
1271   bool HasNonNegativeOff = false;
1272   if (auto *CI = dyn_cast<ConstantInt>(OffAccum)) {
1273     HasNonNegativeOff = !CI->isNegative();
1274   }
1275   Value *NewOff;
1276   if (match(Off, m_Zero())) {
1277     NewOff = OffAccum;
1278   } else {
1279     NewOff = IRB.CreateAdd(Off, OffAccum, "",
1280                            /*hasNUW=*/InBounds && HasNonNegativeOff,
1281                            /*hasNSW=*/false);
1282   }
1283   copyMetadata(NewOff, &GEP);
1284   NewOff->takeName(&GEP);
1285   SplitUsers.insert(&GEP);
1286   return {Rsrc, NewOff};
1287 }
1288 
visitPtrToIntInst(PtrToIntInst & PI)1289 PtrParts SplitPtrStructs::visitPtrToIntInst(PtrToIntInst &PI) {
1290   Value *Ptr = PI.getPointerOperand();
1291   if (!isSplitFatPtr(Ptr->getType()))
1292     return {nullptr, nullptr};
1293   IRB.SetInsertPoint(&PI);
1294 
1295   Type *ResTy = PI.getType();
1296   unsigned Width = ResTy->getScalarSizeInBits();
1297 
1298   auto [Rsrc, Off] = getPtrParts(Ptr);
1299   const DataLayout &DL = PI.getDataLayout();
1300   unsigned FatPtrWidth = DL.getPointerSizeInBits(AMDGPUAS::BUFFER_FAT_POINTER);
1301 
1302   Value *Res;
1303   if (Width <= BufferOffsetWidth) {
1304     Res = IRB.CreateIntCast(Off, ResTy, /*isSigned=*/false,
1305                             PI.getName() + ".off");
1306   } else {
1307     Value *RsrcInt = IRB.CreatePtrToInt(Rsrc, ResTy, PI.getName() + ".rsrc");
1308     Value *Shl = IRB.CreateShl(
1309         RsrcInt,
1310         ConstantExpr::getIntegerValue(ResTy, APInt(Width, BufferOffsetWidth)),
1311         "", Width >= FatPtrWidth, Width > FatPtrWidth);
1312     Value *OffCast = IRB.CreateIntCast(Off, ResTy, /*isSigned=*/false,
1313                                        PI.getName() + ".off");
1314     Res = IRB.CreateOr(Shl, OffCast);
1315   }
1316 
1317   copyMetadata(Res, &PI);
1318   Res->takeName(&PI);
1319   SplitUsers.insert(&PI);
1320   PI.replaceAllUsesWith(Res);
1321   return {nullptr, nullptr};
1322 }
1323 
visitIntToPtrInst(IntToPtrInst & IP)1324 PtrParts SplitPtrStructs::visitIntToPtrInst(IntToPtrInst &IP) {
1325   if (!isSplitFatPtr(IP.getType()))
1326     return {nullptr, nullptr};
1327   IRB.SetInsertPoint(&IP);
1328   const DataLayout &DL = IP.getDataLayout();
1329   unsigned RsrcPtrWidth = DL.getPointerSizeInBits(AMDGPUAS::BUFFER_RESOURCE);
1330   Value *Int = IP.getOperand(0);
1331   Type *IntTy = Int->getType();
1332   Type *RsrcIntTy = IntTy->getWithNewBitWidth(RsrcPtrWidth);
1333   unsigned Width = IntTy->getScalarSizeInBits();
1334 
1335   auto *RetTy = cast<StructType>(IP.getType());
1336   Type *RsrcTy = RetTy->getElementType(0);
1337   Type *OffTy = RetTy->getElementType(1);
1338   Value *RsrcPart = IRB.CreateLShr(
1339       Int,
1340       ConstantExpr::getIntegerValue(IntTy, APInt(Width, BufferOffsetWidth)));
1341   Value *RsrcInt = IRB.CreateIntCast(RsrcPart, RsrcIntTy, /*isSigned=*/false);
1342   Value *Rsrc = IRB.CreateIntToPtr(RsrcInt, RsrcTy, IP.getName() + ".rsrc");
1343   Value *Off =
1344       IRB.CreateIntCast(Int, OffTy, /*IsSigned=*/false, IP.getName() + ".off");
1345 
1346   copyMetadata(Rsrc, &IP);
1347   SplitUsers.insert(&IP);
1348   return {Rsrc, Off};
1349 }
1350 
visitAddrSpaceCastInst(AddrSpaceCastInst & I)1351 PtrParts SplitPtrStructs::visitAddrSpaceCastInst(AddrSpaceCastInst &I) {
1352   if (!isSplitFatPtr(I.getType()))
1353     return {nullptr, nullptr};
1354   IRB.SetInsertPoint(&I);
1355   Value *In = I.getPointerOperand();
1356   // No-op casts preserve parts
1357   if (In->getType() == I.getType()) {
1358     auto [Rsrc, Off] = getPtrParts(In);
1359     SplitUsers.insert(&I);
1360     return {Rsrc, Off};
1361   }
1362   if (I.getSrcAddressSpace() != AMDGPUAS::BUFFER_RESOURCE)
1363     report_fatal_error("Only buffer resources (addrspace 8) can be cast to "
1364                        "buffer fat pointers (addrspace 7)");
1365   Type *OffTy = cast<StructType>(I.getType())->getElementType(1);
1366   Value *ZeroOff = Constant::getNullValue(OffTy);
1367   SplitUsers.insert(&I);
1368   return {In, ZeroOff};
1369 }
1370 
visitICmpInst(ICmpInst & Cmp)1371 PtrParts SplitPtrStructs::visitICmpInst(ICmpInst &Cmp) {
1372   Value *Lhs = Cmp.getOperand(0);
1373   if (!isSplitFatPtr(Lhs->getType()))
1374     return {nullptr, nullptr};
1375   Value *Rhs = Cmp.getOperand(1);
1376   IRB.SetInsertPoint(&Cmp);
1377   ICmpInst::Predicate Pred = Cmp.getPredicate();
1378 
1379   assert((Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) &&
1380          "Pointer comparison is only equal or unequal");
1381   auto [LhsRsrc, LhsOff] = getPtrParts(Lhs);
1382   auto [RhsRsrc, RhsOff] = getPtrParts(Rhs);
1383   Value *RsrcCmp =
1384       IRB.CreateICmp(Pred, LhsRsrc, RhsRsrc, Cmp.getName() + ".rsrc");
1385   copyMetadata(RsrcCmp, &Cmp);
1386   Value *OffCmp = IRB.CreateICmp(Pred, LhsOff, RhsOff, Cmp.getName() + ".off");
1387   copyMetadata(OffCmp, &Cmp);
1388 
1389   Value *Res = nullptr;
1390   if (Pred == ICmpInst::ICMP_EQ)
1391     Res = IRB.CreateAnd(RsrcCmp, OffCmp);
1392   else if (Pred == ICmpInst::ICMP_NE)
1393     Res = IRB.CreateOr(RsrcCmp, OffCmp);
1394   copyMetadata(Res, &Cmp);
1395   Res->takeName(&Cmp);
1396   SplitUsers.insert(&Cmp);
1397   Cmp.replaceAllUsesWith(Res);
1398   return {nullptr, nullptr};
1399 }
1400 
visitFreezeInst(FreezeInst & I)1401 PtrParts SplitPtrStructs::visitFreezeInst(FreezeInst &I) {
1402   if (!isSplitFatPtr(I.getType()))
1403     return {nullptr, nullptr};
1404   IRB.SetInsertPoint(&I);
1405   auto [Rsrc, Off] = getPtrParts(I.getOperand(0));
1406 
1407   Value *RsrcRes = IRB.CreateFreeze(Rsrc, I.getName() + ".rsrc");
1408   copyMetadata(RsrcRes, &I);
1409   Value *OffRes = IRB.CreateFreeze(Off, I.getName() + ".off");
1410   copyMetadata(OffRes, &I);
1411   SplitUsers.insert(&I);
1412   return {RsrcRes, OffRes};
1413 }
1414 
visitExtractElementInst(ExtractElementInst & I)1415 PtrParts SplitPtrStructs::visitExtractElementInst(ExtractElementInst &I) {
1416   if (!isSplitFatPtr(I.getType()))
1417     return {nullptr, nullptr};
1418   IRB.SetInsertPoint(&I);
1419   Value *Vec = I.getVectorOperand();
1420   Value *Idx = I.getIndexOperand();
1421   auto [Rsrc, Off] = getPtrParts(Vec);
1422 
1423   Value *RsrcRes = IRB.CreateExtractElement(Rsrc, Idx, I.getName() + ".rsrc");
1424   copyMetadata(RsrcRes, &I);
1425   Value *OffRes = IRB.CreateExtractElement(Off, Idx, I.getName() + ".off");
1426   copyMetadata(OffRes, &I);
1427   SplitUsers.insert(&I);
1428   return {RsrcRes, OffRes};
1429 }
1430 
visitInsertElementInst(InsertElementInst & I)1431 PtrParts SplitPtrStructs::visitInsertElementInst(InsertElementInst &I) {
1432   // The mutated instructions temporarily don't return vectors, and so
1433   // we need the generic getType() here to avoid crashes.
1434   if (!isSplitFatPtr(cast<Instruction>(I).getType()))
1435     return {nullptr, nullptr};
1436   IRB.SetInsertPoint(&I);
1437   Value *Vec = I.getOperand(0);
1438   Value *Elem = I.getOperand(1);
1439   Value *Idx = I.getOperand(2);
1440   auto [VecRsrc, VecOff] = getPtrParts(Vec);
1441   auto [ElemRsrc, ElemOff] = getPtrParts(Elem);
1442 
1443   Value *RsrcRes =
1444       IRB.CreateInsertElement(VecRsrc, ElemRsrc, Idx, I.getName() + ".rsrc");
1445   copyMetadata(RsrcRes, &I);
1446   Value *OffRes =
1447       IRB.CreateInsertElement(VecOff, ElemOff, Idx, I.getName() + ".off");
1448   copyMetadata(OffRes, &I);
1449   SplitUsers.insert(&I);
1450   return {RsrcRes, OffRes};
1451 }
1452 
visitShuffleVectorInst(ShuffleVectorInst & I)1453 PtrParts SplitPtrStructs::visitShuffleVectorInst(ShuffleVectorInst &I) {
1454   // Cast is needed for the same reason as insertelement's.
1455   if (!isSplitFatPtr(cast<Instruction>(I).getType()))
1456     return {nullptr, nullptr};
1457   IRB.SetInsertPoint(&I);
1458 
1459   Value *V1 = I.getOperand(0);
1460   Value *V2 = I.getOperand(1);
1461   ArrayRef<int> Mask = I.getShuffleMask();
1462   auto [V1Rsrc, V1Off] = getPtrParts(V1);
1463   auto [V2Rsrc, V2Off] = getPtrParts(V2);
1464 
1465   Value *RsrcRes =
1466       IRB.CreateShuffleVector(V1Rsrc, V2Rsrc, Mask, I.getName() + ".rsrc");
1467   copyMetadata(RsrcRes, &I);
1468   Value *OffRes =
1469       IRB.CreateShuffleVector(V1Off, V2Off, Mask, I.getName() + ".off");
1470   copyMetadata(OffRes, &I);
1471   SplitUsers.insert(&I);
1472   return {RsrcRes, OffRes};
1473 }
1474 
visitPHINode(PHINode & PHI)1475 PtrParts SplitPtrStructs::visitPHINode(PHINode &PHI) {
1476   if (!isSplitFatPtr(PHI.getType()))
1477     return {nullptr, nullptr};
1478   IRB.SetInsertPoint(*PHI.getInsertionPointAfterDef());
1479   // Phi nodes will be handled in post-processing after we've visited every
1480   // instruction. However, instead of just returning {nullptr, nullptr},
1481   // we explicitly create the temporary extractvalue operations that are our
1482   // temporary results so that they end up at the beginning of the block with
1483   // the PHIs.
1484   Value *TmpRsrc = IRB.CreateExtractValue(&PHI, 0, PHI.getName() + ".rsrc");
1485   Value *TmpOff = IRB.CreateExtractValue(&PHI, 1, PHI.getName() + ".off");
1486   Conditionals.push_back(&PHI);
1487   SplitUsers.insert(&PHI);
1488   return {TmpRsrc, TmpOff};
1489 }
1490 
visitSelectInst(SelectInst & SI)1491 PtrParts SplitPtrStructs::visitSelectInst(SelectInst &SI) {
1492   if (!isSplitFatPtr(SI.getType()))
1493     return {nullptr, nullptr};
1494   IRB.SetInsertPoint(&SI);
1495 
1496   Value *Cond = SI.getCondition();
1497   Value *True = SI.getTrueValue();
1498   Value *False = SI.getFalseValue();
1499   auto [TrueRsrc, TrueOff] = getPtrParts(True);
1500   auto [FalseRsrc, FalseOff] = getPtrParts(False);
1501 
1502   Value *RsrcRes =
1503       IRB.CreateSelect(Cond, TrueRsrc, FalseRsrc, SI.getName() + ".rsrc", &SI);
1504   copyMetadata(RsrcRes, &SI);
1505   Conditionals.push_back(&SI);
1506   Value *OffRes =
1507       IRB.CreateSelect(Cond, TrueOff, FalseOff, SI.getName() + ".off", &SI);
1508   copyMetadata(OffRes, &SI);
1509   SplitUsers.insert(&SI);
1510   return {RsrcRes, OffRes};
1511 }
1512 
1513 /// Returns true if this intrinsic needs to be removed when it is
1514 /// applied to `ptr addrspace(7)` values. Calls to these intrinsics are
1515 /// rewritten into calls to versions of that intrinsic on the resource
1516 /// descriptor.
isRemovablePointerIntrinsic(Intrinsic::ID IID)1517 static bool isRemovablePointerIntrinsic(Intrinsic::ID IID) {
1518   switch (IID) {
1519   default:
1520     return false;
1521   case Intrinsic::ptrmask:
1522   case Intrinsic::invariant_start:
1523   case Intrinsic::invariant_end:
1524   case Intrinsic::launder_invariant_group:
1525   case Intrinsic::strip_invariant_group:
1526     return true;
1527   }
1528 }
1529 
visitIntrinsicInst(IntrinsicInst & I)1530 PtrParts SplitPtrStructs::visitIntrinsicInst(IntrinsicInst &I) {
1531   Intrinsic::ID IID = I.getIntrinsicID();
1532   switch (IID) {
1533   default:
1534     break;
1535   case Intrinsic::ptrmask: {
1536     Value *Ptr = I.getArgOperand(0);
1537     if (!isSplitFatPtr(Ptr->getType()))
1538       return {nullptr, nullptr};
1539     Value *Mask = I.getArgOperand(1);
1540     IRB.SetInsertPoint(&I);
1541     auto [Rsrc, Off] = getPtrParts(Ptr);
1542     if (Mask->getType() != Off->getType())
1543       report_fatal_error("offset width is not equal to index width of fat "
1544                          "pointer (data layout not set up correctly?)");
1545     Value *OffRes = IRB.CreateAnd(Off, Mask, I.getName() + ".off");
1546     copyMetadata(OffRes, &I);
1547     SplitUsers.insert(&I);
1548     return {Rsrc, OffRes};
1549   }
1550   // Pointer annotation intrinsics that, given their object-wide nature
1551   // operate on the resource part.
1552   case Intrinsic::invariant_start: {
1553     Value *Ptr = I.getArgOperand(1);
1554     if (!isSplitFatPtr(Ptr->getType()))
1555       return {nullptr, nullptr};
1556     IRB.SetInsertPoint(&I);
1557     auto [Rsrc, Off] = getPtrParts(Ptr);
1558     Type *NewTy = PointerType::get(I.getContext(), AMDGPUAS::BUFFER_RESOURCE);
1559     auto *NewRsrc = IRB.CreateIntrinsic(IID, {NewTy}, {I.getOperand(0), Rsrc});
1560     copyMetadata(NewRsrc, &I);
1561     NewRsrc->takeName(&I);
1562     SplitUsers.insert(&I);
1563     I.replaceAllUsesWith(NewRsrc);
1564     return {nullptr, nullptr};
1565   }
1566   case Intrinsic::invariant_end: {
1567     Value *RealPtr = I.getArgOperand(2);
1568     if (!isSplitFatPtr(RealPtr->getType()))
1569       return {nullptr, nullptr};
1570     IRB.SetInsertPoint(&I);
1571     Value *RealRsrc = getPtrParts(RealPtr).first;
1572     Value *InvPtr = I.getArgOperand(0);
1573     Value *Size = I.getArgOperand(1);
1574     Value *NewRsrc = IRB.CreateIntrinsic(IID, {RealRsrc->getType()},
1575                                          {InvPtr, Size, RealRsrc});
1576     copyMetadata(NewRsrc, &I);
1577     NewRsrc->takeName(&I);
1578     SplitUsers.insert(&I);
1579     I.replaceAllUsesWith(NewRsrc);
1580     return {nullptr, nullptr};
1581   }
1582   case Intrinsic::launder_invariant_group:
1583   case Intrinsic::strip_invariant_group: {
1584     Value *Ptr = I.getArgOperand(0);
1585     if (!isSplitFatPtr(Ptr->getType()))
1586       return {nullptr, nullptr};
1587     IRB.SetInsertPoint(&I);
1588     auto [Rsrc, Off] = getPtrParts(Ptr);
1589     Value *NewRsrc = IRB.CreateIntrinsic(IID, {Rsrc->getType()}, {Rsrc});
1590     copyMetadata(NewRsrc, &I);
1591     NewRsrc->takeName(&I);
1592     SplitUsers.insert(&I);
1593     return {NewRsrc, Off};
1594   }
1595   }
1596   return {nullptr, nullptr};
1597 }
1598 
processFunction(Function & F)1599 void SplitPtrStructs::processFunction(Function &F) {
1600   ST = &TM->getSubtarget<GCNSubtarget>(F);
1601   SmallVector<Instruction *, 0> Originals;
1602   LLVM_DEBUG(dbgs() << "Splitting pointer structs in function: " << F.getName()
1603                     << "\n");
1604   for (Instruction &I : instructions(F))
1605     Originals.push_back(&I);
1606   for (Instruction *I : Originals) {
1607     auto [Rsrc, Off] = visit(I);
1608     assert(((Rsrc && Off) || (!Rsrc && !Off)) &&
1609            "Can't have a resource but no offset");
1610     if (Rsrc)
1611       RsrcParts[I] = Rsrc;
1612     if (Off)
1613       OffParts[I] = Off;
1614   }
1615   processConditionals();
1616   killAndReplaceSplitInstructions(Originals);
1617 
1618   // Clean up after ourselves to save on memory.
1619   RsrcParts.clear();
1620   OffParts.clear();
1621   SplitUsers.clear();
1622   Conditionals.clear();
1623   ConditionalTemps.clear();
1624 }
1625 
1626 namespace {
1627 class AMDGPULowerBufferFatPointers : public ModulePass {
1628 public:
1629   static char ID;
1630 
AMDGPULowerBufferFatPointers()1631   AMDGPULowerBufferFatPointers() : ModulePass(ID) {
1632     initializeAMDGPULowerBufferFatPointersPass(
1633         *PassRegistry::getPassRegistry());
1634   }
1635 
1636   bool run(Module &M, const TargetMachine &TM);
1637   bool runOnModule(Module &M) override;
1638 
1639   void getAnalysisUsage(AnalysisUsage &AU) const override;
1640 };
1641 } // namespace
1642 
1643 /// Returns true if there are values that have a buffer fat pointer in them,
1644 /// which means we'll need to perform rewrites on this function. As a side
1645 /// effect, this will populate the type remapping cache.
containsBufferFatPointers(const Function & F,BufferFatPtrToStructTypeMap * TypeMap)1646 static bool containsBufferFatPointers(const Function &F,
1647                                       BufferFatPtrToStructTypeMap *TypeMap) {
1648   bool HasFatPointers = false;
1649   for (const BasicBlock &BB : F)
1650     for (const Instruction &I : BB)
1651       HasFatPointers |= (I.getType() != TypeMap->remapType(I.getType()));
1652   return HasFatPointers;
1653 }
1654 
hasFatPointerInterface(const Function & F,BufferFatPtrToStructTypeMap * TypeMap)1655 static bool hasFatPointerInterface(const Function &F,
1656                                    BufferFatPtrToStructTypeMap *TypeMap) {
1657   Type *Ty = F.getFunctionType();
1658   return Ty != TypeMap->remapType(Ty);
1659 }
1660 
1661 /// Move the body of `OldF` into a new function, returning it.
moveFunctionAdaptingType(Function * OldF,FunctionType * NewTy,ValueToValueMapTy & CloneMap)1662 static Function *moveFunctionAdaptingType(Function *OldF, FunctionType *NewTy,
1663                                           ValueToValueMapTy &CloneMap) {
1664   bool IsIntrinsic = OldF->isIntrinsic();
1665   Function *NewF =
1666       Function::Create(NewTy, OldF->getLinkage(), OldF->getAddressSpace());
1667   NewF->IsNewDbgInfoFormat = OldF->IsNewDbgInfoFormat;
1668   NewF->copyAttributesFrom(OldF);
1669   NewF->copyMetadata(OldF, 0);
1670   NewF->takeName(OldF);
1671   NewF->updateAfterNameChange();
1672   NewF->setDLLStorageClass(OldF->getDLLStorageClass());
1673   OldF->getParent()->getFunctionList().insertAfter(OldF->getIterator(), NewF);
1674 
1675   while (!OldF->empty()) {
1676     BasicBlock *BB = &OldF->front();
1677     BB->removeFromParent();
1678     BB->insertInto(NewF);
1679     CloneMap[BB] = BB;
1680     for (Instruction &I : *BB) {
1681       CloneMap[&I] = &I;
1682     }
1683   }
1684 
1685   AttributeMask PtrOnlyAttrs;
1686   for (auto K :
1687        {Attribute::Dereferenceable, Attribute::DereferenceableOrNull,
1688         Attribute::NoAlias, Attribute::NoCapture, Attribute::NoFree,
1689         Attribute::NonNull, Attribute::NullPointerIsValid, Attribute::ReadNone,
1690         Attribute::ReadOnly, Attribute::WriteOnly}) {
1691     PtrOnlyAttrs.addAttribute(K);
1692   }
1693   SmallVector<AttributeSet> ArgAttrs;
1694   AttributeList OldAttrs = OldF->getAttributes();
1695 
1696   for (auto [I, OldArg, NewArg] : enumerate(OldF->args(), NewF->args())) {
1697     CloneMap[&NewArg] = &OldArg;
1698     NewArg.takeName(&OldArg);
1699     Type *OldArgTy = OldArg.getType(), *NewArgTy = NewArg.getType();
1700     // Temporarily mutate type of `NewArg` to allow RAUW to work.
1701     NewArg.mutateType(OldArgTy);
1702     OldArg.replaceAllUsesWith(&NewArg);
1703     NewArg.mutateType(NewArgTy);
1704 
1705     AttributeSet ArgAttr = OldAttrs.getParamAttrs(I);
1706     // Intrinsics get their attributes fixed later.
1707     if (OldArgTy != NewArgTy && !IsIntrinsic)
1708       ArgAttr = ArgAttr.removeAttributes(NewF->getContext(), PtrOnlyAttrs);
1709     ArgAttrs.push_back(ArgAttr);
1710   }
1711   AttributeSet RetAttrs = OldAttrs.getRetAttrs();
1712   if (OldF->getReturnType() != NewF->getReturnType() && !IsIntrinsic)
1713     RetAttrs = RetAttrs.removeAttributes(NewF->getContext(), PtrOnlyAttrs);
1714   NewF->setAttributes(AttributeList::get(
1715       NewF->getContext(), OldAttrs.getFnAttrs(), RetAttrs, ArgAttrs));
1716   return NewF;
1717 }
1718 
makeCloneInPraceMap(Function * F,ValueToValueMapTy & CloneMap)1719 static void makeCloneInPraceMap(Function *F, ValueToValueMapTy &CloneMap) {
1720   for (Argument &A : F->args())
1721     CloneMap[&A] = &A;
1722   for (BasicBlock &BB : *F) {
1723     CloneMap[&BB] = &BB;
1724     for (Instruction &I : BB)
1725       CloneMap[&I] = &I;
1726   }
1727 }
1728 
run(Module & M,const TargetMachine & TM)1729 bool AMDGPULowerBufferFatPointers::run(Module &M, const TargetMachine &TM) {
1730   bool Changed = false;
1731   const DataLayout &DL = M.getDataLayout();
1732   // Record the functions which need to be remapped.
1733   // The second element of the pair indicates whether the function has to have
1734   // its arguments or return types adjusted.
1735   SmallVector<std::pair<Function *, bool>> NeedsRemap;
1736 
1737   BufferFatPtrToStructTypeMap StructTM(DL);
1738   BufferFatPtrToIntTypeMap IntTM(DL);
1739   for (const GlobalVariable &GV : M.globals()) {
1740     if (GV.getAddressSpace() == AMDGPUAS::BUFFER_FAT_POINTER)
1741       report_fatal_error("Global variables with a buffer fat pointer address "
1742                          "space (7) are not supported");
1743     Type *VT = GV.getValueType();
1744     if (VT != StructTM.remapType(VT))
1745       report_fatal_error("Global variables that contain buffer fat pointers "
1746                          "(address space 7 pointers) are unsupported. Use "
1747                          "buffer resource pointers (address space 8) instead.");
1748   }
1749 
1750   {
1751     // Collect all constant exprs and aggregates referenced by any function.
1752     SmallVector<Constant *, 8> Worklist;
1753     for (Function &F : M.functions())
1754       for (Instruction &I : instructions(F))
1755         for (Value *Op : I.operands())
1756           if (isa<ConstantExpr>(Op) || isa<ConstantAggregate>(Op))
1757             Worklist.push_back(cast<Constant>(Op));
1758 
1759     // Recursively look for any referenced buffer pointer constants.
1760     SmallPtrSet<Constant *, 8> Visited;
1761     SetVector<Constant *> BufferFatPtrConsts;
1762     while (!Worklist.empty()) {
1763       Constant *C = Worklist.pop_back_val();
1764       if (!Visited.insert(C).second)
1765         continue;
1766       if (isBufferFatPtrOrVector(C->getType()))
1767         BufferFatPtrConsts.insert(C);
1768       for (Value *Op : C->operands())
1769         if (isa<ConstantExpr>(Op) || isa<ConstantAggregate>(Op))
1770           Worklist.push_back(cast<Constant>(Op));
1771     }
1772 
1773     // Expand all constant expressions using fat buffer pointers to
1774     // instructions.
1775     Changed |= convertUsersOfConstantsToInstructions(
1776         BufferFatPtrConsts.getArrayRef(), /*RestrictToFunc=*/nullptr,
1777         /*RemoveDeadConstants=*/false, /*IncludeSelf=*/true);
1778   }
1779 
1780   StoreFatPtrsAsIntsVisitor MemOpsRewrite(&IntTM, M.getContext());
1781   for (Function &F : M.functions()) {
1782     bool InterfaceChange = hasFatPointerInterface(F, &StructTM);
1783     bool BodyChanges = containsBufferFatPointers(F, &StructTM);
1784     Changed |= MemOpsRewrite.processFunction(F);
1785     if (InterfaceChange || BodyChanges)
1786       NeedsRemap.push_back(std::make_pair(&F, InterfaceChange));
1787   }
1788   if (NeedsRemap.empty())
1789     return Changed;
1790 
1791   SmallVector<Function *> NeedsPostProcess;
1792   SmallVector<Function *> Intrinsics;
1793   // Keep one big map so as to memoize constants across functions.
1794   ValueToValueMapTy CloneMap;
1795   FatPtrConstMaterializer Materializer(&StructTM, CloneMap);
1796 
1797   ValueMapper LowerInFuncs(CloneMap, RF_None, &StructTM, &Materializer);
1798   for (auto [F, InterfaceChange] : NeedsRemap) {
1799     Function *NewF = F;
1800     if (InterfaceChange)
1801       NewF = moveFunctionAdaptingType(
1802           F, cast<FunctionType>(StructTM.remapType(F->getFunctionType())),
1803           CloneMap);
1804     else
1805       makeCloneInPraceMap(F, CloneMap);
1806     LowerInFuncs.remapFunction(*NewF);
1807     if (NewF->isIntrinsic())
1808       Intrinsics.push_back(NewF);
1809     else
1810       NeedsPostProcess.push_back(NewF);
1811     if (InterfaceChange) {
1812       F->replaceAllUsesWith(NewF);
1813       F->eraseFromParent();
1814     }
1815     Changed = true;
1816   }
1817   StructTM.clear();
1818   IntTM.clear();
1819   CloneMap.clear();
1820 
1821   SplitPtrStructs Splitter(M.getContext(), &TM);
1822   for (Function *F : NeedsPostProcess)
1823     Splitter.processFunction(*F);
1824   for (Function *F : Intrinsics) {
1825     if (isRemovablePointerIntrinsic(F->getIntrinsicID())) {
1826       F->eraseFromParent();
1827     } else {
1828       std::optional<Function *> NewF = Intrinsic::remangleIntrinsicFunction(F);
1829       if (NewF)
1830         F->replaceAllUsesWith(*NewF);
1831     }
1832   }
1833   return Changed;
1834 }
1835 
runOnModule(Module & M)1836 bool AMDGPULowerBufferFatPointers::runOnModule(Module &M) {
1837   TargetPassConfig &TPC = getAnalysis<TargetPassConfig>();
1838   const TargetMachine &TM = TPC.getTM<TargetMachine>();
1839   return run(M, TM);
1840 }
1841 
1842 char AMDGPULowerBufferFatPointers::ID = 0;
1843 
1844 char &llvm::AMDGPULowerBufferFatPointersID = AMDGPULowerBufferFatPointers::ID;
1845 
getAnalysisUsage(AnalysisUsage & AU) const1846 void AMDGPULowerBufferFatPointers::getAnalysisUsage(AnalysisUsage &AU) const {
1847   AU.addRequired<TargetPassConfig>();
1848 }
1849 
1850 #define PASS_DESC "Lower buffer fat pointer operations to buffer resources"
INITIALIZE_PASS_BEGIN(AMDGPULowerBufferFatPointers,DEBUG_TYPE,PASS_DESC,false,false)1851 INITIALIZE_PASS_BEGIN(AMDGPULowerBufferFatPointers, DEBUG_TYPE, PASS_DESC,
1852                       false, false)
1853 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
1854 INITIALIZE_PASS_END(AMDGPULowerBufferFatPointers, DEBUG_TYPE, PASS_DESC, false,
1855                     false)
1856 #undef PASS_DESC
1857 
1858 ModulePass *llvm::createAMDGPULowerBufferFatPointersPass() {
1859   return new AMDGPULowerBufferFatPointers();
1860 }
1861 
1862 PreservedAnalyses
run(Module & M,ModuleAnalysisManager & MA)1863 AMDGPULowerBufferFatPointersPass::run(Module &M, ModuleAnalysisManager &MA) {
1864   return AMDGPULowerBufferFatPointers().run(M, TM) ? PreservedAnalyses::none()
1865                                                    : PreservedAnalyses::all();
1866 }
1867